diff --git a/.github/workflows/codeql.yml b/.github/workflows/codeql.yml deleted file mode 100644 index 7ac2f1649..000000000 --- a/.github/workflows/codeql.yml +++ /dev/null @@ -1,104 +0,0 @@ -name: CodeQL - -on: - push: - branches: - - main - pull_request: - branches: - - main - schedule: - - cron: '42 20 * * 4' - -jobs: - analyze-cuda: - name: Analyze (CUDA) - strategy: - fail-fast: false - matrix: - language: [ 'cpp' ] - concurrency: - group: ${{ github.workflow }}-cuda-${{ github.ref }} - cancel-in-progress: true - runs-on: ubuntu-latest - container: - image: ghcr.io/microsoft/ark/ark:base-dev-cuda12.2 - permissions: - actions: read - contents: read - security-events: write - - steps: - - name: Checkout repository - uses: actions/checkout@v4 - - - name: Check disk space - run: | - df -h - - # Initializes the CodeQL tools for scanning. - - name: Initialize CodeQL - uses: github/codeql-action/init@v3 - with: - languages: ${{ matrix.language }} - - - name: Dubious ownership exception - run: | - git config --global --add safe.directory /__w/ark/ark - - - name: Build - run: | - mkdir build && cd build - cmake -DCMAKE_BUILD_TYPE=Debug -DARK_BYPASS_GPU_CHECK=ON -DARK_USE_CUDA=ON -DARK_BUILD_TESTS=OFF .. - make build ark_py - - - name: Perform CodeQL Analysis - uses: github/codeql-action/analyze@v3 - with: - category: "/language:${{matrix.language}}" - - analyze-rocm: - name: Analyze (ROCM) - strategy: - fail-fast: false - matrix: - language: [ 'cpp' ] - concurrency: - group: ${{ github.workflow }}-rocm-${{ github.ref }} - cancel-in-progress: true - runs-on: ubuntu-latest - container: - image: ghcr.io/microsoft/ark/ark:build-rocm6.1 - permissions: - actions: read - contents: read - security-events: write - - steps: - - name: Checkout repository - uses: actions/checkout@v4 - - - name: Check disk space - run: | - df -h - - # Initializes the CodeQL tools for scanning. - - name: Initialize CodeQL - uses: github/codeql-action/init@v3 - with: - languages: ${{ matrix.language }} - - - name: Dubious ownership exception - run: | - git config --global --add safe.directory /__w/ark/ark - - - name: Build - run: | - mkdir build && cd build - CXX=/opt/rocm/bin/hipcc cmake -DCMAKE_BUILD_TYPE=Debug -DARK_BYPASS_GPU_CHECK=ON -DARK_USE_ROCM=ON -DARK_BUILD_TESTS=OFF .. - make -j build ark_py - - - name: Perform CodeQL Analysis - uses: github/codeql-action/analyze@v3 - with: - category: "/language:${{matrix.language}}" diff --git a/.github/workflows/lint.yml b/.github/workflows/lint.yml index c799e86c6..0fe0cf826 100644 --- a/.github/workflows/lint.yml +++ b/.github/workflows/lint.yml @@ -7,7 +7,7 @@ on: jobs: linters: - runs-on: ubuntu-20.04 + runs-on: ubuntu-latest steps: - name: Check out Git repository @@ -16,22 +16,19 @@ jobs: - name: Install ClangFormat run: sudo apt-get install -y clang-format - - name: Run git-clang-format - run: git clang-format --style=file --diff - - name: Set up Python uses: actions/setup-python@v4 with: - python-version: 3.8 + python-version: '3.12' - name: Install Python dependencies - run: python3.8 -m pip install black + run: pip install black - - name: Run black - run: python3.8 -m black --check --config pyproject.toml . + - name: Run lint + run: bash tools/lint.sh dry spelling: - runs-on: ubuntu-20.04 + runs-on: ubuntu-latest steps: - name: Check out Git repository diff --git a/.github/workflows/ut-rocm.yml b/.github/workflows/ut-rocm.yml deleted file mode 100644 index ac8ed0e90..000000000 --- a/.github/workflows/ut-rocm.yml +++ /dev/null @@ -1,64 +0,0 @@ -name: "Unit Tests (ROCm)" - -on: - push: - branches: - - main - pull_request: - branches: - - main - -jobs: - UnitTest: - runs-on: [ self-hosted, AMD ] - defaults: - run: - shell: bash - strategy: - matrix: - rocm: [ rocm6.0 ] - concurrency: - group: ${{ github.workflow }}-${{ github.ref }}-${{ matrix.rocm }} - cancel-in-progress: true - # container: - # image: "ghcr.io/microsoft/ark/ark:base-dev-${{ matrix.rocm }}" - # options: --privileged --ipc=host --security-opt seccomp=unconfined --group-add video --ulimit memlock=-1:-1 - - steps: - - name: Checkout - uses: actions/checkout@v4 - - - name: Dubious ownership exception - run: | - git config --global --add safe.directory /__w/ark/ark - - - name: Build - run: | - mkdir build && cd build - cmake -DCMAKE_BUILD_TYPE=Debug .. - make -j ut - - - name: RunUT - run: | - cd build && ARK_ROOT=$PWD ARK_IGNORE_BINARY_CACHE=1 ctest --stop-on-failure --verbose --schedule-random - - - name: ReportCoverage - run: | - cd build - lcov --capture --directory . --output-file coverage.info - lcov --remove coverage.info \ - '/usr/*' \ - '/tmp/*' \ - '*/third_party/*' \ - '*/ark/*_test.*' \ - '*/examples/*' \ - '*/python/*' \ - '*/ark/unittest/unittest_utils.cc' \ - --output-file coverage.info - lcov --list coverage.info - bash <(curl -s https://codecov.io/bash) -f coverage.info || echo "Codecov did not collect coverage reports" - - - name: BuildPython - run: | - python3 -m pip install -r requirements.txt - python3 -m pip install . diff --git a/.github/workflows/ut-cuda.yml b/.github/workflows/ut.yml similarity index 57% rename from .github/workflows/ut-cuda.yml rename to .github/workflows/ut.yml index 10b0679da..0929c75e2 100644 --- a/.github/workflows/ut-cuda.yml +++ b/.github/workflows/ut.yml @@ -1,4 +1,4 @@ -name: "Unit Tests (CUDA)" +name: "Unit Tests" on: push: @@ -11,46 +11,66 @@ on: jobs: UnitTest: - runs-on: [ self-hosted, A100 ] defaults: run: shell: bash - timeout-minutes: 30 + timeout-minutes: 60 + permissions: + actions: read + contents: read + security-events: write strategy: + fail-fast: false matrix: - cuda: [ cuda11.8, cuda12.2 ] + include: + - platform: cuda + runner: [self-hosted, CUDA] + container: nvcr.io/nvidia/pytorch:26.03-py3 + container_options: --privileged --ipc=host --gpus=all --ulimit memlock=-1:-1 + - platform: rocm + runner: [self-hosted, ROCM] + container: rocm/pytorch:rocm6.2.3_ubuntu22.04_py3.10_pytorch_release_2.3.0 + container_options: --privileged --ipc=host --security-opt seccomp=unconfined --group-add video --ulimit memlock=-1:-1 + runs-on: ${{ matrix.runner }} concurrency: - group: ${{ github.workflow }}-${{ github.ref }}-${{ matrix.cuda }} + group: ${{ github.workflow }}-${{ matrix.platform }}-${{ github.ref }} cancel-in-progress: true container: - image: "ghcr.io/microsoft/ark/ark:base-dev-${{ matrix.cuda }}" - options: --privileged --ipc=host --gpus=all --ulimit memlock=-1:-1 + image: ${{ matrix.container }} + options: ${{ matrix.container_options }} steps: - name: Checkout uses: actions/checkout@v4 - - name: LockGPUClock - run: | - sudo nvidia-smi -pm 1 - for i in $(seq 0 $(( $(nvidia-smi -L | wc -l) - 1 ))); do - sudo nvidia-smi -ac $(nvidia-smi --query-gpu=clocks.max.memory,clocks.max.sm --format=csv,noheader,nounits -i $i | sed 's/\ //') -i $i - done - - name: Dubious ownership exception run: | git config --global --add safe.directory /__w/ark/ark + - name: Initialize CodeQL + uses: github/codeql-action/init@v3 + with: + languages: cpp + - name: Build run: | + apt-get update && apt-get install -y lcov mkdir build && cd build - cmake -DCMAKE_BUILD_TYPE=Debug .. + CMAKE_ARGS="-DCMAKE_BUILD_TYPE=Debug" + if [ "${{ matrix.platform }}" = "rocm" ]; then + CMAKE_ARGS="$CMAKE_ARGS -DCMAKE_CXX_COMPILER=/opt/rocm/bin/hipcc" + fi + cmake $CMAKE_ARGS .. make -j ut ark_py - name: Run C++ UT run: | cd build ARK_ROOT=$PWD ctest --stop-on-failure --verbose --schedule-random + + - name: C++ Coverage + run: | + cd build lcov --capture --directory . --output-file cpp_coverage.info lcov --remove cpp_coverage.info \ '/usr/*' \ @@ -75,7 +95,7 @@ jobs: --cov=python/ark \ --cov-report lcov:py_coverage.info \ --verbose \ - ../python/unittest/test.py + ../python/unittest/ - name: Report Coverage env: @@ -92,3 +112,8 @@ jobs: - name: Run Tutorials run: | python3 ./examples/tutorial/quickstart_tutorial.py + + - name: Perform CodeQL Analysis + uses: github/codeql-action/analyze@v3 + with: + category: "/language:cpp-${{ matrix.platform }}" diff --git a/CMakeLists.txt b/CMakeLists.txt index 8d5de19d1..437746888 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -65,13 +65,24 @@ if(ARK_USE_CUDA) endif() # Set CUDA architectures - if(CUDAToolkit_VERSION_MAJOR GREATER_EQUAL 11) + if(CUDAToolkit_VERSION_MAJOR GREATER_EQUAL 13) + # CUDA 13+ dropped sm_60 and sm_70 + set(CMAKE_CUDA_ARCHITECTURES 80 90) + elseif(CUDAToolkit_VERSION_MAJOR GREATER_EQUAL 12) + set(CMAKE_CUDA_ARCHITECTURES 60 70 80 90) + elseif(CUDAToolkit_VERSION_MAJOR GREATER_EQUAL 11) set(CMAKE_CUDA_ARCHITECTURES 60 70 80) endif() - # Hopper architecture - if(CUDAToolkit_VERSION_MAJOR GREATER_EQUAL 12) - set(CMAKE_CUDA_ARCHITECTURES ${CMAKE_CUDA_ARCHITECTURES} 90) + # CUDA 13+ moved CCCL headers into a cccl/ subdirectory. + # Add it to the include path so third-party code (e.g. MSCCL++) + # that includes can still find the headers. + if(CUDAToolkit_VERSION_MAJOR GREATER_EQUAL 13) + set(CCCL_INCLUDE_DIR "${CUDAToolkit_INCLUDE_DIRS}/cccl") + if(EXISTS "${CCCL_INCLUDE_DIR}") + include_directories(SYSTEM "${CCCL_INCLUDE_DIR}") + message(STATUS "CUDA 13+: added CCCL include dir ${CCCL_INCLUDE_DIR}") + endif() endif() else() # ARK_USE_ROCM set(CMAKE_HIP_STANDARD 17) diff --git a/ark/api/context.cpp b/ark/api/context.cpp index 702247ddf..087e0e7c9 100644 --- a/ark/api/context.cpp +++ b/ark/api/context.cpp @@ -29,8 +29,6 @@ void Context::set(const std::string& key, const std::string& value, this->impl_->set(key, value_json, type); } -std::string Context::dump() const { - return this->impl_->dump().dump(); -} +std::string Context::dump() const { return this->impl_->dump().dump(); } } // namespace ark diff --git a/ark/api/planner.cpp b/ark/api/planner.cpp index d7e96e957..c48e19c50 100644 --- a/ark/api/planner.cpp +++ b/ark/api/planner.cpp @@ -211,8 +211,8 @@ std::string Planner::Impl::plan(bool pretty) const { Dims tile(trim_leading_ones); std::stringstream ss; - ss << "Result shape is not divided by tile " - << tile << ". Op: " << op->serialize().dump(); + ss << "Result shape is not divided by tile " << tile + << ". Op: " << op->serialize().dump(); auto not_divided_error = ss.str(); auto &result_shape = result_tensors[0]->padded_shape(); @@ -224,11 +224,10 @@ std::string Planner::Impl::plan(bool pretty) const { max_num_tasks = 1; for (int i = 0; i < tile4.ndims(); i++) { if (tile4[i] == 0) { - ERR(PlanError, "Tile dimension is zero. Op: ", - op->serialize().dump()); + ERR(PlanError, + "Tile dimension is zero. Op: ", op->serialize().dump()); } - max_num_tasks *= - (result_shape4[i] + tile4[i] - 1) / tile4[i]; + max_num_tasks *= (result_shape4[i] + tile4[i] - 1) / tile4[i]; } if (max_num_tasks == 0) ERR(InternalError, "max_num_tasks == 0"); } @@ -328,10 +327,13 @@ std::string Planner::Impl::plan(bool pretty) const { max_processor_id = std::max(max_processor_id, num_processors); } else if (processor_group_root == -1) { processor_group_root = ctx_processor_range_list.front()[0]; - processor_group["ProcessorRange"] = ctx_processor_range_list.front()[1]; - resource_group["ProcessorRange"] = ctx_processor_range_list.back()[1]; + processor_group["ProcessorRange"] = + ctx_processor_range_list.front()[1]; + resource_group["ProcessorRange"] = + ctx_processor_range_list.back()[1]; max_processor_id = std::max( - max_processor_id, ctx_processor_range_list.front()[1][1].get()); + max_processor_id, + ctx_processor_range_list.front()[1][1].get()); } else { new_processor_group = false; resource_group["ProcessorRange"] = diff --git a/ark/api/planner_test.cpp b/ark/api/planner_test.cpp index 7507ea023..e557ee307 100644 --- a/ark/api/planner_test.cpp +++ b/ark/api/planner_test.cpp @@ -87,8 +87,9 @@ ark::unittest::State test_planner_context_processor_range() { auto t = model.add(t0, t1); tensors.push_back(t); - UNITTEST_EQ(ctx.get("ProcessorRange"), - ark::Json({subctx.id(), {0 * (int)i, 2 * (int)i}}).dump()); + UNITTEST_EQ( + ctx.get("ProcessorRange"), + ark::Json({subctx.id(), {0 * (int)i, 2 * (int)i}}).dump()); } UNITTEST_TRUE(model.verify()); @@ -131,15 +132,13 @@ ark::unittest::State test_planner_context_warp_range() { ctx.warp_range(0, 4); t3 = model.relu(t2); - UNITTEST_EQ(ctx.get("WarpRange"), - ark::Json({ctx.id(), {0, 4}}).dump()); + UNITTEST_EQ(ctx.get("WarpRange"), ark::Json({ctx.id(), {0, 4}}).dump()); // node 2 ctx.warp_range(2, 4); t4 = model.sqrt(t3); - UNITTEST_EQ(ctx.get("WarpRange"), - ark::Json({ctx.id(), {2, 4}}).dump()); + UNITTEST_EQ(ctx.get("WarpRange"), ark::Json({ctx.id(), {2, 4}}).dump()); // Invalid usage: range (0, 4) is out of previous range (2, 4) UNITTEST_THROW(ctx.warp_range(0, 4), ark::PlanError); @@ -197,15 +196,13 @@ ark::unittest::State test_planner_context_sram_range() { ctx.sram_range(0, 4); t3 = model.relu(t2); - UNITTEST_EQ(ctx.get("SramRange"), - ark::Json({ctx.id(), {0, 4}}).dump()); + UNITTEST_EQ(ctx.get("SramRange"), ark::Json({ctx.id(), {0, 4}}).dump()); // node 2 ctx.sram_range(2, 4); t4 = model.sqrt(t3); - UNITTEST_EQ(ctx.get("SramRange"), - ark::Json({ctx.id(), {2, 4}}).dump()); + UNITTEST_EQ(ctx.get("SramRange"), ark::Json({ctx.id(), {2, 4}}).dump()); // Invalid usage: range (0, 4) is out of previous range (2, 4) UNITTEST_THROW(ctx.sram_range(0, 4), ark::PlanError); @@ -263,15 +260,13 @@ ark::unittest::State test_planner_context_sync() { ctx.sync(false); t3 = model.relu(t2); - UNITTEST_EQ(ctx.get("Sync"), - ark::Json({ctx.id(), false}).dump()); + UNITTEST_EQ(ctx.get("Sync"), ark::Json({ctx.id(), false}).dump()); // node 2 ctx.sync(true); t4 = model.sqrt(t3); - UNITTEST_EQ(ctx.get("Sync"), - ark::Json({ctx.id(), true}).dump()); + UNITTEST_EQ(ctx.get("Sync"), ark::Json({ctx.id(), true}).dump()); } { // node 3 @@ -280,8 +275,7 @@ ark::unittest::State test_planner_context_sync() { ctx.sync(true); t5 = model.exp(t2); - UNITTEST_EQ(ctx.get("Sync"), - ark::Json({ctx.id(), true}).dump()); + UNITTEST_EQ(ctx.get("Sync"), ark::Json({ctx.id(), true}).dump()); } UNITTEST_TRUE(model.verify()); @@ -297,8 +291,9 @@ ark::unittest::State test_planner_context_sync() { UNITTEST_EQ(nodes[1]->context.at("Sync"), ark::Json({{sync_id_1, true}, {sync_id_1, false}})); UNITTEST_GE(nodes[2]->context.size(), 1); - UNITTEST_EQ(nodes[2]->context.at("Sync"), - ark::Json({{sync_id_1, true}, {sync_id_1, false}, {sync_id_1, true}})); + UNITTEST_EQ( + nodes[2]->context.at("Sync"), + ark::Json({{sync_id_1, true}, {sync_id_1, false}, {sync_id_1, true}})); UNITTEST_GE(nodes[3]->context.size(), 1); UNITTEST_EQ(nodes[3]->context.at("Sync"), ark::Json({{sync_id_2, true}, {sync_id_2, true}})); @@ -361,7 +356,8 @@ ark::unittest::State test_planner_context_config() { ark::Json({{cfg_id_1, {{"key0", "val1"}}}})); UNITTEST_GE(nodes[2]->context.size(), 1); UNITTEST_EQ(nodes[2]->context.at("Config"), - ark::Json({{cfg_id_1, {{"key0", "val1"}}}, {cfg_id_1, {{"key1", "val2"}}}})); + ark::Json({{cfg_id_1, {{"key0", "val1"}}}, + {cfg_id_1, {{"key1", "val2"}}}})); UNITTEST_GE(nodes[3]->context.size(), 1); UNITTEST_EQ(nodes[3]->context.at("Config"), ark::Json({{cfg_id_2, {{"key2", "val3"}}}})); diff --git a/ark/context_impl.cpp b/ark/context_impl.cpp index c4f95f2c3..0eca1bf0e 100644 --- a/ark/context_impl.cpp +++ b/ark/context_impl.cpp @@ -52,8 +52,6 @@ bool Context::Impl::has(const std::string& key) const { return context_manager_->has(key); } -Json Context::Impl::dump() const { - return context_manager_->dump(); -} +Json Context::Impl::dump() const { return context_manager_->dump(); } } // namespace ark diff --git a/ark/context_impl.hpp b/ark/context_impl.hpp index b79353296..cf1509167 100644 --- a/ark/context_impl.hpp +++ b/ark/context_impl.hpp @@ -17,7 +17,8 @@ class Context::Impl { Json get(const std::string& key) const; - void set(const std::string& key, const Json& value_json, ContextType type = ContextType::Overwrite); + void set(const std::string& key, const Json& value_json, + ContextType type = ContextType::Overwrite); bool has(const std::string& key) const; diff --git a/ark/gpu/gpu.hpp b/ark/gpu/gpu.hpp index dbcd50f3e..fe1bf07bb 100644 --- a/ark/gpu/gpu.hpp +++ b/ark/gpu/gpu.hpp @@ -21,7 +21,7 @@ constexpr auto alias = cuda_const; #define ARK_GPU_DEFINE_FUNC_ALIAS(alias, cuda_func, rocm_func) \ template \ - inline auto alias(Args &&... args) { \ + inline auto alias(Args &&...args) { \ return cuda_func(std::forward(args)...); \ } @@ -35,7 +35,7 @@ constexpr auto alias = rocm_const; #define ARK_GPU_DEFINE_FUNC_ALIAS(alias, cuda_func, rocm_func) \ template \ - inline auto alias(Args &&... args) { \ + inline auto alias(Args &&...args) { \ return rocm_func(std::forward(args)...); \ } @@ -148,6 +148,8 @@ ARK_GPU_DEFINE_FUNC_ALIAS(gpuMemcpy, cudaMemcpy, hipMemcpy); ARK_GPU_DEFINE_FUNC_ALIAS(gpuMemcpyAsync, cudaMemcpyAsync, hipMemcpyAsync); ARK_GPU_DEFINE_FUNC_ALIAS(gpuMemsetAsync, cudaMemsetAsync, hipMemsetAsync); ARK_GPU_DEFINE_FUNC_ALIAS(gpuSetDevice, cudaSetDevice, hipSetDevice); +ARK_GPU_DEFINE_FUNC_ALIAS(gpuGetDeviceCount, cudaGetDeviceCount, + hipGetDeviceCount); ARK_GPU_DEFINE_FUNC_ALIAS(gpuStreamCreateWithFlags, cudaStreamCreateWithFlags, hipStreamCreateWithFlags); ARK_GPU_DEFINE_FUNC_ALIAS(gpuStreamDestroy, cudaStreamDestroy, diff --git a/ark/include/ark.hpp b/ark/include/ark.hpp index b1955bf9c..90f23b2f1 100644 --- a/ark/include/ark.hpp +++ b/ark/include/ark.hpp @@ -14,9 +14,9 @@ #include #include #include +#include #include #include -#include #include #include #include diff --git a/ark/include/ark/executor.hpp b/ark/include/ark/executor.hpp index 2e97ffe78..765cd0f27 100644 --- a/ark/include/ark/executor.hpp +++ b/ark/include/ark/executor.hpp @@ -52,9 +52,8 @@ class Executor { bool record = false); /// Run the executor for `iter` iterations. - void run( - int iter, - const std::unordered_map &placeholder_data = {}); + void run(int iter, + const std::unordered_map &placeholder_data = {}); /// Wait for the previous run to finish. void wait(int64_t max_spin_count = -1); diff --git a/ark/include/ark/planner.hpp b/ark/include/ark/planner.hpp index 9547848b9..b34acbc39 100644 --- a/ark/include/ark/planner.hpp +++ b/ark/include/ark/planner.hpp @@ -38,8 +38,8 @@ class Planner { ~Planner(); - using ConfigRule = std::function; + using ConfigRule = std::function; void install_config_rule(ConfigRule rule); diff --git a/ark/include/ark/tensor.hpp b/ark/include/ark/tensor.hpp index aa8dcaa68..67eda64ae 100644 --- a/ark/include/ark/tensor.hpp +++ b/ark/include/ark/tensor.hpp @@ -69,9 +69,7 @@ std::ostream &operator<<(std::ostream &os, const Tensor &tensor); namespace std { template <> struct hash { - size_t operator()(const ark::Tensor &t) const noexcept { - return t.id(); - } + size_t operator()(const ark::Tensor &t) const noexcept { return t.id(); } }; } // namespace std diff --git a/ark/include/kernels/comm.h b/ark/include/kernels/comm.h index 9075bb728..4a2deca80 100644 --- a/ark/include/kernels/comm.h +++ b/ark/include/kernels/comm.h @@ -414,8 +414,7 @@ DEVICE void read_reduce_and_write( DataType, NelemPerThread, Rank, NPeers, nelems_per_rank>>::run(dst, src, scratch, peer_offsets, uop_idx); - } - else { + } else { PacketType *scratch = reinterpret_cast(scratch_base); comm::PacketReduce< OutDims, OutShape, UnitOutDims, NumWarps, SmemBytes, PacketType, diff --git a/ark/include/kernels/common/arch.h b/ark/include/kernels/common/arch.h index e268ad78c..7eff95c7b 100644 --- a/ark/include/kernels/common/arch.h +++ b/ark/include/kernels/common/arch.h @@ -32,13 +32,13 @@ DEVICE int warp_id() { #if defined(ARK_TARGET_CUDA_ARCH) #define ARCH_ALIAS_FUNC(alias, cuda_func, hip_func) \ template \ - inline auto alias(Args &&... args) { \ + inline auto alias(Args &&...args) { \ return cuda_func(std::forward(args)...); \ } #elif defined(ARK_TARGET_ROCM_ARCH) #define ARCH_ALIAS_FUNC(alias, cuda_func, hip_func) \ template \ - inline auto alias(Args &&... args) { \ + inline auto alias(Args &&...args) { \ return hip_func(std::forward(args)...); \ } #endif diff --git a/ark/include/kernels/common/broadcast.h b/ark/include/kernels/common/broadcast.h index 86e84e5d0..d64a31fd5 100644 --- a/ark/include/kernels/common/broadcast.h +++ b/ark/include/kernels/common/broadcast.h @@ -41,22 +41,17 @@ struct Broadcast1Intrinsic { static constexpr int InConsecBytes = InConsecLen * sizeof(InputType); static constexpr int OutNelemPerThread = - (OutConsecBytes % 16 == 0) - ? 16 / sizeof(OutputType) - : (OutConsecBytes % 8 == 0) - ? 8 / sizeof(OutputType) - : (OutConsecBytes % 4 == 0) - ? 4 / sizeof(OutputType) - : (OutConsecBytes % 2 == 0) ? 2 / sizeof(OutputType) - : 1; + (OutConsecBytes % 16 == 0) ? 16 / sizeof(OutputType) + : (OutConsecBytes % 8 == 0) ? 8 / sizeof(OutputType) + : (OutConsecBytes % 4 == 0) ? 4 / sizeof(OutputType) + : (OutConsecBytes % 2 == 0) ? 2 / sizeof(OutputType) + : 1; static constexpr int InNelemPerThread = - (InConsecBytes % 16 == 0) - ? 16 / sizeof(InputType) - : (InConsecBytes % 8 == 0) - ? 8 / sizeof(InputType) - : (InConsecBytes % 4 == 0) - ? 4 / sizeof(InputType) - : (InConsecBytes % 2 == 0) ? 2 / sizeof(InputType) : 1; + (InConsecBytes % 16 == 0) ? 16 / sizeof(InputType) + : (InConsecBytes % 8 == 0) ? 8 / sizeof(InputType) + : (InConsecBytes % 4 == 0) ? 4 / sizeof(InputType) + : (InConsecBytes % 2 == 0) ? 2 / sizeof(InputType) + : 1; static constexpr int NelemPerThread = BroadcastInput ? OutNelemPerThread @@ -155,43 +150,35 @@ struct Broadcast2Intrinsic { static constexpr int In1ConsecBytes = In1ConsecLen * sizeof(InputType); static constexpr int OutNelemPerThread = - (OutConsecBytes % 16 == 0) - ? 16 / sizeof(OutputType) - : (OutConsecBytes % 8 == 0) - ? 8 / sizeof(OutputType) - : (OutConsecBytes % 4 == 0) - ? 4 / sizeof(OutputType) - : (OutConsecBytes % 2 == 0) ? 2 / sizeof(OutputType) - : 1; + (OutConsecBytes % 16 == 0) ? 16 / sizeof(OutputType) + : (OutConsecBytes % 8 == 0) ? 8 / sizeof(OutputType) + : (OutConsecBytes % 4 == 0) ? 4 / sizeof(OutputType) + : (OutConsecBytes % 2 == 0) ? 2 / sizeof(OutputType) + : 1; static constexpr int In0NelemPerThread = - (In0ConsecBytes % 16 == 0) - ? 16 / sizeof(InputType) - : (In0ConsecBytes % 8 == 0) - ? 8 / sizeof(InputType) - : (In0ConsecBytes % 4 == 0) - ? 4 / sizeof(InputType) - : (In0ConsecBytes % 2 == 0) ? 2 / sizeof(InputType) : 1; + (In0ConsecBytes % 16 == 0) ? 16 / sizeof(InputType) + : (In0ConsecBytes % 8 == 0) ? 8 / sizeof(InputType) + : (In0ConsecBytes % 4 == 0) ? 4 / sizeof(InputType) + : (In0ConsecBytes % 2 == 0) ? 2 / sizeof(InputType) + : 1; static constexpr int In1NelemPerThread = - (In1ConsecBytes % 16 == 0) - ? 16 / sizeof(InputType) - : (In1ConsecBytes % 8 == 0) - ? 8 / sizeof(InputType) - : (In1ConsecBytes % 4 == 0) - ? 4 / sizeof(InputType) - : (In1ConsecBytes % 2 == 0) ? 2 / sizeof(InputType) : 1; + (In1ConsecBytes % 16 == 0) ? 16 / sizeof(InputType) + : (In1ConsecBytes % 8 == 0) ? 8 / sizeof(InputType) + : (In1ConsecBytes % 4 == 0) ? 4 / sizeof(InputType) + : (In1ConsecBytes % 2 == 0) ? 2 / sizeof(InputType) + : 1; static constexpr int NelemPerThread = - (BroadcastInput0 && BroadcastInput1) - ? OutNelemPerThread - : BroadcastInput0 - ? math::gcd::value - : BroadcastInput1 - ? math::gcd::value - : math::gcd::value>::value; + (BroadcastInput0 && BroadcastInput1) ? OutNelemPerThread + : BroadcastInput0 + ? math::gcd::value + : BroadcastInput1 + ? math::gcd::value + : math::gcd::value>::value; static_assert(math::is_pow2::value, "NelemPerThread must be power of 2"); diff --git a/ark/include/kernels/common/vector_type.h b/ark/include/kernels/common/vector_type.h index 1e5316e20..f247c53ee 100644 --- a/ark/include/kernels/common/vector_type.h +++ b/ark/include/kernels/common/vector_type.h @@ -71,28 +71,29 @@ struct Constant { template struct IntrinsicCompute1Exists { template - static auto test(const InputVtype &) - -> decltype(&U::compute, std::true_type{}); + static auto test(const InputVtype &) -> decltype(&U::compute, + std::true_type{}); template static auto test(...) -> std::false_type; - static constexpr bool value = decltype( - test(type::Constant::zero()))::value; + static constexpr bool value = decltype(test( + type::Constant::zero()))::value; }; template struct IntrinsicCompute2Exists { template - static auto test(const InputVtype &, const InputVtype &) - -> decltype(&U::compute, std::true_type{}); + static auto test(const InputVtype &, + const InputVtype &) -> decltype(&U::compute, + std::true_type{}); template static auto test(...) -> std::false_type; - static constexpr bool value = decltype( - test(type::Constant::zero(), - type::Constant::zero()))::value; + static constexpr bool value = decltype(test( + type::Constant::zero(), + type::Constant::zero()))::value; }; template @@ -198,11 +199,10 @@ struct DefaultNelemPerThread { : math::min::value; static const int value = - (sizeof(OutDataType) <= 2 && ConsecutiveDimLen % 8 == 0) - ? 8 - : (ConsecutiveDimLen % 4 == 0) - ? 4 - : (ConsecutiveDimLen % 2 == 0) ? 2 : 1; + (sizeof(OutDataType) <= 2 && ConsecutiveDimLen % 8 == 0) ? 8 + : (ConsecutiveDimLen % 4 == 0) ? 4 + : (ConsecutiveDimLen % 2 == 0) ? 2 + : 1; }; } // namespace ark diff --git a/ark/include/kernels/gemm_ck.h b/ark/include/kernels/gemm_ck.h index 478419691..a15cf49e0 100644 --- a/ark/include/kernels/gemm_ck.h +++ b/ark/include/kernels/gemm_ck.h @@ -90,13 +90,15 @@ struct CkGemmConfig::value; static constexpr auto MXdlPerWave = (TileSizeM == 16) ? 1 - : (TileSizeM < TileSizeN) - ? 1 << (LogMNXdlPerWave / 2) - : 1 << (LogMNXdlPerWave - LogMNXdlPerWave / 2); + : (TileSizeM < TileSizeN) + ? 1 << (LogMNXdlPerWave / 2) + : 1 << (LogMNXdlPerWave - LogMNXdlPerWave / 2); static constexpr auto NXdlPerWave = MNXdlPerWave / MXdlPerWave; static constexpr bool Is_256x256x128 = @@ -197,13 +199,15 @@ struct CkGemmConfig, typename std::conditional, S<1, 0, 2>>::type, typename std::conditional, S<1, 0, 2>>::type, - (IsColA ? 1 : 2), (!IsColA ? 8 : Is_128x128x64 ? 4 : MXdlPerWave), 8, - true, S<4, NumThreads / 4, 1>, + (IsColA ? 1 : 2), + (!IsColA ? 8 + : Is_128x128x64 ? 4 + : MXdlPerWave), + 8, true, S<4, NumThreads / 4, 1>, typename std::conditional, S<0, 2, 1>>::type, typename std::conditional, S<0, 2, 1>>::type, (IsColB ? 2 : 1), - (IsColB ? 8 - : Is_128x32x256 - ? 8 - : (Is_128x32x128 || Is_128x64x128 || Is_128x128x128) - ? 4 - : (Is_128x32x64 || Is_64x32x32) ? 2 : NXdlPerWave), + (IsColB ? 8 + : Is_128x32x256 ? 8 + : (Is_128x32x128 || Is_128x64x128 || Is_128x128x128) ? 4 + : (Is_128x32x64 || Is_64x32x32) ? 2 + : NXdlPerWave), 8, true, 7, 1, 1, LoopSched, PipelineVer>; using ImplXdlCShuffle = @@ -234,16 +240,17 @@ struct CkGemmConfig, S<1, 0, 2>>::type, typename std::conditional, S<1, 0, 2>>::type, (IsColA ? 1 : 2), - (!IsColA ? 8 : (AK1 == 2 || Is_128x128x64) ? 4 : MXdlPerWave), AK1, - (AK1 == 8), S, + (!IsColA ? 8 + : (AK1 == 2 || Is_128x128x64) ? 4 + : MXdlPerWave), + AK1, (AK1 == 8), S, typename std::conditional, S<0, 2, 1>>::type, typename std::conditional, S<0, 2, 1>>::type, (IsColB ? 2 : 1), (IsColB ? 8 - : (BK1 == 2 || Is_256x128x256 || Is_128x128x128 || - Is_128x64x128) - ? 4 - : NXdlPerWave), + : (BK1 == 2 || Is_256x128x256 || Is_128x128x128 || Is_128x64x128) + ? 4 + : NXdlPerWave), BK1, (BK1 == 8), 1, 1, S<1, (Is_128x128x128 || Is_128x64x128 || Is_128x32x128 || @@ -255,16 +262,17 @@ struct CkGemmConfig; #if (DEBUG_CK != 0) - PrintDeviceGemmXdlCShuffle< - NumThreads, TileSizeM, TileSizeN, 32, AK1, BK1, 32, 32, MXdlPerWave, - NXdlPerWave, - (!IsColA ? 8 : (AK1 == 2 || Is_128x128x64) ? 4 : MXdlPerWave), - (IsColB - ? 8 - : (BK1 == 2 || Is_256x128x256 || Is_128x128x128 || Is_128x64x128) - ? 4 - : NXdlPerWave), - 1, 1> + PrintDeviceGemmXdlCShuffle p; #endif // (DEBUG_CK != 0) }; @@ -286,9 +294,9 @@ struct CkGemmConfig::value; static constexpr auto MXdlPerWave = (TileSizeM == 16) ? 1 - : (TileSizeM < TileSizeN) - ? 1 << (LogMNXdlPerWave / 2) - : 1 << (LogMNXdlPerWave - LogMNXdlPerWave / 2); + : (TileSizeM < TileSizeN) + ? 1 << (LogMNXdlPerWave / 2) + : 1 << (LogMNXdlPerWave - LogMNXdlPerWave / 2); static constexpr auto NXdlPerWave = MNXdlPerWave / MXdlPerWave; static constexpr bool Is_256x256x128 = @@ -307,7 +315,8 @@ struct CkGemmConfig, S<1, 0, 2>>::type, typename std::conditional, S<1, 0, 2>>::type, (IsColA ? 1 : 2), - (!IsColA ? 8 : (AK1 == 2 || Is_128x128x64) ? 4 : MXdlPerWave), AK1, - (AK1 == 8), S, + (!IsColA ? 8 + : (AK1 == 2 || Is_128x128x64) ? 4 + : MXdlPerWave), + AK1, (AK1 == 8), S, typename std::conditional, S<0, 2, 1>>::type, typename std::conditional, S<0, 2, 1>>::type, (IsColB ? 2 : 1), (IsColB ? 8 - : (BK1 == 2 || Is_256x128x256 || Is_128x128x128 || - Is_128x64x128) - ? 4 - : NXdlPerWave), + : (BK1 == 2 || Is_256x128x256 || Is_128x128x128 || Is_128x64x128) + ? 4 + : NXdlPerWave), BK1, (BK1 == 8), 1, 1, S<1, (Is_128x128x128 || Is_128x64x128 || Is_128x32x128 || diff --git a/ark/include/kernels/reduce.h b/ark/include/kernels/reduce.h index 62af5840b..587674b91 100644 --- a/ark/include/kernels/reduce.h +++ b/ark/include/kernels/reduce.h @@ -357,13 +357,11 @@ struct WwiseReduce { ReduceShapeChecker; constexpr int InConsecBytes = sizeof(DataType) * InShape::W; constexpr int NelemPerThread = - (InConsecBytes % 16 == 0) - ? 16 / sizeof(DataType) - : (InConsecBytes % 8 == 0) - ? 8 / sizeof(DataType) - : (InConsecBytes % 4 == 0) - ? 4 / sizeof(DataType) - : (InConsecBytes % 2 == 0) ? 2 / sizeof(DataType) : 1; + (InConsecBytes % 16 == 0) ? 16 / sizeof(DataType) + : (InConsecBytes % 8 == 0) ? 8 / sizeof(DataType) + : (InConsecBytes % 4 == 0) ? 4 / sizeof(DataType) + : (InConsecBytes % 2 == 0) ? 2 / sizeof(DataType) + : 1; constexpr int NonReduceDimLength = UnitOutDims::N * UnitOutDims::C * UnitOutDims::H; @@ -411,32 +409,55 @@ struct WwiseReduce { if constexpr (NelemPerThread > 8) { #pragma unroll for (int i = 8; i < NelemPerThread; i += 8) { - ReduceType::template reduce<8>(&reduced[0], &reduced[0], &reduced[i]); + ReduceType::template reduce<8>(&reduced[0], &reduced[0], + &reduced[i]); } - ReduceType::template reduce<4>(&reduced[0], &reduced[0], &reduced[4]); - ReduceType::template reduce<2>(&reduced[0], &reduced[0], &reduced[2]); - ReduceType::template reduce<1>(&reduced[0], &reduced[0], &reduced[1]); + ReduceType::template reduce<4>(&reduced[0], &reduced[0], + &reduced[4]); + ReduceType::template reduce<2>(&reduced[0], &reduced[0], + &reduced[2]); + ReduceType::template reduce<1>(&reduced[0], &reduced[0], + &reduced[1]); } else if constexpr (NelemPerThread == 8) { - ReduceType::template reduce<4>(&reduced[0], &reduced[0], &reduced[4]); - ReduceType::template reduce<2>(&reduced[0], &reduced[0], &reduced[2]); - ReduceType::template reduce<1>(&reduced[0], &reduced[0], &reduced[1]); + ReduceType::template reduce<4>(&reduced[0], &reduced[0], + &reduced[4]); + ReduceType::template reduce<2>(&reduced[0], &reduced[0], + &reduced[2]); + ReduceType::template reduce<1>(&reduced[0], &reduced[0], + &reduced[1]); } else if constexpr (NelemPerThread == 4) { - ReduceType::template reduce<2>(&reduced[0], &reduced[0], &reduced[2]); - ReduceType::template reduce<1>(&reduced[0], &reduced[0], &reduced[1]); + ReduceType::template reduce<2>(&reduced[0], &reduced[0], + &reduced[2]); + ReduceType::template reduce<1>(&reduced[0], &reduced[0], + &reduced[1]); } else if constexpr (NelemPerThread == 2) { - ReduceType::template reduce<1>(&reduced[0], &reduced[0], &reduced[1]); + ReduceType::template reduce<1>(&reduced[0], &reduced[0], + &reduced[1]); } if constexpr (InShape::W % ThreadsPerRow != 0) { UnitOp::sync_threads(); } - // final reduction on shared memory using warp shuffle. - reduced[0] = warpsReduce( - reduced[0], tid, smem_per_warp); + // final reduction using warp shuffle. + // PhysicalThreadsPerRow = actual number of HW threads per row. + constexpr int PhysicalThreadsPerRow = + UnitOp::NumThreads / NonReduceDimLength; + static_assert(PhysicalThreadsPerRow > 0, + "Not enough threads for the tile dimensions. " + "Increase NumWarps or decrease Tile H dimension."); + if constexpr (PhysicalThreadsPerRow <= Arch::ThreadsPerWarp) { + // All threads for one row are within a single warp. + reduced[0] = + warpReduce(reduced[0]); + } else { + // Threads for one row span multiple warps — need shared memory. + reduced[0] = warpsReduce( + reduced[0], tid % PhysicalThreadsPerRow, smem_per_warp); + } - // write the result to output. - if (tid % ThreadsPerRow == 0) { + // write the result to output — first thread of each row group. + if (tid % PhysicalThreadsPerRow == 0) { ReduceType::template postReduce<1>(&out[idx_out], &reduced[0], InShape::W); } diff --git a/ark/model/model_buffer.cpp b/ark/model/model_buffer.cpp index a54b6e81f..3778190d1 100644 --- a/ark/model/model_buffer.cpp +++ b/ark/model/model_buffer.cpp @@ -80,8 +80,7 @@ std::shared_ptr ModelBuffer::deserialize(const Json &serialized) { } else if (!serialized.contains("SendTags")) { ERR(ModelError, "ModelBuffer deserialization failed: missing SendTags"); } else if (!serialized.contains("RecvTags")) { - ERR(ModelError, - "ModelBuffer deserialization failed: missing RecvTags"); + ERR(ModelError, "ModelBuffer deserialization failed: missing RecvTags"); } else if (!serialized.contains("IsExternal")) { ERR(ModelError, "ModelBuffer deserialization failed: missing IsExternal"); diff --git a/ark/model/model_context_manager.cpp b/ark/model/model_context_manager.cpp index 799cce785..e3be664f9 100644 --- a/ark/model/model_context_manager.cpp +++ b/ark/model/model_context_manager.cpp @@ -27,8 +27,6 @@ Json ModelContextManager::get(const std::string& key) const { return context_stack_->get(key); } -Json ModelContextManager::dump() const { - return context_stack_->dump(); -} +Json ModelContextManager::dump() const { return context_stack_->dump(); } } // namespace ark diff --git a/ark/model/model_graph_impl.hpp b/ark/model/model_graph_impl.hpp index b9646d057..18c33f28a 100644 --- a/ark/model/model_graph_impl.hpp +++ b/ark/model/model_graph_impl.hpp @@ -54,7 +54,7 @@ class ModelGraph::Impl { Impl &operator=(const Impl &other); template - ModelOpRef create_op(const std::string &name, Args &&... args) { + ModelOpRef create_op(const std::string &name, Args &&...args) { ModelOpRef op = std::make_shared(std::forward(args)...); std::string name_copy; if (name.empty()) { diff --git a/ark/model/model_op.hpp b/ark/model/model_op.hpp index ab261eb20..6c5bbbbfd 100644 --- a/ark/model/model_op.hpp +++ b/ark/model/model_op.hpp @@ -50,8 +50,8 @@ class ModelOp { return ""; } - virtual std::vector impl_args([ - [maybe_unused]] const Json &config) const { + virtual std::vector impl_args( + [[maybe_unused]] const Json &config) const { return {}; } diff --git a/ark/model/model_tensor.cpp b/ark/model/model_tensor.cpp index 068783045..405faa4e2 100644 --- a/ark/model/model_tensor.cpp +++ b/ark/model/model_tensor.cpp @@ -92,13 +92,9 @@ size_t ModelTensor::shape_bytes() const { return shape_.nelems() * data_type_->bytes(); } -void *ModelTensor::data() const { - return buffer_->data(); -} +void *ModelTensor::data() const { return buffer_->data(); } -void *ModelTensor::data(void *data) { - return buffer_->data(data); -} +void *ModelTensor::data(void *data) { return buffer_->data(data); } bool ModelTensor::is_external() const { return buffer_->is_external(); } diff --git a/ark/ops/ops_all_reduce_test.cpp b/ark/ops/ops_all_reduce_test.cpp index 8cf68b085..e4fe4dac0 100644 --- a/ark/ops/ops_all_reduce_test.cpp +++ b/ark/ops/ops_all_reduce_test.cpp @@ -91,7 +91,8 @@ ark::Tensor all_reduce_packet(ark::Model &m, ark::Tensor input, int rank, std::vector outputs; size_t out_off = flag % 2 == 0 ? 0 : nbytes_per_rank * 2; ark::Dims out_shape = {nbytes_per_rank * 2}; - ark::Dims out_strides = {nbytes_per_rank * 2 * 2}; // packet + double buffer + ark::Dims out_strides = {nbytes_per_rank * 2 * + 2}; // packet + double buffer for (int i = 0; i < rank_num; i++) { if (i != rank) { outputs.push_back(m.tensor(out_shape, ark::UINT8, out_strides, @@ -121,7 +122,8 @@ void test_all_reduce_packet_internal(ark::DimType nelem) { ark::Model m(gpu_id, NumGpus); ark::Tensor ones = m.tensor({nelem}, ark::FP16); ark::Tensor data = m.mul(ones, float(gpu_id + 1)); - ark::Tensor output = all_reduce_packet(m, data, gpu_id, NumGpus, 1, data); + ark::Tensor output = + all_reduce_packet(m, data, gpu_id, NumGpus, 1, data); std::vector ones_vec(ones.shape().nelems(), ark::half_t(1.0f)); @@ -186,7 +188,6 @@ ark::Tensor all_reduce_sm(ark::Model &m, ark::Tensor input, int rank, return res; } - template void test_all_reduce_sm_internal(ark::DimType nelem) { auto config_rule = [nelem](const std::string op_str, const std::string) { @@ -244,36 +245,42 @@ void test_all_reduce_sm_internal(ark::DimType nelem) { } ark::unittest::State test_all_reduce_4gpus() { + UNITTEST_SKIP(ark::unittest::get_gpu_count() < 4); test_all_reduce_internal<4>(64); test_all_reduce_internal<4>(8192); return ark::unittest::SUCCESS; } ark::unittest::State test_all_reduce_8gpus() { + UNITTEST_SKIP(ark::unittest::get_gpu_count() < 8); test_all_reduce_internal<8>(64); test_all_reduce_internal<8>(8192); return ark::unittest::SUCCESS; } ark::unittest::State test_all_reduce_packet_4gpus() { + UNITTEST_SKIP(ark::unittest::get_gpu_count() < 4); test_all_reduce_packet_internal<4>(2048); test_all_reduce_packet_internal<4>(8192); return ark::unittest::SUCCESS; } ark::unittest::State test_all_reduce_packet_8gpus() { + UNITTEST_SKIP(ark::unittest::get_gpu_count() < 8); test_all_reduce_packet_internal<8>(2048); test_all_reduce_packet_internal<8>(8192); return ark::unittest::SUCCESS; } ark::unittest::State test_all_reduce_sm_4gpus() { + UNITTEST_SKIP(ark::unittest::get_gpu_count() < 4); test_all_reduce_sm_internal<4>(2048 * 1024); test_all_reduce_sm_internal<4>(8192 * 1024); return ark::unittest::SUCCESS; } ark::unittest::State test_all_reduce_sm_8gpus() { + UNITTEST_SKIP(ark::unittest::get_gpu_count() < 8); test_all_reduce_sm_internal<8>(2048 * 1024); test_all_reduce_sm_internal<8>(8192 * 1024); return ark::unittest::SUCCESS; diff --git a/ark/ops/ops_arithmetic_test.cpp b/ark/ops/ops_arithmetic_test.cpp deleted file mode 100644 index 6a878c667..000000000 --- a/ark/ops/ops_arithmetic_test.cpp +++ /dev/null @@ -1,451 +0,0 @@ -// Copyright (c) Microsoft Corporation. -// Licensed under the MIT license. - -#include "ops_test_common.hpp" - -template -void baseline_add(std::vector &outputs, - const std::vector &output_shapes, - const std::vector &inputs, - const std::vector &input_shapes, int) { - T *out = static_cast(outputs[0]); - T *t0 = static_cast(inputs[0]); - T *t1 = static_cast(inputs[1]); - - // NumPy-style broadcasted addition - ark::Dims osh = output_shapes[0].dims4(); - ark::Dims ish0 = input_shapes[0].dims4(); - ark::Dims ish1 = input_shapes[1].dims4(); - for (ark::DimType n = 0; n < osh[0]; ++n) { - for (ark::DimType c = 0; c < osh[1]; ++c) { - for (ark::DimType h = 0; h < osh[2]; ++h) { - for (ark::DimType w = 0; w < osh[3]; ++w) { - out[w + h * osh[3] + c * osh[2] * osh[3] + - n * osh[1] * osh[2] * osh[3]] = - t0[(w % ish0[3]) + (h % ish0[2]) * ish0[3] + - (c % ish0[1]) * ish0[2] * ish0[3] + - (n % ish0[0]) * ish0[1] * ish0[2] * ish0[3]] + - t1[(w % ish1[3]) + (h % ish1[2]) * ish1[3] + - (c % ish1[1]) * ish1[2] * ish1[3] + - (n % ish1[0]) * ish1[1] * ish1[2] * ish1[3]]; - } - } - } - } -}; - -template -void baseline_sub(std::vector &outputs, - const std::vector &output_shapes, - const std::vector &inputs, - const std::vector &input_shapes, int) { - T *out = static_cast(outputs[0]); - T *t0 = static_cast(inputs[0]); - T *t1 = static_cast(inputs[1]); - - // NumPy-style broadcasted addition - ark::Dims osh = output_shapes[0].dims4(); - ark::Dims ish0 = input_shapes[0].dims4(); - ark::Dims ish1 = input_shapes[1].dims4(); - for (ark::DimType n = 0; n < osh[0]; ++n) { - for (ark::DimType c = 0; c < osh[1]; ++c) { - for (ark::DimType h = 0; h < osh[2]; ++h) { - for (ark::DimType w = 0; w < osh[3]; ++w) { - out[w + h * osh[3] + c * osh[2] * osh[3] + - n * osh[1] * osh[2] * osh[3]] = - t0[(w % ish0[3]) + (h % ish0[2]) * ish0[3] + - (c % ish0[1]) * ish0[2] * ish0[3] + - (n % ish0[0]) * ish0[1] * ish0[2] * ish0[3]] - - t1[(w % ish1[3]) + (h % ish1[2]) * ish1[3] + - (c % ish1[1]) * ish1[2] * ish1[3] + - (n % ish1[0]) * ish1[1] * ish1[2] * ish1[3]]; - } - } - } - } -}; - -template -void baseline_mul(std::vector &outputs, - const std::vector &output_shapes, - const std::vector &inputs, - const std::vector &input_shapes, int) { - T *out = static_cast(outputs[0]); - T *t0 = static_cast(inputs[0]); - T *t1 = static_cast(inputs[1]); - - // NumPy-style broadcasted multiplication - ark::Dims osh = output_shapes[0].dims4(); - ark::Dims ish0 = input_shapes[0].dims4(); - ark::Dims ish1 = input_shapes[1].dims4(); - for (ark::DimType n = 0; n < osh[0]; ++n) { - for (ark::DimType c = 0; c < osh[1]; ++c) { - for (ark::DimType h = 0; h < osh[2]; ++h) { - for (ark::DimType w = 0; w < osh[3]; ++w) { - out[w + h * osh[3] + c * osh[2] * osh[3] + - n * osh[1] * osh[2] * osh[3]] = - t0[(w % ish0[3]) + (h % ish0[2]) * ish0[3] + - (c % ish0[1]) * ish0[2] * ish0[3] + - (n % ish0[0]) * ish0[1] * ish0[2] * ish0[3]] * - t1[(w % ish1[3]) + (h % ish1[2]) * ish1[3] + - (c % ish1[1]) * ish1[2] * ish1[3] + - (n % ish1[0]) * ish1[1] * ish1[2] * ish1[3]]; - } - } - } - } -}; - -template -void baseline_div(std::vector &outputs, - const std::vector &output_shapes, - const std::vector &inputs, - const std::vector &input_shapes, int) { - T *out = static_cast(outputs[0]); - T *t0 = static_cast(inputs[0]); - T *t1 = static_cast(inputs[1]); - - // NumPy-style broadcasted division - ark::Dims osh = output_shapes[0].dims4(); - ark::Dims ish0 = input_shapes[0].dims4(); - ark::Dims ish1 = input_shapes[1].dims4(); - for (ark::DimType n = 0; n < osh[0]; ++n) { - for (ark::DimType c = 0; c < osh[1]; ++c) { - for (ark::DimType h = 0; h < osh[2]; ++h) { - for (ark::DimType w = 0; w < osh[3]; ++w) { - out[w + h * osh[3] + c * osh[2] * osh[3] + - n * osh[1] * osh[2] * osh[3]] = - t0[(w % ish0[3]) + (h % ish0[2]) * ish0[3] + - (c % ish0[1]) * ish0[2] * ish0[3] + - (n % ish0[0]) * ish0[1] * ish0[2] * ish0[3]] / - t1[(w % ish1[3]) + (h % ish1[2]) * ish1[3] + - (c % ish1[1]) * ish1[2] * ish1[3] + - (n % ish1[0]) * ish1[1] * ish1[2] * ish1[3]]; - } - } - } - } -}; - -ark::unittest::State test_add_fp32() { - ark::Model m; - ark::Tensor t0 = m.tensor({8192}, ark::FP32); - ark::Tensor t1 = m.tensor({8192}, ark::FP32); - ark::Tensor out = m.add(t0, t1); - - auto result = - ark::op_test("add_fp32", m, {t0, t1}, {out}, baseline_add); - UNITTEST_LOG(result); - UNITTEST_EQ(result.max_diff[0], 0.0f); - return ark::unittest::SUCCESS; -} - -ark::unittest::State test_add_fp16() { - ark::Model m; - ark::Tensor t0 = m.tensor({8192}, ark::FP16); - ark::Tensor t1 = m.tensor({8192}, ark::FP16); - ark::Tensor out = m.add(t0, t1); - - auto result = - ark::op_test("add_fp16", m, {t0, t1}, {out}, baseline_add); - UNITTEST_LOG(result); - UNITTEST_EQ(result.max_diff[0], 0.0f); - return ark::unittest::SUCCESS; -} - -ark::unittest::State test_add_bf16() { - ark::Model m; - ark::Tensor t0 = m.tensor({8192}, ark::BF16); - ark::Tensor t1 = m.tensor({8192}, ark::BF16); - ark::Tensor out = m.add(t0, t1); - - auto result = ark::op_test("add_bf16", m, {t0, t1}, {out}, - baseline_add); - UNITTEST_LOG(result); - UNITTEST_EQ(result.max_diff[0], 0.0f); - return ark::unittest::SUCCESS; -} - -ark::unittest::State test_add_overwrite() { - ark::Model m; - ark::Tensor t0 = m.tensor({8192}, ark::FP16); - ark::Tensor t1 = m.tensor({8192}, ark::FP16); - ark::Tensor out = m.add(t0, t1, t1); - - auto result = ark::op_test("add_overwrite", m, {t0, t1}, {out}, - baseline_add); - UNITTEST_LOG(result); - UNITTEST_EQ(result.max_diff[0], 0.0f); - return ark::unittest::SUCCESS; -} - -ark::unittest::State test_add_broadcast() { - { - ark::Model m; - ark::Tensor t0 = m.tensor({4, 1024}, ark::FP16); - ark::Tensor t1 = m.tensor({1, 1024}, ark::FP16); - ark::Tensor out = m.add(t0, t1); - - auto result = ark::op_test("add_broadcast", m, {t0, t1}, {out}, - baseline_add); - UNITTEST_LOG(result); - UNITTEST_EQ(result.max_diff[0], 0.0f); - } - { - ark::Model m; - ark::Tensor t0 = m.tensor({4, 64}, ark::FP16); - ark::Tensor t1 = m.tensor({4, 1}, ark::FP16, {4, 2}); - ark::Tensor out = m.add(t0, t1); - - auto result = ark::op_test("add_broadcast", m, {t0, t1}, {out}, - baseline_add); - UNITTEST_LOG(result); - UNITTEST_EQ(result.max_diff[0], 0.0f); - } - { - ark::Model m; - ark::Tensor t0 = m.tensor({3, 1, 1024}, ark::FP16); - ark::Tensor t1 = m.tensor({1, 4, 1}, ark::FP16, {1, 4, 2}); - ark::Tensor out = m.add(t0, t1); - - auto result = ark::op_test("add_broadcast", m, {t0, t1}, {out}, - baseline_add); - UNITTEST_LOG(result); - UNITTEST_EQ(result.max_diff[0], 0.0f); - } - return ark::unittest::SUCCESS; -} - -ark::unittest::State test_add_offset() { - { - ark::Model m; - ark::Tensor t0 = m.tensor({2, 64}, ark::FP16, {4, 128}, {2, 64}); - ark::Tensor t1 = m.tensor({2, 64}, ark::FP16); - ark::Tensor out = m.add(t0, t1); - - auto result = ark::op_test("add_offset", m, {t0, t1}, {out}, - baseline_add); - UNITTEST_LOG(result); - UNITTEST_EQ(result.max_diff[0], 0.0f); - } - return ark::unittest::SUCCESS; -} - -ark::unittest::State test_add_invalid() { - { - ark::Model m; - ark::Tensor t0 = m.tensor({1024}, ark::FP16); - ark::Tensor t1 = m.tensor({1024}, ark::FP32); - UNITTEST_THROW(m.add(t0, t1), ark::ModelError); - } - { - ark::Model m; - ark::Tensor t0 = m.tensor({8192}, ark::FP16); - ark::Tensor t1 = m.tensor({8192}, ark::FP16); - ark::Tensor out = m.tensor({8192}, ark::FP32); - UNITTEST_THROW(m.add(t0, t1, out), ark::ModelError); - } - { - ark::Model m; - ark::Tensor t0 = m.tensor({8192}, ark::FP16); - ark::Tensor t1 = m.tensor({8192}, ark::FP16); - ark::Tensor out = m.tensor({1024}, ark::FP16); - UNITTEST_THROW(m.add(t0, t1, out), ark::ModelError); - } - return ark::unittest::SUCCESS; -} - -ark::unittest::State test_sub_fp32() { - ark::Model m; - ark::Tensor t0 = m.tensor({8192}, ark::FP32); - ark::Tensor t1 = m.tensor({8192}, ark::FP32); - ark::Tensor out = m.sub(t0, t1); - - auto result = - ark::op_test("sub_fp32", m, {t0, t1}, {out}, baseline_sub); - UNITTEST_LOG(result); - UNITTEST_EQ(result.max_diff[0], 0.0f); - return ark::unittest::SUCCESS; -} - -ark::unittest::State test_sub_invalid() { - { - ark::Model m; - ark::Tensor t0 = m.tensor({1024}, ark::FP16); - ark::Tensor t1 = m.tensor({1024}, ark::FP32); - UNITTEST_THROW(m.sub(t0, t1), ark::ModelError); - } - { - ark::Model m; - ark::Tensor t0 = m.tensor({8192}, ark::FP16); - ark::Tensor t1 = m.tensor({8192}, ark::FP16); - ark::Tensor out = m.tensor({8192}, ark::FP32); - UNITTEST_THROW(m.sub(t0, t1, out), ark::ModelError); - } - { - ark::Model m; - ark::Tensor t0 = m.tensor({8192}, ark::FP16); - ark::Tensor t1 = m.tensor({8192}, ark::FP16); - ark::Tensor out = m.tensor({1024}, ark::FP16); - UNITTEST_THROW(m.sub(t0, t1, out), ark::ModelError); - } - return ark::unittest::SUCCESS; -} - -ark::unittest::State test_mul_fp32() { - ark::Model m; - ark::Tensor t0 = m.tensor({8192}, ark::FP32); - ark::Tensor t1 = m.tensor({8192}, ark::FP32); - ark::Tensor out = m.mul(t0, t1); - - auto result = - ark::op_test("mul_fp32", m, {t0, t1}, {out}, baseline_mul); - UNITTEST_LOG(result); - UNITTEST_EQ(result.max_diff[0], 0.0f); - return ark::unittest::SUCCESS; -} - -ark::unittest::State test_mul_fp16() { - ark::Model m; - ark::Tensor t0 = m.tensor({8192}, ark::FP16); - ark::Tensor t1 = m.tensor({8192}, ark::FP16); - ark::Tensor out = m.mul(t0, t1); - - auto result = - ark::op_test("mul_fp16", m, {t0, t1}, {out}, baseline_mul); - UNITTEST_LOG(result); - UNITTEST_EQ(result.max_diff[0], 0.0f); - return ark::unittest::SUCCESS; -} - -ark::unittest::State test_mul_overwrite() { - ark::Model m; - ark::Tensor t0 = m.tensor({8192}, ark::FP16); - ark::Tensor t1 = m.tensor({8192}, ark::FP16); - ark::Tensor out = m.mul(t0, t1, t1); - - auto result = ark::op_test("mul_overwrite", m, {t0, t1}, {out}, - baseline_mul); - UNITTEST_LOG(result); - UNITTEST_EQ(result.max_diff[0], 0.0f); - return ark::unittest::SUCCESS; -} - -ark::unittest::State test_mul_broadcast() { - { - ark::Model m; - ark::Tensor t0 = m.tensor({4, 1024}, ark::FP16); - ark::Tensor t1 = m.tensor({1, 1024}, ark::FP16); - ark::Tensor out = m.mul(t0, t1); - - auto result = ark::op_test("mul_broadcast", m, {t0, t1}, {out}, - baseline_mul); - UNITTEST_LOG(result); - UNITTEST_EQ(result.max_diff[0], 0.0f); - } - { - ark::Model m; - ark::Tensor t0 = m.tensor({4, 1024}, ark::FP16); - ark::Tensor t1 = m.tensor({4, 1}, ark::FP16, {4, 2}); - ark::Tensor out = m.mul(t0, t1); - - auto result = ark::op_test("mul_broadcast", m, {t0, t1}, {out}, - baseline_mul); - UNITTEST_LOG(result); - UNITTEST_EQ(result.max_diff[0], 0.0f); - } - { - ark::Model m; - ark::Tensor t0 = m.tensor({3, 1, 1024}, ark::FP16); - ark::Tensor t1 = m.tensor({1, 4, 1}, ark::FP16, {1, 4, 2}); - ark::Tensor out = m.mul(t0, t1); - - auto result = ark::op_test("mul_broadcast", m, {t0, t1}, {out}, - baseline_mul); - UNITTEST_LOG(result); - UNITTEST_EQ(result.max_diff[0], 0.0f); - } - return ark::unittest::SUCCESS; -} - -ark::unittest::State test_mul_invalid() { - { - ark::Model m; - ark::Tensor t0 = m.tensor({1024}, ark::FP16); - ark::Tensor t1 = m.tensor({1024}, ark::FP32); - UNITTEST_THROW(m.mul(t0, t1), ark::ModelError); - } - { - ark::Model m; - ark::Tensor t0 = m.tensor({8192}, ark::FP16); - ark::Tensor t1 = m.tensor({8192}, ark::FP16); - ark::Tensor out = m.tensor({8192}, ark::FP32); - UNITTEST_THROW(m.mul(t0, t1, out), ark::ModelError); - } - { - ark::Model m; - ark::Tensor t0 = m.tensor({8192}, ark::FP16); - ark::Tensor t1 = m.tensor({8192}, ark::FP16); - ark::Tensor out = m.tensor({1024}, ark::FP16); - UNITTEST_THROW(m.mul(t0, t1, out), ark::ModelError); - } - return ark::unittest::SUCCESS; -} - -ark::unittest::State test_div_fp32() { - ark::Model m; - ark::Tensor t0 = m.tensor({8192}, ark::FP32); - ark::Tensor t1 = m.tensor({8192}, ark::FP32); - ark::Tensor out = m.div(t0, t1); - - auto result = - ark::op_test("div_fp32", m, {t0, t1}, {out}, baseline_div); - UNITTEST_LOG(result); - UNITTEST_EQ(result.max_diff[0], 0.0f); - return ark::unittest::SUCCESS; -} - -ark::unittest::State test_div_invalid() { - { - ark::Model m; - ark::Tensor t0 = m.tensor({8192}, ark::FP32); - ark::Tensor t1 = m.tensor({8192}, ark::FP16); - UNITTEST_THROW(m.div(t0, t1), ark::ModelError); - } - { - ark::Model m; - ark::Tensor t0 = m.tensor({8192}, ark::FP32); - ark::Tensor t1 = m.tensor({8192}, ark::FP32); - ark::Tensor out = m.tensor({8192}, ark::FP16); - UNITTEST_THROW(m.div(t0, t1, out), ark::ModelError); - } - { - ark::Model m; - ark::Tensor t0 = m.tensor({8192}, ark::FP32); - ark::Tensor t1 = m.tensor({8192}, ark::FP32); - ark::Tensor out = m.tensor({1024}, ark::FP16); - UNITTEST_THROW(m.div(t0, t1, out), ark::ModelError); - } - return ark::unittest::SUCCESS; -} - -int main() { - ark::init(); - UNITTEST(test_add_fp32); - UNITTEST(test_add_fp16); - UNITTEST(test_add_bf16); - UNITTEST(test_add_overwrite); - UNITTEST(test_add_broadcast); - UNITTEST(test_add_offset); - UNITTEST(test_add_invalid); - UNITTEST(test_sub_fp32); - UNITTEST(test_sub_invalid); - UNITTEST(test_mul_fp32); - UNITTEST(test_mul_fp16); - UNITTEST(test_mul_overwrite); - UNITTEST(test_mul_broadcast); - UNITTEST(test_mul_invalid); - UNITTEST(test_div_fp32); - UNITTEST(test_div_invalid); - return ark::unittest::SUCCESS; -} diff --git a/ark/ops/ops_broadcast.cpp b/ark/ops/ops_broadcast.cpp index 2fd02b801..8642feefd 100644 --- a/ark/ops/ops_broadcast.cpp +++ b/ark/ops/ops_broadcast.cpp @@ -39,13 +39,13 @@ std::string ModelOpBroadcast1::impl_name(const Json &config) const { std::to_string(0)}); } -std::vector ModelOpBroadcast1::impl_args([ - [maybe_unused]] const Json &config) const { +std::vector ModelOpBroadcast1::impl_args( + [[maybe_unused]] const Json &config) const { return {result_tensors_[0], read_tensors_[0]}; } -Json ModelOpBroadcast1::default_config([ - [maybe_unused]] const ArchRef arch) const { +Json ModelOpBroadcast1::default_config( + [[maybe_unused]] const ArchRef arch) const { Json config; config["NumWarps"] = 1; config["SramBytes"] = 0; @@ -108,8 +108,8 @@ std::string ModelOpBroadcast2::impl_name(const Json &config) const { std::to_string(0)}); } -std::vector ModelOpBroadcast2::impl_args([ - [maybe_unused]] const Json &config) const { +std::vector ModelOpBroadcast2::impl_args( + [[maybe_unused]] const Json &config) const { std::vector args; args.emplace_back(result_tensors_[0]); args.emplace_back(read_tensors_[0]); @@ -117,8 +117,8 @@ std::vector ModelOpBroadcast2::impl_args([ return args; } -Json ModelOpBroadcast2::default_config([ - [maybe_unused]] const ArchRef arch) const { +Json ModelOpBroadcast2::default_config( + [[maybe_unused]] const ArchRef arch) const { Json config; config["NumWarps"] = 1; config["SramBytes"] = 0; diff --git a/ark/ops/ops_cast_test.cpp b/ark/ops/ops_cast_test.cpp deleted file mode 100644 index 8404e07f5..000000000 --- a/ark/ops/ops_cast_test.cpp +++ /dev/null @@ -1,322 +0,0 @@ -// Copyright (c) Microsoft Corporation. -// Licensed under the MIT license. - -#include "ops_test_common.hpp" - -template -void baseline_cast(std::vector &outputs, - const std::vector &output_shapes, - const std::vector &inputs, - const std::vector &, int) { - ToType *out = static_cast(outputs[0]); - FromType *input = static_cast(inputs[0]); - ark::Dims osh = output_shapes[0]; - for (ark::DimType i = 0; i < osh.nelems(); ++i) { - out[i] = ToType(input[i]); - } -}; - -template -void baseline_cast_from_byte(std::vector &outputs, - const std::vector &output_shapes, - const std::vector &inputs, - const std::vector &, int) { - ToType *out = static_cast(outputs[0]); - // input is a byte array, but force read it as ToType. - ToType *input = reinterpret_cast(inputs[0]); - ark::Dims osh = output_shapes[0]; - for (ark::DimType i = 0; i < osh.nelems(); ++i) { - out[i] = input[i]; - } -}; - -template -void baseline_cast_to_byte(std::vector &outputs, - const std::vector &, - const std::vector &inputs, - const std::vector &input_shapes, int) { - // output is a byte array, but force write it as FromType. - FromType *out = reinterpret_cast(outputs[0]); - FromType *input = static_cast(inputs[0]); - ark::Dims ish = input_shapes[0]; - for (ark::DimType i = 0; i < ish.nelems(); ++i) { - out[i] = input[i]; - } -}; - -ark::unittest::State test_cast_fp16_to_fp32() { - ark::Model m; - ark::Tensor t = m.tensor(ark::Dims(4, 2, 1024), ark::FP16); - ark::Tensor out = m.cast(t, ark::FP32); - - auto result = ark::op_test("cast_fp16_to_fp32", m, {t}, {out}, - baseline_cast); - UNITTEST_LOG(result); - UNITTEST_EQ(result.max_diff[0], 0.0f); - return ark::unittest::SUCCESS; -} - -ark::unittest::State test_cast_fp16_to_int32() { - ark::Model m; - ark::Tensor t = m.tensor(ark::Dims(4, 2, 1024), ark::FP16); - ark::Tensor out = m.cast(t, ark::INT32); - - std::vector input_data(t.shape().nelems()); - for (size_t i = 0; i < input_data.size(); ++i) { - input_data[i] = ark::half_t(int((i + 1) % 1000)); - } - - auto result = - ark::op_test("cast_fp16_to_int32", m, {t}, {out}, - baseline_cast, {input_data.data()}); - UNITTEST_LOG(result); - UNITTEST_EQ(result.max_diff[0], 0.0f); - return ark::unittest::SUCCESS; -} - -ark::unittest::State test_cast_fp32_to_fp16() { - ark::Model m; - ark::Tensor t = m.tensor(ark::Dims(4, 2, 1024), ark::FP32); - ark::Tensor out = m.cast(t, ark::FP16); - - auto result = ark::op_test("cast_fp32_to_fp16", m, {t}, {out}, - baseline_cast); - UNITTEST_LOG(result); - UNITTEST_EQ(result.max_diff[0], 0.0f); - return ark::unittest::SUCCESS; -} - -ark::unittest::State test_cast_fp32_to_int32() { - ark::Model m; - ark::Tensor t = m.tensor(ark::Dims(4, 2, 1024), ark::FP32); - ark::Tensor out = m.cast(t, ark::INT32); - - std::vector input_data(t.shape().nelems()); - for (size_t i = 0; i < input_data.size(); ++i) { - input_data[i] = float((i + 1) % 1000); - } - - auto result = ark::op_test("cast_fp32_to_int32", m, {t}, {out}, - baseline_cast, {input_data.data()}); - UNITTEST_LOG(result); - UNITTEST_EQ(result.max_diff[0], 0.0f); - return ark::unittest::SUCCESS; -} - -ark::unittest::State test_cast_int32_to_fp32() { - ark::Model m; - ark::Tensor t = m.tensor(ark::Dims(4, 2, 1024), ark::INT32); - ark::Tensor out = m.cast(t, ark::FP32); - - std::vector input_data(t.shape().nelems()); - for (size_t i = 0; i < input_data.size(); ++i) { - input_data[i] = (i + 1) % 1000; - } - - auto result = ark::op_test("cast_int32_to_fp32", m, {t}, {out}, - baseline_cast, {input_data.data()}); - UNITTEST_LOG(result); - UNITTEST_EQ(result.max_diff[0], 0.0f); - return ark::unittest::SUCCESS; -} - -ark::unittest::State test_cast_int32_to_fp16() { - ark::Model m; - ark::Tensor t = m.tensor(ark::Dims(4, 2, 1024), ark::INT32); - ark::Tensor out = m.cast(t, ark::FP16); - - std::vector input_data(t.shape().nelems()); - for (size_t i = 0; i < input_data.size(); ++i) { - input_data[i] = (i + 1) % 1000; - } - - auto result = - ark::op_test("cast_int32_to_fp16", m, {t}, {out}, - baseline_cast, {input_data.data()}); - UNITTEST_LOG(result); - UNITTEST_EQ(result.max_diff[0], 0.0f); - return ark::unittest::SUCCESS; -} - -ark::unittest::State test_cast_byte_to_fp32() { - ark::Model m; - ark::Tensor t = m.tensor(ark::Dims(4, 2, 1024), ark::BYTE); - ark::Tensor out = m.cast(t, ark::FP32); - - // For preventing optimize-out - m.noop(t); - m.noop(out); - - auto result = ark::op_test("cast_byte_to_fp32", m, {t}, {out}, - baseline_cast_from_byte); - UNITTEST_LOG(result); - UNITTEST_EQ(result.max_diff[0], 0.0f); - return ark::unittest::SUCCESS; -} - -ark::unittest::State test_cast_byte_to_fp16() { - ark::Model m; - ark::Tensor t = m.tensor(ark::Dims(4, 2, 1024), ark::BYTE); - ark::Tensor out = m.cast(t, ark::FP16); - - // For preventing optimize-out - m.noop(t); - m.noop(out); - - auto result = ark::op_test("cast_byte_to_fp16", m, {t}, {out}, - baseline_cast_from_byte); - UNITTEST_LOG(result); - UNITTEST_EQ(result.max_diff[0], 0.0f); - return ark::unittest::SUCCESS; -} - -ark::unittest::State test_cast_byte_to_int32() { - ark::Model m; - ark::Tensor t = m.tensor(ark::Dims(4, 2, 1024), ark::BYTE); - ark::Tensor out = m.cast(t, ark::INT32); - - // For preventing optimize-out - m.noop(t); - m.noop(out); - - auto result = ark::op_test("cast_byte_to_int32", m, {t}, {out}, - baseline_cast_from_byte); - UNITTEST_LOG(result); - UNITTEST_EQ(result.max_diff[0], 0.0f); - return ark::unittest::SUCCESS; -} - -ark::unittest::State test_cast_fp32_to_byte() { - ark::Model m; - ark::Tensor t = m.tensor(ark::Dims(4, 2, 1024), ark::FP32); - ark::Tensor out = m.cast(t, ark::BYTE); - - // For preventing optimize-out - m.noop(t); - m.noop(out); - - auto result = ark::op_test("cast_fp32_to_byte", m, {t}, {out}, - baseline_cast_to_byte); - UNITTEST_LOG(result); - UNITTEST_EQ(result.max_diff[0], 0.0f); - return ark::unittest::SUCCESS; -} - -ark::unittest::State test_cast_fp16_to_byte() { - ark::Model m; - ark::Tensor t = m.tensor(ark::Dims(4, 2, 1024), ark::FP16); - ark::Tensor out = m.cast(t, ark::BYTE); - - // For preventing optimize-out - m.noop(t); - m.noop(out); - - auto result = ark::op_test("cast_fp16_to_byte", m, {t}, {out}, - baseline_cast_to_byte); - UNITTEST_LOG(result); - UNITTEST_EQ(result.max_diff[0], 0.0f); - return ark::unittest::SUCCESS; -} - -ark::unittest::State test_cast_int32_to_byte() { - ark::Model m; - ark::Tensor t = m.tensor(ark::Dims(4, 2, 1024), ark::INT32); - ark::Tensor out = m.cast(t, ark::BYTE); - - // For preventing optimize-out - m.noop(t); - m.noop(out); - - auto result = ark::op_test("cast_int32_to_byte", m, {t}, {out}, - baseline_cast_to_byte); - UNITTEST_LOG(result); - UNITTEST_EQ(result.max_diff[0], 0.0f); - return ark::unittest::SUCCESS; -} - -ark::unittest::State test_cast_bf16_to_float() { - ark::Model m; - ark::Tensor t = m.tensor(ark::Dims(4, 2, 1024), ark::BF16); - ark::Tensor out = m.cast(t, ark::FP32); - - auto result = ark::op_test("cast_bf16_to_float", m, {t}, {out}, - baseline_cast); - UNITTEST_LOG(result); - UNITTEST_EQ(result.max_diff[0], 0.0f); - return ark::unittest::SUCCESS; -} - -ark::unittest::State test_cast_float_to_bf16() { - ark::Model m; - ark::Tensor t = m.tensor(ark::Dims(4, 2, 1024), ark::FP32); - ark::Tensor out = m.cast(t, ark::BF16); - - auto result = ark::op_test("cast_float_to_bf16", m, {t}, {out}, - baseline_cast); - UNITTEST_LOG(result); - UNITTEST_EQ(result.max_diff[0], 0.0f); - return ark::unittest::SUCCESS; -} - -ark::unittest::State test_cast_invalid() { - { - ark::Model m; - ark::Tensor t = m.tensor(ark::Dims(1), ark::BYTE); - UNITTEST_THROW(m.cast(t, ark::FP32), ark::ModelError); - } - { - ark::Model m; - ark::Tensor t0 = m.tensor(ark::Dims(4, 1), ark::BYTE); - m.cast(t0, ark::FP32); // ok - ark::Tensor t1 = m.tensor(ark::Dims(4, 1, 1), ark::BYTE); - m.cast(t1, ark::FP32); // ok - ark::Tensor t2 = m.tensor(ark::Dims(4, 1, 1, 1), ark::BYTE); - m.cast(t2, ark::FP32); // ok - ark::Tensor t3 = m.tensor(ark::Dims(7, 1), ark::BYTE); - UNITTEST_THROW(m.cast(t3, ark::FP32), ark::ModelError); - ark::Tensor t4 = m.tensor(ark::Dims(7, 1, 1), ark::BYTE); - UNITTEST_THROW(m.cast(t4, ark::FP32), ark::ModelError); - ark::Tensor t5 = m.tensor(ark::Dims(7, 1, 1, 1), ark::BYTE); - UNITTEST_THROW(m.cast(t5, ark::FP32), ark::ModelError); - } - { - ark::Model m; - ark::Tensor t0 = m.tensor({8, 1}, ark::BYTE); - m.cast(t0, ark::FP32); // ok - ark::Tensor t1 = m.tensor({8, 1}, ark::BYTE, {13, 1}, {0, 0}, {9, 1}); - UNITTEST_THROW(m.cast(t1, ark::FP32), ark::ModelError); - } - { - ark::Model m; - ark::Tensor t0 = m.tensor({8, 1}, ark::FP16); - ark::Tensor out = m.tensor({8, 1}, ark::INT32); - UNITTEST_THROW(m.cast(t0, ark::FP32, out), ark::ModelError); - } - { - ark::Model m; - ark::Tensor t0 = m.tensor({8, 1}, ark::FP16); - ark::Tensor out = m.tensor({4, 1}, ark::FP32); - UNITTEST_THROW(m.cast(t0, ark::FP32, out), ark::ModelError); - } - return ark::unittest::SUCCESS; -} - -int main() { - ark::init(); - UNITTEST(test_cast_fp16_to_fp32); - UNITTEST(test_cast_fp16_to_int32); - UNITTEST(test_cast_fp32_to_fp16); - UNITTEST(test_cast_fp32_to_int32); - UNITTEST(test_cast_int32_to_fp32); - UNITTEST(test_cast_int32_to_fp16); - UNITTEST(test_cast_byte_to_fp32); - UNITTEST(test_cast_byte_to_fp16); - UNITTEST(test_cast_byte_to_int32); - UNITTEST(test_cast_fp32_to_byte); - UNITTEST(test_cast_fp16_to_byte); - UNITTEST(test_cast_int32_to_byte); - UNITTEST(test_cast_bf16_to_float); - UNITTEST(test_cast_float_to_bf16); - UNITTEST(test_cast_invalid); - return ark::unittest::SUCCESS; -} diff --git a/ark/ops/ops_communication.cpp b/ark/ops/ops_communication.cpp index c5be1ca65..4e221e173 100644 --- a/ark/ops/ops_communication.cpp +++ b/ark/ops/ops_communication.cpp @@ -71,8 +71,8 @@ std::string ModelOpSend::impl_name(const Json &config) const { output->data_type()->type_str()}); } -std::vector ModelOpSend::impl_args([ - [maybe_unused]] const Json &config) const { +std::vector ModelOpSend::impl_args( + [[maybe_unused]] const Json &config) const { return {ModelOffset(write_tensors_[0]), ModelOffset(read_tensors_[0])}; } @@ -107,13 +107,13 @@ std::string ModelOpSendDone::impl_name(const Json &config) const { std::to_string(remote_rank)}); } -std::vector ModelOpSendDone::impl_args([ - [maybe_unused]] const Json &config) const { +std::vector ModelOpSendDone::impl_args( + [[maybe_unused]] const Json &config) const { return {}; } -Json ModelOpSendDone::default_config([ - [maybe_unused]] const ArchRef arch) const { +Json ModelOpSendDone::default_config( + [[maybe_unused]] const ArchRef arch) const { return {{"ChannelType", "Proxy"}, {"NumTasks", 1}, {"NumWarps", 1}, @@ -138,8 +138,8 @@ ModelOpRecv::ModelOpRecv(ModelTensorRef output, int remote_rank, int tag) } std::string ModelOpRecv::impl_name(const Json &config) const { - check_fields_config(config, - {"ChannelType", "NumTasks", "NumWarps", "SramBytes", "Wait"}); + check_fields_config( + config, {"ChannelType", "NumTasks", "NumWarps", "SramBytes", "Wait"}); std::string channel_type = config["ChannelType"]; bool wait = config["Wait"]; if (channel_type != "Proxy" && channel_type != "SecondaryProxy" && @@ -155,8 +155,8 @@ std::string ModelOpRecv::impl_name(const Json &config) const { std::to_string(max_spin_cnt), std::to_string(wait)}); } -std::vector ModelOpRecv::impl_args([ - [maybe_unused]] const Json &config) const { +std::vector ModelOpRecv::impl_args( + [[maybe_unused]] const Json &config) const { return {}; } @@ -231,13 +231,13 @@ std::string ModelOpSendPacket::impl_name(const Json &config) const { packet_type, std::to_string(flag)}); } -std::vector ModelOpSendPacket::impl_args([ - [maybe_unused]] const Json &config) const { +std::vector ModelOpSendPacket::impl_args( + [[maybe_unused]] const Json &config) const { return {ModelOffset(write_tensors_[0]), ModelOffset(read_tensors_[0])}; } -Json ModelOpSendPacket::default_config([ - [maybe_unused]] const ArchRef arch) const { +Json ModelOpSendPacket::default_config( + [[maybe_unused]] const ArchRef arch) const { Json config; if (arch->belongs_to(ARCH_ROCM)) { config["PacketType"] = "mscclpp::LL8Packet"; @@ -324,13 +324,13 @@ std::string ModelOpRecvPacket::impl_name(const Json &config) const { packet_type, std::to_string(flag)}); } -std::vector ModelOpRecvPacket::impl_args([ - [maybe_unused]] const Json &config) const { +std::vector ModelOpRecvPacket::impl_args( + [[maybe_unused]] const Json &config) const { return {ModelOffset(write_tensors_[0]), ModelOffset(read_tensors_[1])}; } -Json ModelOpRecvPacket::default_config([ - [maybe_unused]] const ArchRef arch) const { +Json ModelOpRecvPacket::default_config( + [[maybe_unused]] const ArchRef arch) const { Json config; if (arch->belongs_to(ARCH_ROCM)) { config["PacketType"] = "mscclpp::LL8Packet"; @@ -418,8 +418,8 @@ std::string ModelOpRecvReduceSendPacket::impl_name(const Json &config) const { input->data_type()->type_str(), std::to_string(flag)}); } -std::vector ModelOpRecvReduceSendPacket::impl_args([ - [maybe_unused]] const Json &config) const { +std::vector ModelOpRecvReduceSendPacket::impl_args( + [[maybe_unused]] const Json &config) const { std::vector args = {write_tensors_[0], read_tensors_[0], read_tensors_[1]}; for (size_t i = 1; i < write_tensors_.size(); ++i) { @@ -431,8 +431,8 @@ std::vector ModelOpRecvReduceSendPacket::impl_args([ return args; } -Json ModelOpRecvReduceSendPacket::default_config([ - [maybe_unused]] const ArchRef arch) const { +Json ModelOpRecvReduceSendPacket::default_config( + [[maybe_unused]] const ArchRef arch) const { Json config; if (arch->belongs_to(ARCH_ROCM)) { config["PacketType"] = "mscclpp::LL8Packet"; @@ -452,12 +452,10 @@ Json ModelOpRecvReduceSendPacket::default_config([ return config; } -ModelOpRecvReduceSend::ModelOpRecvReduceSend(ModelTensorRef input, - ModelTensorRef output, int rank, - const std::vector &remote_ranks, - int recv_tag, int output_tag, - std::vector &peer_output_refs, - ModelTensorRef scratch) +ModelOpRecvReduceSend::ModelOpRecvReduceSend( + ModelTensorRef input, ModelTensorRef output, int rank, + const std::vector &remote_ranks, int recv_tag, int output_tag, + std::vector &peer_output_refs, ModelTensorRef scratch) : ModelOp("RecvReduceSend") { check_null(input); uint32_t n_remote_ranks = remote_ranks.size(); @@ -519,8 +517,8 @@ std::string ModelOpRecvReduceSend::impl_name(const Json &config) const { input->data_type()->type_str(), input->data_type()->type_str()}); } -std::vector ModelOpRecvReduceSend::impl_args([ - [maybe_unused]] const Json &config) const { +std::vector ModelOpRecvReduceSend::impl_args( + [[maybe_unused]] const Json &config) const { std::vector args = {write_tensors_[0], read_tensors_[0], read_tensors_[1]}; for (size_t i = 1; i < write_tensors_.size(); ++i) { @@ -532,8 +530,8 @@ std::vector ModelOpRecvReduceSend::impl_args([ return args; } -Json ModelOpRecvReduceSend::default_config([ - [maybe_unused]] const ArchRef arch) const { +Json ModelOpRecvReduceSend::default_config( + [[maybe_unused]] const ArchRef arch) const { Json config; config["NumWarps"] = 1; config["SramBytes"] = 0; @@ -576,12 +574,13 @@ std::string ModelOpDeviceSync::impl_name(const Json &config) const { std::to_string(peer_num), std::to_string(rank)}); } -std::vector ModelOpDeviceSync::impl_args([ - [maybe_unused]] const Json &config) const { +std::vector ModelOpDeviceSync::impl_args( + [[maybe_unused]] const Json &config) const { return {}; } -Json ModelOpDeviceSync::default_config([[maybe_unused]] const ArchRef arch) const { +Json ModelOpDeviceSync::default_config( + [[maybe_unused]] const ArchRef arch) const { return {{"ChannelType", "Proxy"}, {"NumTasks", 1}, {"NumWarps", 1}, diff --git a/ark/ops/ops_communication.hpp b/ark/ops/ops_communication.hpp index 23f3b84af..f0c0134f2 100644 --- a/ark/ops/ops_communication.hpp +++ b/ark/ops/ops_communication.hpp @@ -103,7 +103,6 @@ class ModelOpRecvReduceSend : public ModelOp { Json default_config(const ArchRef arch = ARCH_ANY) const override; }; - class ModelOpDeviceSync : public ModelOp { public: ModelOpDeviceSync() = default; diff --git a/ark/ops/ops_communication_test.cpp b/ark/ops/ops_communication_test.cpp index de7c42833..e5ffc8804 100644 --- a/ark/ops/ops_communication_test.cpp +++ b/ark/ops/ops_communication_test.cpp @@ -346,7 +346,8 @@ ark::unittest::State test_communication_send_recv_reduce_packet() { ark::unittest::spawn_process([gpu_id]() { ark::Model model(gpu_id, 2); ark::Tensor tns_data = model.tensor({1024}, ark::FP16); - std::vector shard_tensors = model.sharding(tns_data, 0, 512); + std::vector shard_tensors = + model.sharding(tns_data, 0, 512); int peer_gpu_id = (gpu_id + 1) % 2; model.send_packet(shard_tensors[peer_gpu_id], peer_gpu_id, 0, 1); @@ -389,8 +390,7 @@ ark::unittest::State test_communication_send_recv_reduce() { config["NumTasks"] = 4; config["NumWarps"] = 4; config["SramBytes"] = 0; - } - else if (op.at("Type") == "DeviceSync") { + } else if (op.at("Type") == "DeviceSync") { config["ChannelType"] = "Sm"; config["NumTasks"] = 1; config["NumWarps"] = 1; diff --git a/ark/ops/ops_embedding.cpp b/ark/ops/ops_embedding.cpp index 2d6b63720..8f29aba9a 100644 --- a/ark/ops/ops_embedding.cpp +++ b/ark/ops/ops_embedding.cpp @@ -54,13 +54,13 @@ std::string ModelOpEmbedding::impl_name(const Json &config) const { }); } -std::vector ModelOpEmbedding::impl_args([ - [maybe_unused]] const Json &config) const { +std::vector ModelOpEmbedding::impl_args( + [[maybe_unused]] const Json &config) const { return {result_tensors_[0], read_tensors_[0], read_tensors_[1]}; } -Json ModelOpEmbedding::default_config([ - [maybe_unused]] const ArchRef arch) const { +Json ModelOpEmbedding::default_config( + [[maybe_unused]] const ArchRef arch) const { Json config; config["NumWarps"] = 1; config["SramBytes"] = 0; diff --git a/ark/ops/ops_embedding_test.cpp b/ark/ops/ops_embedding_test.cpp deleted file mode 100644 index 222605296..000000000 --- a/ark/ops/ops_embedding_test.cpp +++ /dev/null @@ -1,122 +0,0 @@ -// Copyright (c) Microsoft Corporation. -// Licensed under the MIT license. - -#include -#include - -#include "ark/random.hpp" -#include "ops_test_common.hpp" - -template -void baseline_embedding(std::vector &outputs, - const std::vector &output_shapes, - const std::vector &inputs, - const std::vector &input_shapes, int) { - T *out = static_cast(outputs[0]); - int *in = static_cast(inputs[0]); - T *weight = static_cast(inputs[1]); - - ark::Dims osh = output_shapes[0].dims4(); - ark::Dims wsh = input_shapes[1].dims4(); - - assert(osh[3] == wsh[3]); - - int in_idx = 0; - for (ark::DimType n = 0; n < osh[0]; ++n) { - for (ark::DimType c = 0; c < osh[1]; ++c) { - for (ark::DimType h = 0; h < osh[2]; ++h) { - int weight_idx = in[in_idx++]; - if (weight_idx < 0) { - weight_idx += wsh[2]; - } - T *ptr = &weight[weight_idx * wsh[3]]; - for (ark::DimType w = 0; w < osh[3]; ++w) { - out[n * osh[1] * osh[2] * osh[3] + c * osh[2] * osh[3] + - h * osh[3] + w] = ptr[w]; - } - } - } - } -}; - -template -ark::unittest::State test_embedding() { - const int num_emb = 100; - const int emb_dim = 4096; - - ark::DataType weight_type; - if (std::is_same::value) { - weight_type = ark::FP32; - } else { - weight_type = ark::FP16; - } - - ark::Model m; - ark::Tensor ti = m.tensor(ark::Dims(8, 3, 64), ark::INT32); - ark::Tensor tw = m.tensor(ark::Dims(num_emb, emb_dim), weight_type); - ark::Tensor to = m.embedding(ti, tw); - - std::vector ti_data; - for (auto i = 0; i < ti.shape().nelems(); ++i) { - // Random indices in [0, num_emb) - int rand_idx = ark::rand() % num_emb; - if (i % 9 == 0) { - // test negative tokens (padding) - rand_idx = -rand_idx; - } - ti_data.push_back(rand_idx); - } - std::vector tw_data(tw.shape().nelems()); - for (auto i = 0; i < tw.shape().nelems(); ++i) { - tw_data[i] = ark::random(-1.0, 1.0); - } - std::string type_str = ""; - if (std::is_same::value) { - type_str = "fp32"; - } else if (std::is_same::value) { - type_str = "fp16"; - } else if (std::is_same::value) { - type_str = "bf16"; - } - auto result = - ark::op_test("embedding_" + type_str, m, {ti, tw}, {to}, - baseline_embedding, {ti_data.data(), tw_data.data()}); - UNITTEST_LOG(result); - UNITTEST_EQ(result.max_diff[0], 0.0f); - return ark::unittest::SUCCESS; -} - -ark::unittest::State test_embedding_fp32() { return test_embedding(); } - -ark::unittest::State test_embedding_fp16() { - return test_embedding(); -} - -ark::unittest::State test_embedding_bf16() { - return test_embedding(); -} - -ark::unittest::State test_embedding_invalid() { - { - ark::Model m; - ark::Tensor ti = m.tensor(ark::Dims(4, 8, 3, 64), ark::INT32); - ark::Tensor tw = m.tensor(ark::Dims(100, 1024), ark::FP32); - UNITTEST_THROW(m.embedding(ti, tw), ark::ModelError); - } - { - ark::Model m; - ark::Tensor ti = m.tensor(ark::Dims(8, 3, 64), ark::INT32); - ark::Tensor tw = m.tensor(ark::Dims(2, 100, 1024), ark::FP32); - UNITTEST_THROW(m.embedding(ti, tw), ark::ModelError); - } - return ark::unittest::SUCCESS; -} - -int main() { - ark::init(); - UNITTEST(test_embedding_fp32); - UNITTEST(test_embedding_fp16); - UNITTEST(test_embedding_bf16); - UNITTEST(test_embedding_invalid); - return ark::unittest::SUCCESS; -} diff --git a/ark/ops/ops_math_test.cpp b/ark/ops/ops_math_test.cpp deleted file mode 100644 index f5774ab8e..000000000 --- a/ark/ops/ops_math_test.cpp +++ /dev/null @@ -1,366 +0,0 @@ -// Copyright (c) Microsoft Corporation. -// Licensed under the MIT license. - -#include - -#include "ark/model.hpp" -#include "ops_test_common.hpp" -#include "unittest/unittest_utils.h" - -float gelu(float x) { - return 0.5 * x * (1 + tanh(sqrt(2 / M_PI) * (x + 0.044715 * pow(x, 3)))); -} - -template -void baseline_gelu(std::vector &outputs, - const std::vector &output_shapes, - const std::vector &inputs, - const std::vector &, int) { - T *out = static_cast(outputs[0]); - T *input = static_cast(inputs[0]); - ark::Dims osh = output_shapes[0]; - for (ark::DimType i = 0; i < osh.nelems(); ++i) { - out[i] = gelu(input[i]); - } -}; - -template -void baseline_exp(std::vector &outputs, - const std::vector &output_shapes, - const std::vector &inputs, - const std::vector &, int) { - T *out = static_cast(outputs[0]); - T *input = static_cast(inputs[0]); - ark::Dims osh = output_shapes[0]; - for (ark::DimType i = 0; i < osh.nelems(); ++i) { - out[i] = std::exp(input[i]); - } -}; - -float relu(float x) { return x > 0 ? x : 0; } - -template -void baseline_relu(std::vector &outputs, - const std::vector &output_shapes, - const std::vector &inputs, - const std::vector &, int) { - T *out = static_cast(outputs[0]); - T *input = static_cast(inputs[0]); - ark::Dims osh = output_shapes[0]; - for (ark::DimType i = 0; i < osh.nelems(); ++i) { - out[i] = relu(input[i]); - } -}; - -template -void baseline_rsqrt(std::vector &outputs, - const std::vector &output_shapes, - const std::vector &inputs, - const std::vector &, int) { - T *out = static_cast(outputs[0]); - T *input = static_cast(inputs[0]); - ark::Dims osh = output_shapes[0]; - for (ark::DimType i = 0; i < osh.nelems(); ++i) { - out[i] = 1.0f / std::sqrt(input[i]); - } -}; - -float sigmoid(float x) { return 1 / (1 + std::exp(-x)); } - -template -void baseline_sigmoid(std::vector &outputs, - const std::vector &output_shapes, - const std::vector &inputs, - const std::vector &, int) { - T *out = static_cast(outputs[0]); - T *input = static_cast(inputs[0]); - ark::Dims osh = output_shapes[0]; - for (ark::DimType i = 0; i < osh.nelems(); ++i) { - out[i] = sigmoid(input[i]); - } -}; - -template -void baseline_sqrt(std::vector &outputs, - const std::vector &output_shapes, - const std::vector &inputs, - const std::vector &, int) { - T *out = static_cast(outputs[0]); - T *input = static_cast(inputs[0]); - ark::Dims osh = output_shapes[0]; - for (ark::DimType i = 0; i < osh.nelems(); ++i) { - out[i] = std::sqrt(input[i]); - } -}; - -ark::unittest::State test_gelu_fp32() { - ark::Model m; - ark::Tensor t = m.tensor({4, 2, 1024}, ark::FP32); - ark::Tensor out = m.gelu(t); - - auto result = - ark::op_test("gelu_fp32", m, {t}, {out}, baseline_gelu); - UNITTEST_LOG(result); - UNITTEST_TRUE(result.max_diff[0] < 1e-6f); - return ark::unittest::SUCCESS; -} - -ark::unittest::State test_gelu_bf16() { - ark::Model m; - ark::Tensor t = m.tensor({4, 2, 1024}, ark::BF16); - ark::Tensor out = m.gelu(t); - - auto result = ark::op_test("gelu_bf16", m, {t}, {out}, - baseline_gelu); - UNITTEST_LOG(result); - UNITTEST_TRUE(result.max_diff[0] < 1e-6f); - return ark::unittest::SUCCESS; -} - -ark::unittest::State test_gelu_invalid() { - { - ark::Model m; - ark::Tensor t = m.tensor({4, 2, 1024}, ark::BF16); - ark::Tensor out = m.tensor({4, 2, 1024}, ark::FP32); - UNITTEST_THROW(m.gelu(t, out), ark::ModelError); - } - { - ark::Model m; - ark::Tensor t = m.tensor({4, 2, 1024}, ark::BF16); - ark::Tensor out = m.tensor({4, 4, 1024}, ark::BF16); - UNITTEST_THROW(m.gelu(t, out), ark::ModelError); - } - return ark::unittest::SUCCESS; -} - -ark::unittest::State test_exp_fp32() { - ark::Model m; - ark::Tensor t = m.tensor({4, 2, 1024}, ark::FP32); - ark::Tensor out = m.exp(t); - - auto result = ark::op_test("exp_fp32", m, {t}, {out}, baseline_exp); - UNITTEST_LOG(result); - UNITTEST_TRUE(result.max_diff[0] < 1e-5f); - return ark::unittest::SUCCESS; -} - -ark::unittest::State test_exp_fp16() { - ark::Model m; - ark::Tensor t = m.tensor({4, 2, 1024}, ark::FP16); - ark::Tensor out = m.exp(t); - - auto result = - ark::op_test("exp_fp16", m, {t}, {out}, baseline_exp); - UNITTEST_LOG(result); - UNITTEST_TRUE(result.max_diff[0] < 1e-2f); - return ark::unittest::SUCCESS; -} - -ark::unittest::State test_exp_bf16() { - ark::Model m; - ark::Tensor t = m.tensor({4, 2, 1024}, ark::BF16); - ark::Tensor out = m.exp(t); - - auto result = - ark::op_test("exp_bf16", m, {t}, {out}, baseline_exp); - UNITTEST_LOG(result); - UNITTEST_TRUE(result.max_diff[0] < 1e-2f); - return ark::unittest::SUCCESS; -} - -ark::unittest::State test_exp_invalid() { - { - ark::Model m; - ark::Tensor t = m.tensor({4, 2, 1024}, ark::BF16); - ark::Tensor out = m.tensor({4, 2, 1024}, ark::FP32); - UNITTEST_THROW(m.exp(t, out), ark::ModelError); - } - { - ark::Model m; - ark::Tensor t = m.tensor({4, 2, 1024}, ark::BF16); - ark::Tensor out = m.tensor({4, 4, 1024}, ark::BF16); - UNITTEST_THROW(m.exp(t, out), ark::ModelError); - } - return ark::unittest::SUCCESS; -} - -ark::unittest::State test_relu_fp32() { - ark::Model m; - ark::Tensor t = m.tensor({4, 2, 1024}, ark::FP32); - ark::Tensor out = m.relu(t); - - auto result = - ark::op_test("relu_fp32", m, {t}, {out}, baseline_relu); - UNITTEST_LOG(result); - UNITTEST_EQ(result.max_diff[0], 0.0f); - return ark::unittest::SUCCESS; -} - -ark::unittest::State test_relu_fp16() { - ark::Model m; - ark::Tensor t = m.tensor({4, 2, 1024}, ark::FP16); - ark::Tensor out = m.relu(t); - - auto result = - ark::op_test("relu_fp16", m, {t}, {out}, baseline_relu); - UNITTEST_LOG(result); - UNITTEST_EQ(result.max_diff[0], 0.0f); - return ark::unittest::SUCCESS; -} - -ark::unittest::State test_relu_bf16() { - ark::Model m; - ark::Tensor t = m.tensor({4, 2, 1024}, ark::BF16); - ark::Tensor out = m.relu(t); - - auto result = ark::op_test("relu_bf16", m, {t}, {out}, - baseline_relu); - UNITTEST_LOG(result); - UNITTEST_EQ(result.max_diff[0], 0.0f); - return ark::unittest::SUCCESS; -} - -ark::unittest::State test_relu_invalid() { - { - ark::Model m; - ark::Tensor t = m.tensor({4, 2, 1024}, ark::BF16); - ark::Tensor out = m.tensor({4, 2, 1024}, ark::FP32); - UNITTEST_THROW(m.relu(t, out), ark::ModelError); - } - { - ark::Model m; - ark::Tensor t = m.tensor({4, 2, 1024}, ark::BF16); - ark::Tensor out = m.tensor({4, 4, 1024}, ark::BF16); - UNITTEST_THROW(m.relu(t, out), ark::ModelError); - } - return ark::unittest::SUCCESS; -} - -ark::unittest::State test_math_rsqrt_fp32() { - ark::Model m; - ark::Tensor t = m.tensor({4, 2, 1024}, ark::FP32); - ark::Tensor out = m.rsqrt(t); - - auto result = - ark::op_test("math_rsqrt_fp32", m, {t}, {out}, baseline_rsqrt); - UNITTEST_LOG(result); - UNITTEST_TRUE(result.max_diff[0] < 1e-4f); - return ark::unittest::SUCCESS; -} - -ark::unittest::State test_math_rsqrt_fp16() { - ark::Model m; - ark::Tensor t = m.tensor({1, 64, 1}, ark::FP16); - ark::Tensor out = m.rsqrt(t); - - std::vector data(64, 4); - - auto result = ark::op_test("math_rsqrt_fp16", m, {t}, {out}, - baseline_rsqrt, {data.data()}); - UNITTEST_LOG(result); - UNITTEST_TRUE(result.max_diff[0] < 1e-4f); - return ark::unittest::SUCCESS; -} - -ark::unittest::State test_sigmoid_fp32() { - ark::Model m; - ark::Tensor t = m.tensor({4, 2, 1024}, ark::FP32); - ark::Tensor out = m.sigmoid(t); - - auto result = - ark::op_test("sigmoid_fp32", m, {t}, {out}, baseline_sigmoid); - UNITTEST_LOG(result); - UNITTEST_TRUE(result.max_diff[0] < 1e-5f); - return ark::unittest::SUCCESS; -} - -ark::unittest::State test_sigmoid_bf16() { - ark::Model m; - ark::Tensor t = m.tensor({4, 2, 1024}, ark::BF16); - ark::Tensor out = m.sigmoid(t); - - auto result = ark::op_test("sigmoid_bf16", m, {t}, {out}, - baseline_sigmoid); - UNITTEST_LOG(result); - UNITTEST_TRUE(result.max_diff[0] < 1e-2f); - return ark::unittest::SUCCESS; -} - -ark::unittest::State test_sigmoid_invalid() { - { - ark::Model m; - ark::Tensor t = m.tensor({4, 2, 1024}, ark::BF16); - ark::Tensor out = m.tensor({4, 2, 1024}, ark::FP32); - UNITTEST_THROW(m.sigmoid(t, out), ark::ModelError); - } - { - ark::Model m; - ark::Tensor t = m.tensor({4, 2, 1024}, ark::BF16); - ark::Tensor out = m.tensor({4, 4, 1024}, ark::BF16); - UNITTEST_THROW(m.sigmoid(t, out), ark::ModelError); - } - return ark::unittest::SUCCESS; -} - -ark::unittest::State test_math_sqrt_fp32() { - ark::Model m; - ark::Tensor t = m.tensor({4, 2, 1024}, ark::FP32); - ark::Tensor out = m.sqrt(t); - - auto result = - ark::op_test("math_sqrt_fp32", m, {t}, {out}, baseline_sqrt); - UNITTEST_LOG(result); - UNITTEST_TRUE(result.max_diff[0] < 1e-6f); - return ark::unittest::SUCCESS; -} - -ark::unittest::State test_math_sqrt_fp16_small_last_dim() { - ark::Model m; - ark::Tensor t = m.tensor({4, 1024, 1}, ark::FP16, {4, 1024, 2}); - ark::Tensor out = m.sqrt(t); - - auto result = ark::op_test("math_sqrt_fp16_small_last_dim", m, {t}, {out}, - baseline_sqrt); - UNITTEST_LOG(result); - UNITTEST_TRUE(result.max_diff[0] < 1e-4f); - return ark::unittest::SUCCESS; -} - -ark::unittest::State test_math_sqrt_invalid() { - { - ark::Model model; - ark::Tensor input = model.tensor({1, 3, 16, 8192}, ark::FP32); - ark::Tensor output = model.tensor({1, 3, 16, 8192}, ark::FP16); - UNITTEST_THROW(model.sqrt(input, output), ark::ModelError); - } - { - ark::Model model; - ark::Tensor input = model.tensor({1, 3, 16, 8192}, ark::FP32); - ark::Tensor output = model.tensor({1, 3, 16, 1024}, ark::FP32); - UNITTEST_THROW(model.sqrt(input, output), ark::ModelError); - } - return ark::unittest::SUCCESS; -} - -int main() { - ark::init(); - UNITTEST(test_gelu_fp32); - UNITTEST(test_gelu_bf16); - UNITTEST(test_gelu_invalid); - UNITTEST(test_exp_fp32); - UNITTEST(test_exp_fp16); - UNITTEST(test_exp_invalid); - UNITTEST(test_relu_fp32); - UNITTEST(test_relu_fp16); - UNITTEST(test_relu_bf16); - UNITTEST(test_relu_invalid); - UNITTEST(test_math_rsqrt_fp32); - UNITTEST(test_math_rsqrt_fp16); - UNITTEST(test_sigmoid_fp32); - UNITTEST(test_sigmoid_bf16); - UNITTEST(test_sigmoid_invalid); - UNITTEST(test_math_sqrt_fp32); - UNITTEST(test_math_sqrt_fp16_small_last_dim); - UNITTEST(test_math_sqrt_invalid); - return ark::unittest::SUCCESS; -} diff --git a/ark/ops/ops_matmul_test.cpp b/ark/ops/ops_matmul_test.cpp deleted file mode 100644 index 11682ca49..000000000 --- a/ark/ops/ops_matmul_test.cpp +++ /dev/null @@ -1,589 +0,0 @@ -// Copyright (c) Microsoft Corporation. -// Licensed under the MIT license. - -#include - -#include "gpu/gpu.hpp" -#include "logging.hpp" -#include "model/model_node.hpp" -#include "model/model_op.hpp" -#include "ops_test_common.hpp" - -#if defined(ARK_CUDA) - -#include - -typedef cublasHandle_t blasHandle; -typedef cublasStatus_t blasStatus; -typedef cublasOperation_t blasOperation; -typedef cudaDataType blasDataType; -typedef cublasComputeType_t blasComputeType; -constexpr auto blasSuccess = CUBLAS_STATUS_SUCCESS; -constexpr auto BLAS_OP_N = CUBLAS_OP_N; -constexpr auto BLAS_OP_T = CUBLAS_OP_T; -constexpr auto BLAS_R_32F = CUDA_R_32F; -constexpr auto BLAS_R_16F = CUDA_R_16F; -constexpr auto BLAS_R_16BF = CUDA_R_16BF; -constexpr auto BLAS_COMPUTE_32F = CUBLAS_COMPUTE_32F; -constexpr auto BLAS_COMPUTE_32F_FAST_TF32 = CUBLAS_COMPUTE_32F_FAST_TF32; -constexpr auto BLAS_COMPUTE_16F = CUBLAS_COMPUTE_16F; - -inline auto blasGemmEx(blasHandle handle, blasOperation transA, - blasOperation transB, int m, int n, int k, - const void *alpha, const void *A, blasDataType Atype, - int lda, const void *B, blasDataType Btype, int ldb, - const void *beta, void *C, blasDataType Ctype, int ldc, - blasComputeType computeType) { - return cublasGemmEx(handle, transA, transB, m, n, k, alpha, A, Atype, lda, - B, Btype, ldb, beta, C, Ctype, ldc, computeType, - CUBLAS_GEMM_DEFAULT); -} - -inline auto blasGemmStridedBatchedEx( - blasHandle handle, blasOperation transA, blasOperation transB, int m, int n, - int k, const void *alpha, const void *A, blasDataType Atype, int lda, - int strideA, const void *B, blasDataType Btype, int ldb, int strideB, - const void *beta, void *C, blasDataType Ctype, int ldc, int strideC, - int batchCount, blasComputeType computeType) { - return cublasGemmStridedBatchedEx( - handle, transA, transB, m, n, k, alpha, A, Atype, lda, strideA, B, - Btype, ldb, strideB, beta, C, Ctype, ldc, strideC, batchCount, - computeType, CUBLAS_GEMM_DEFAULT); -} - -#elif defined(ARK_ROCM) - -#include - -typedef rocblas_handle blasHandle; -typedef rocblas_status blasStatus; -typedef rocblas_operation blasOperation; -typedef rocblas_datatype blasDataType; -typedef rocblas_datatype blasComputeType; -constexpr auto blasSuccess = rocblas_status_success; -constexpr auto BLAS_OP_N = rocblas_operation_none; -constexpr auto BLAS_OP_T = rocblas_operation_transpose; -constexpr auto BLAS_R_32F = rocblas_datatype_f32_r; -constexpr auto BLAS_R_16F = rocblas_datatype_f16_r; -constexpr auto BLAS_R_16BF = rocblas_datatype_bf16_r; -constexpr auto BLAS_COMPUTE_32F = rocblas_datatype_f32_r; -[[maybe_unused]] constexpr auto BLAS_COMPUTE_32F_FAST_TF32 = - rocblas_datatype_f32_r; -[[maybe_unused]] constexpr auto BLAS_COMPUTE_16F = rocblas_datatype_f16_r; - -inline auto blasGemmEx(blasHandle handle, blasOperation transA, - blasOperation transB, int m, int n, int k, - const void *alpha, const void *A, blasDataType Atype, - int lda, const void *B, blasDataType Btype, int ldb, - const void *beta, void *C, blasDataType Ctype, int ldc, - blasComputeType computeType) { - return rocblas_gemm_ex(handle, transA, transB, m, n, k, alpha, A, Atype, - lda, B, Btype, ldb, beta, C, Ctype, ldc, C, Ctype, - ldc, computeType, rocblas_gemm_algo_standard, 0, 0); -} - -inline auto blasGemmStridedBatchedEx( - blasHandle handle, blasOperation transA, blasOperation transB, int m, int n, - int k, const void *alpha, const void *A, blasDataType Atype, int lda, - int strideA, const void *B, blasDataType Btype, int ldb, int strideB, - const void *beta, void *C, blasDataType Ctype, int ldc, int strideC, - int batchCount, blasComputeType computeType) { - return rocblas_gemm_strided_batched_ex( - handle, transA, transB, m, n, k, alpha, A, Atype, lda, strideA, B, - Btype, ldb, strideB, beta, C, Ctype, ldc, strideC, C, Ctype, ldc, - strideC, batchCount, computeType, rocblas_gemm_algo_standard, 0, 0); -} - -#endif - -ARK_GPU_DEFINE_FUNC_ALIAS(blasCreate, cublasCreate, rocblas_create_handle); -ARK_GPU_DEFINE_FUNC_ALIAS(blasDestroy, cublasDestroy, rocblas_destroy_handle); - -class BlasHandle { - public: - BlasHandle() { - if (blasCreate(&handle_) != blasSuccess) { - throw std::runtime_error("Failed to create blas handle"); - } - } - - ~BlasHandle() { - if (blasDestroy(handle_) != blasSuccess) { - // do nothing. - } - } - - blasHandle get() const { return handle_; } - - private: - blasHandle handle_; -}; - -static BlasHandle globalBlasHandle; - -template -void blas_matmul(int m, int n, int k, const DataType *a, int lda, - const DataType *b, int ldb, DataType *c, int ldc, - int batch_size = 1) { - static_assert(std::is_same_v || - std::is_same_v || - std::is_same_v, - "Unsupported data type"); - - auto blasH = globalBlasHandle.get(); - blasStatus status; - blasOperation optypeA = (blasOperation)BlasOpTypeA; - blasOperation optypeB = (blasOperation)BlasOpTypeB; - -#if defined(ARK_CUDA) - using CompType = - typename std::conditional_t, - ark::half_t, float>; - blasComputeType ctype = - std::is_same_v - ? BLAS_COMPUTE_32F_FAST_TF32 - : (std::is_same_v ? BLAS_COMPUTE_16F - : BLAS_COMPUTE_32F); -#elif defined(ARK_ROCM) - // CK uses only fp32 compute type for fp16/bf16 - using CompType = float; - blasComputeType ctype = BLAS_COMPUTE_32F; -#endif - CompType alpha = 1; - CompType beta = 0; - - blasDataType dtype = - std::is_same_v - ? BLAS_R_32F - : (std::is_same_v ? BLAS_R_16F - : BLAS_R_16BF); - if (batch_size == 1) { - status = blasGemmEx(blasH, optypeB, optypeA, n, m, k, &alpha, b, dtype, - ldb, a, dtype, lda, &beta, c, dtype, ldc, ctype); - if (status != blasSuccess) { - throw std::runtime_error("Failed to call blasGemmEx"); - } - } else { - status = blasGemmStridedBatchedEx( - blasH, optypeB, optypeA, n, m, k, &alpha, b, dtype, ldb, n * k, a, - dtype, lda, k * m, &beta, c, dtype, ldc, n * m, batch_size, ctype); - if (status != blasSuccess) { - throw std::runtime_error("Failed to call blasGemmStridedBatchedEx"); - } - } -} - -template -void baseline_matmul_nn(std::vector &outputs, - const std::vector &output_shapes, - const std::vector &inputs, - const std::vector &input_shapes, int) { - auto out_shape_dims4 = output_shapes[0].dims4(); - - // baseline inputs & outputs have no padding - int m = out_shape_dims4[2]; - int n = out_shape_dims4[3]; - int k = input_shapes[0].dims4()[3]; - int lda = k; - int ldb = n; - int ldc = n; - - int batch_size = out_shape_dims4[0] * out_shape_dims4[1]; - - auto memA = ark::to_gpu(inputs[0], input_shapes[0].nelems() * sizeof(T)); - auto memB = ark::to_gpu(inputs[1], input_shapes[1].nelems() * sizeof(T)); - auto memC = ark::to_gpu(outputs[0], output_shapes[0].nelems() * sizeof(T)); - - T *devA = static_cast(memA.get()); - T *devB = static_cast(memB.get()); - T *devC = static_cast(memC.get()); - - blas_matmul(m, n, k, devA, lda, devB, ldb, devC, ldc, - batch_size); - ark::sync_gpu(); - - // copy back to host - ark::from_gpu(memC, outputs[0]); -} - -template -void baseline_matmul_nt(std::vector &outputs, - const std::vector &output_shapes, - const std::vector &inputs, - const std::vector &input_shapes, int) { - auto out_shape_dims4 = output_shapes[0].dims4(); - - // baseline inputs & outputs have no padding - int m = out_shape_dims4[2]; - int n = out_shape_dims4[3]; - int k = input_shapes[0].dims4()[3]; - int lda = k; - int ldb = k; - int ldc = n; - - int batch_size = out_shape_dims4[0] * out_shape_dims4[1]; - - auto memA = ark::to_gpu(inputs[0], input_shapes[0].nelems() * sizeof(T)); - auto memB = ark::to_gpu(inputs[1], input_shapes[1].nelems() * sizeof(T)); - auto memC = ark::to_gpu(outputs[0], output_shapes[0].nelems() * sizeof(T)); - - T *devA = static_cast(memA.get()); - T *devB = static_cast(memB.get()); - T *devC = static_cast(memC.get()); - - blas_matmul(m, n, k, devA, lda, devB, ldb, devC, ldc, - batch_size); - ark::sync_gpu(); - - // copy back to host - ark::from_gpu(memC, outputs[0]); -} - -template -void baseline_matmul_tn(std::vector &outputs, - const std::vector &output_shapes, - const std::vector &inputs, - const std::vector &input_shapes, int) { - auto out_shape_dims4 = output_shapes[0].dims4(); - - // baseline inputs & outputs have no padding - int m = out_shape_dims4[2]; - int n = out_shape_dims4[3]; - int k = input_shapes[0].dims4()[2]; - int lda = m; - int ldb = n; - int ldc = n; - - int batch_size = out_shape_dims4[0] * out_shape_dims4[1]; - - auto memA = ark::to_gpu(inputs[0], input_shapes[0].nelems() * sizeof(T)); - auto memB = ark::to_gpu(inputs[1], input_shapes[1].nelems() * sizeof(T)); - auto memC = ark::to_gpu(outputs[0], output_shapes[0].nelems() * sizeof(T)); - - T *devA = static_cast(memA.get()); - T *devB = static_cast(memB.get()); - T *devC = static_cast(memC.get()); - - blas_matmul(m, n, k, devA, lda, devB, ldb, devC, ldc, - batch_size); - ark::sync_gpu(); - - // copy back to host - ark::from_gpu(memC, outputs[0]); -} - -template -void baseline_matmul_tt(std::vector &outputs, - const std::vector &output_shapes, - const std::vector &inputs, - const std::vector &input_shapes, int) { - auto out_shape_dims4 = output_shapes[0].dims4(); - - // baseline inputs & outputs have no padding - int m = out_shape_dims4[2]; - int n = out_shape_dims4[3]; - int k = input_shapes[0].dims4()[2]; - int lda = m; - int ldb = k; - int ldc = n; - - int batch_size = out_shape_dims4[0] * out_shape_dims4[1]; - - auto memA = ark::to_gpu(inputs[0], input_shapes[0].nelems() * sizeof(T)); - auto memB = ark::to_gpu(inputs[1], input_shapes[1].nelems() * sizeof(T)); - auto memC = ark::to_gpu(outputs[0], output_shapes[0].nelems() * sizeof(T)); - - T *devA = static_cast(memA.get()); - T *devB = static_cast(memB.get()); - T *devC = static_cast(memC.get()); - - blas_matmul(m, n, k, devA, lda, devB, ldb, devC, ldc, - batch_size); - ark::sync_gpu(); - - // copy back to host - ark::from_gpu(memC, outputs[0]); -} - -ark::unittest::State test_matmul_model() { - // Hidden dimension of the dense layer. - unsigned int units = 1024; - // Input dimension of the dense layer. - unsigned int in_dim = 1024; - // Extra dimension of the input. CHANNEL=1 for 2D inputs. - unsigned int channel = 128; - // Batch size of the input. - unsigned int batch_size = 1; - - ark::Model m; - ark::Tensor input = m.tensor({batch_size, channel, in_dim}, ark::FP16); - ark::Tensor weight = m.tensor({in_dim, units}, ark::FP16); - m.matmul(input, weight); - - UNITTEST_TRUE(m.verify()); - auto compressed = m.compress(); - UNITTEST_TRUE(compressed.verify()); - - return ark::unittest::SUCCESS; -} - -ark::unittest::State test_matmul_fp16() { - { - ark::Model m; - ark::Tensor a = m.tensor(ark::Dims(128, 64), ark::FP16); - ark::Tensor b = m.tensor(ark::Dims(64, 256), ark::FP16); - ark::Tensor c = m.matmul(a, b); - - auto result = ark::op_test("matmul_fp16", m, {a, b}, {c}, - baseline_matmul_nn); - UNITTEST_LOG(result); - UNITTEST_TRUE(result.max_diff[0] < - ark::reduction_abs_error_bound(0.1f, 64)); - } - { - ark::Model m; - ark::Tensor a = m.tensor(ark::Dims(4096, 2048), ark::FP16); - ark::Tensor b = m.tensor(ark::Dims(2048, 16384), ark::FP16); - ark::Tensor c = m.matmul(a, b); - - std::vector p_ones_a(a.shape().nelems(), - ark::half_t(0.1f)); - std::vector p_ones_b(b.shape().nelems(), - ark::half_t(0.1f)); - - auto result = ark::op_test("matmul_fp16", m, {a, b}, {c}, - baseline_matmul_nn, - {p_ones_a.data(), p_ones_b.data()}); - UNITTEST_LOG(result); - UNITTEST_TRUE(result.max_diff[0] < - ark::reduction_abs_error_bound(0.1f, 2048)); - } - return ark::unittest::SUCCESS; -} - -ark::unittest::State test_matmul_fp32() { - { - ark::Model m; - ark::Tensor a = m.tensor(ark::Dims(128, 64), ark::FP32); - ark::Tensor b = m.tensor(ark::Dims(64, 256), ark::FP32); - ark::Tensor c = m.matmul(a, b); - - auto result = ark::op_test("matmul_fp32", m, {a, b}, {c}, - baseline_matmul_nn); - UNITTEST_LOG(result); - UNITTEST_TRUE(result.max_diff[0] < - ark::reduction_abs_error_bound(0.1f, 64)); - } - { - ark::Model m; - ark::Tensor a = m.tensor(ark::Dims(4096, 8192), ark::FP32); - ark::Tensor b = m.tensor(ark::Dims(8192, 16384), ark::FP32); - ark::Tensor c = m.matmul(a, b); - - std::vector p_ones_a(a.shape().nelems(), float(0.1f)); - std::vector p_ones_b(b.shape().nelems(), float(0.1f)); - - auto result = ark::op_test("matmul_fp32", m, {a, b}, {c}, - baseline_matmul_nn, - {p_ones_a.data(), p_ones_b.data()}); - UNITTEST_LOG(result); - // TODO: #199 -#if defined(ARK_CUDA) - UNITTEST_TRUE(result.max_diff[0] < - ark::reduction_abs_error_bound(0.1f, 8192)); -#endif // defined(ARK_CUDA) - } - return ark::unittest::SUCCESS; -} - -ark::unittest::State test_matmul_bf16() { - { - ark::Model m; - ark::Tensor a = m.tensor(ark::Dims(128, 64), ark::BF16); - ark::Tensor b = m.tensor(ark::Dims(64, 256), ark::BF16); - ark::Tensor c = m.matmul(a, b); - - auto result = ark::op_test("matmul_bf16", m, {a, b}, {c}, - baseline_matmul_nn); - UNITTEST_LOG(result); - UNITTEST_TRUE( - result.max_diff[0] < - ark::reduction_abs_error_bound(0.1f, 64)); - } - { - ark::Model m; - ark::Tensor a = m.tensor(ark::Dims(4096, 256), ark::BF16); - ark::Tensor b = m.tensor(ark::Dims(256, 16384), ark::BF16); - ark::Tensor c = m.matmul(a, b); - - auto result = ark::op_test("matmul_bf16", m, {a, b}, {c}, - baseline_matmul_nn); - UNITTEST_LOG(result); - UNITTEST_TRUE( - result.max_diff[0] < - ark::reduction_abs_error_bound(0.1f, 256)); - } - return ark::unittest::SUCCESS; -} - -ark::unittest::State test_matmul_fp16_nt() { - { - ark::Model m; - ark::Tensor a = m.tensor(ark::Dims(128, 64), ark::FP16); - ark::Tensor b = m.tensor(ark::Dims(256, 64), ark::FP16); - ark::Tensor c = m.matmul(a, b, ark::NullTensor, false, true); - - auto result = ark::op_test("matmul_fp16_nt", m, {a, b}, {c}, - baseline_matmul_nt); - UNITTEST_LOG(result); - UNITTEST_TRUE(result.max_diff[0] < - ark::reduction_abs_error_bound(0.1f, 64)); - } - { - ark::Model m; - ark::Tensor a = m.tensor(ark::Dims(4096, 2048), ark::FP16); - ark::Tensor b = m.tensor(ark::Dims(16384, 2048), ark::FP16); - ark::Tensor c = m.matmul(a, b, ark::NullTensor, false, true); - - auto result = ark::op_test("matmul_fp16_nt", m, {a, b}, {c}, - baseline_matmul_nt); - UNITTEST_LOG(result); - UNITTEST_TRUE(result.max_diff[0] < - ark::reduction_abs_error_bound(0.1f, 2048)); - } - return ark::unittest::SUCCESS; -} - -ark::unittest::State test_matmul_fp16_tn() { - { - ark::Model m; - ark::Tensor a = m.tensor(ark::Dims(64, 128), ark::FP16); - ark::Tensor b = m.tensor(ark::Dims(64, 256), ark::FP16); - ark::Tensor c = m.matmul(a, b, ark::NullTensor, true, false); - - auto result = ark::op_test("matmul_fp16_tn", m, {a, b}, {c}, - baseline_matmul_tn); - UNITTEST_LOG(result); - UNITTEST_TRUE(result.max_diff[0] < - ark::reduction_abs_error_bound(0.1f, 64)); - } - { - ark::Model m; - ark::Tensor a = m.tensor(ark::Dims(2048, 4096), ark::FP16); - ark::Tensor b = m.tensor(ark::Dims(2048, 16384), ark::FP16); - ark::Tensor c = m.matmul(a, b, ark::NullTensor, true, false); - - auto result = ark::op_test("matmul_fp16_tn", m, {a, b}, {c}, - baseline_matmul_tn); - UNITTEST_LOG(result); - UNITTEST_TRUE(result.max_diff[0] < - ark::reduction_abs_error_bound(0.1f, 2048)); - } - return ark::unittest::SUCCESS; -} - -ark::unittest::State test_matmul_fp16_tt() { - { - ark::Model m; - ark::Tensor a = m.tensor(ark::Dims(64, 128), ark::FP16); - ark::Tensor b = m.tensor(ark::Dims(256, 64), ark::FP16); - ark::Tensor c = m.matmul(a, b, ark::NullTensor, true, true); - - auto result = ark::op_test("matmul_fp16_tt", m, {a, b}, {c}, - baseline_matmul_tt); - UNITTEST_LOG(result); - UNITTEST_TRUE(result.max_diff[0] < - ark::reduction_abs_error_bound(0.1f, 64)); - } - { - ark::Model m; - ark::Tensor a = m.tensor(ark::Dims(2048, 4096), ark::FP16); - ark::Tensor b = m.tensor(ark::Dims(16384, 2048), ark::FP16); - ark::Tensor c = m.matmul(a, b, ark::NullTensor, true, true); - - auto result = ark::op_test("matmul_fp16_tt", m, {a, b}, {c}, - baseline_matmul_tt); - UNITTEST_LOG(result); - UNITTEST_TRUE(result.max_diff[0] < - ark::reduction_abs_error_bound(0.1f, 2048)); - } - return ark::unittest::SUCCESS; -} - -ark::unittest::State test_matmul_fp16_batched() { - ark::Model m; - ark::Tensor a = m.tensor(ark::Dims(3, 7, 128, 128), ark::FP16); - ark::Tensor b = m.tensor(ark::Dims(3, 7, 128, 256), ark::FP16); - ark::Tensor c = m.matmul(a, b); - - auto result = ark::op_test("matmul_fp16_batched", m, {a, b}, {c}, - baseline_matmul_nn); - UNITTEST_LOG(result); - UNITTEST_TRUE(result.max_diff[0] < - ark::reduction_abs_error_bound(0.1f, 128)); - return ark::unittest::SUCCESS; -} - -ark::unittest::State test_matmul_fp16_batched_padded() { - ark::Model m; - ark::Tensor a = - m.tensor({3, 7, 2, 9}, ark::FP16, {3, 7, 128, 64}, {}, {3, 7, 128, 64}); - ark::Tensor b = - m.tensor({3, 7, 9, 2}, ark::FP16, {3, 7, 64, 256}, {}, {3, 7, 64, 256}); - ark::Tensor c = m.tensor({3, 7, 2, 2}, ark::FP16, {3, 7, 128, 256}, {}, - {3, 7, 128, 256}); - m.matmul(a, b, c); - - auto result = ark::op_test("matmul_fp16_batched_padded", m, {a, b}, {c}, - baseline_matmul_nn); - UNITTEST_LOG(result); - UNITTEST_TRUE(result.max_diff[0] < - ark::reduction_abs_error_bound(0.1f, 9)); - return ark::unittest::SUCCESS; -} - -ark::unittest::State test_matmul_fp16_offset() { - ark::Model m; - ark::Tensor a = - m.tensor({1, 128, 64}, ark::FP16, {1, 128, 256}, {0, 0, 64}); - ark::Tensor b = m.tensor({1, 64, 128}, ark::FP16, {1, 128, 256}, {0, 64, 0}, - {1, 64, 256}); - ark::Tensor c = m.tensor({1, 128, 128}, ark::FP16, {2, 256, 384}, - {1, 64, 128}, {1, 128, 256}); - m.matmul(a, b, c); - - auto result = ark::op_test("matmul_fp16_offset", m, {a, b}, {c}, - baseline_matmul_nn); - UNITTEST_LOG(result); - UNITTEST_TRUE(result.max_diff[0] < - ark::reduction_abs_error_bound(0.1f, 64)); - return ark::unittest::SUCCESS; -} - -ark::unittest::State test_matmul_invalid() { - ark::Model m; - ark::Tensor a = m.tensor(ark::Dims(128, 64), ark::FP16); - ark::Tensor b = m.tensor(ark::Dims(128, 256), ark::FP16); - UNITTEST_THROW(m.matmul(a, b), ark::ModelError); - - ark::Tensor c = m.tensor(ark::Dims(3, 3, 128, 128), ark::FP16); - ark::Tensor d = m.tensor(ark::Dims(2, 3, 128, 128), ark::FP16); - UNITTEST_THROW(m.matmul(c, d), ark::ModelError); - return ark::unittest::SUCCESS; -} - -int main() { - ark::init(); - UNITTEST(test_matmul_model); - UNITTEST(test_matmul_fp16); - UNITTEST(test_matmul_fp32); - UNITTEST(test_matmul_bf16); - UNITTEST(test_matmul_fp16_nt); - UNITTEST(test_matmul_fp16_tn); - UNITTEST(test_matmul_fp16_tt); - UNITTEST(test_matmul_fp16_batched); - UNITTEST(test_matmul_fp16_batched_padded); - UNITTEST(test_matmul_fp16_offset); - UNITTEST(test_matmul_invalid); - return ark::unittest::SUCCESS; -} diff --git a/ark/ops/ops_noop.cpp b/ark/ops/ops_noop.cpp index 894ab29be..50d1c2640 100644 --- a/ark/ops/ops_noop.cpp +++ b/ark/ops/ops_noop.cpp @@ -16,8 +16,8 @@ std::string ModelOpNoop::impl_name([[maybe_unused]] const Json &config) const { return function_name_string("noop"); } -std::vector ModelOpNoop::impl_args([ - [maybe_unused]] const Json &config) const { +std::vector ModelOpNoop::impl_args( + [[maybe_unused]] const Json &config) const { return {}; } diff --git a/ark/ops/ops_reduce.cpp b/ark/ops/ops_reduce.cpp index 78dd9d7e6..02d8b5c96 100644 --- a/ark/ops/ops_reduce.cpp +++ b/ark/ops/ops_reduce.cpp @@ -49,7 +49,7 @@ ModelOpReduce::ModelOpReduce(const std::string &type_name, ModelTensorRef input, } std::string ModelOpReduce::impl_name(const Json &config) const { - check_fields_config(config, {"NumWarps", "SramBytes", "ImplType"}); + check_fields_config(config, {"NumWarps", "SramBytes", "ImplType", "Tile"}); check_fields_args(args_, {"Axis", "KeepDim"}); std::string red_type; @@ -92,6 +92,13 @@ std::string ModelOpReduce::impl_name(const Json &config) const { output_shape.insert(axis, 1); } + Dims unit_out_dims(config.at("Tile").get>()); + auto udims4 = unit_out_dims.dims4(); + if (udims4[axis] != 1) { + ERR(PlanError, "Tile dimension along reduce axis (", axis, + ") must be 1, got ", udims4[axis]); + } + return function_name_string( "reduce_" + impl_type + "_" + red_type, { @@ -99,15 +106,15 @@ std::string ModelOpReduce::impl_name(const Json &config) const { vec_string(read_tensors_[0]->shape().dims4()), vec_string(output_strides.dims4()), vec_string(output_shape.dims4()), - vec_string(Dims(1, 1, 1, 1)), + vec_string(udims4), std::to_string(num_warps), std::to_string(sram_bytes), std::to_string(axis), }); } -std::vector ModelOpReduce::impl_args([ - [maybe_unused]] const Json &config) const { +std::vector ModelOpReduce::impl_args( + [[maybe_unused]] const Json &config) const { return {result_tensors_[0], read_tensors_[0]}; } @@ -122,6 +129,7 @@ Json ModelOpReduce::default_config([[maybe_unused]] const ArchRef arch) const { config["ImplType"] = "ElementWise"; config["SramBytes"] = 0; } + config["Tile"] = {1, 1, 1, 1}; config["NumTasks"] = result_tensors_[0]->shape().nelems(); return config; } diff --git a/ark/ops/ops_reduce_test.cpp b/ark/ops/ops_reduce_test.cpp deleted file mode 100644 index 637c8daec..000000000 --- a/ark/ops/ops_reduce_test.cpp +++ /dev/null @@ -1,472 +0,0 @@ -// Copyright (c) Microsoft Corporation. -// Licensed under the MIT license. - -#include -#include -#include - -#include "ops_test_common.hpp" - -template -void baseline_reduce_sum_axis0(std::vector &outputs, - const std::vector &output_shapes, - const std::vector &inputs, - const std::vector &input_shapes, - int) { - T *out = static_cast(outputs[0]); - T *input = static_cast(inputs[0]); - - ark::Dims osh = output_shapes[0]; - ark::Dims ish = input_shapes[0].dims4(); - - if (KeepDim) { - assert(osh[0] == 1); - } else { - osh.insert(0, 1); - } - osh = osh.dims4(); - - for (ark::DimType c = 0; c < ish[1]; ++c) { - for (ark::DimType h = 0; h < ish[2]; ++h) { - for (ark::DimType w = 0; w < ish[3]; ++w) { - float sum = 0; - for (ark::DimType n = 0; n < ish[0]; ++n) { - sum += float(input[n * ish[1] * ish[2] * ish[3] + - c * ish[2] * ish[3] + h * ish[3] + w]); - } - out[c * osh[2] * osh[3] + h * osh[3] + w] = T(sum); - } - } - } -} - -template -void baseline_reduce_sum_axis1(std::vector &outputs, - const std::vector &output_shapes, - const std::vector &inputs, - const std::vector &input_shapes, - int) { - T *out = static_cast(outputs[0]); - T *input = static_cast(inputs[0]); - - ark::Dims osh = output_shapes[0]; - ark::Dims ish = input_shapes[0].dims4(); - - if (KeepDim) { - assert(osh[1] == 1); - } else { - osh.insert(1, 1); - } - osh = osh.dims4(); - - for (ark::DimType n = 0; n < ish[0]; ++n) { - for (ark::DimType h = 0; h < ish[2]; ++h) { - for (ark::DimType w = 0; w < ish[3]; ++w) { - float sum = 0; - for (ark::DimType c = 0; c < ish[1]; ++c) { - sum += float(input[n * ish[1] * ish[2] * ish[3] + - c * ish[2] * ish[3] + h * ish[3] + w]); - } - out[n * osh[1] * osh[2] * osh[3] + h * osh[3] + w] = T(sum); - } - } - } -} - -template -void baseline_reduce_sum_axis2(std::vector &outputs, - const std::vector &output_shapes, - const std::vector &inputs, - const std::vector &input_shapes, - int) { - T *out = static_cast(outputs[0]); - T *input = static_cast(inputs[0]); - - ark::Dims osh = output_shapes[0]; - ark::Dims ish = input_shapes[0].dims4(); - - if (KeepDim) { - assert(osh[2] == 1); - } else { - osh.insert(2, 1); - } - osh = osh.dims4(); - - for (ark::DimType n = 0; n < ish[0]; ++n) { - for (ark::DimType c = 0; c < ish[1]; ++c) { - for (ark::DimType w = 0; w < ish[3]; ++w) { - float sum = 0; - for (ark::DimType h = 0; h < ish[2]; ++h) { - sum += float(input[n * ish[1] * ish[2] * ish[3] + - c * ish[2] * ish[3] + h * ish[3] + w]); - } - out[n * osh[1] * osh[2] * osh[3] + c * osh[2] * osh[3] + w] = - T(sum); - } - } - } -}; - -template -void baseline_reduce_sum_axis3(std::vector &outputs, - const std::vector &output_shapes, - const std::vector &inputs, - const std::vector &input_shapes, - int) { - T *out = static_cast(outputs[0]); - T *input = static_cast(inputs[0]); - - ark::Dims osh = output_shapes[0]; - ark::Dims ish = input_shapes[0].dims4(); - - if (KeepDim) { - assert(osh[3] == 1); - } else { - osh.insert(3, 1); - } - osh = osh.dims4(); - - for (ark::DimType n = 0; n < ish[0]; ++n) { - for (ark::DimType c = 0; c < ish[1]; ++c) { - for (ark::DimType h = 0; h < ish[2]; ++h) { - float sum = 0; - for (ark::DimType w = 0; w < ish[3]; ++w) { - sum += float(input[n * ish[1] * ish[2] * ish[3] + - c * ish[2] * ish[3] + h * ish[3] + w]); - } - out[n * osh[1] * osh[2] * osh[3] + c * osh[2] * osh[3] + - h * osh[3]] = T(sum); - } - } - } -}; - -template -void baseline_reduce_max_axis3(std::vector &outputs, - const std::vector &output_shapes, - const std::vector &inputs, - const std::vector &input_shapes, - int) { - T *out = static_cast(outputs[0]); - T *input = static_cast(inputs[0]); - - ark::Dims osh = output_shapes[0]; - ark::Dims ish = input_shapes[0].dims4(); - - if (KeepDim) { - assert(osh[3] == 1); - } else { - osh.insert(3, 1); - } - osh = osh.dims4(); - - for (ark::DimType n = 0; n < ish[0]; ++n) { - for (ark::DimType c = 0; c < ish[1]; ++c) { - for (ark::DimType h = 0; h < ish[2]; ++h) { - float max_val = std::numeric_limits::lowest(); - for (ark::DimType w = 0; w < ish[3]; ++w) { - float val = - float(input[n * ish[1] * ish[2] * ish[3] + - c * ish[2] * ish[3] + h * ish[3] + w]); - max_val = std::max(max_val, val); - } - out[n * osh[1] * osh[2] * osh[3] + c * osh[2] * osh[3] + - h * osh[3]] = T(max_val); - } - } - } -}; - -template -void baseline_reduce_mean_axis3(std::vector &outputs, - const std::vector &output_shapes, - const std::vector &inputs, - const std::vector &input_shapes, - int) { - T *out = static_cast(outputs[0]); - T *input = static_cast(inputs[0]); - - ark::Dims osh = output_shapes[0]; - ark::Dims ish = input_shapes[0].dims4(); - - if (KeepDim) { - assert(osh[3] == 1); - } else { - osh.insert(3, 1); - } - osh = osh.dims4(); - - for (ark::DimType n = 0; n < ish[0]; ++n) { - for (ark::DimType c = 0; c < ish[1]; ++c) { - for (ark::DimType h = 0; h < ish[2]; ++h) { - float mean = 0; - for (ark::DimType w = 0; w < ish[3]; ++w) { - mean += float(input[n * ish[1] * ish[2] * ish[3] + - c * ish[2] * ish[3] + h * ish[3] + w]); - } - mean /= ish[3]; - out[n * osh[1] * osh[2] * osh[3] + c * osh[2] * osh[3] + - h * osh[3]] = T(mean); - } - } - } -}; - -ark::unittest::State test_reduce_sum_axis0() { - ark::Model m; - ark::Tensor t = m.tensor(ark::Dims(7, 2, 4, 1024), ark::FP32); - ark::Tensor out = m.reduce_sum(t, /*axis=*/0); - - auto result = ark::op_test("reduce_sum_axis0", m, {t}, {out}, - baseline_reduce_sum_axis0); - UNITTEST_LOG(result); - UNITTEST_TRUE(result.max_diff[0] < - ark::reduction_abs_error_bound(0.1, t.shape()[0])); - return ark::unittest::SUCCESS; -} - -ark::unittest::State test_reduce_sum_axis1() { - ark::Model m; - ark::Tensor t = m.tensor(ark::Dims(1, 2, 4, 1024), ark::FP32); - ark::Tensor out = m.reduce_sum(t, /*axis=*/1); - - auto result = ark::op_test("reduce_sum_axis1", m, {t}, {out}, - baseline_reduce_sum_axis1); - UNITTEST_LOG(result); - UNITTEST_TRUE(result.max_diff[0] < - ark::reduction_abs_error_bound(0.1, t.shape()[1])); - return ark::unittest::SUCCESS; -} - -ark::unittest::State test_reduce_sum_axis2() { - ark::Model m; - ark::Tensor t = m.tensor(ark::Dims(1, 1, 7, 8192), ark::FP32); - ark::Tensor out = m.reduce_sum(t, /*axis=*/2); - - auto result = ark::op_test("reduce_sum_axis2", m, {t}, {out}, - baseline_reduce_sum_axis2); - UNITTEST_LOG(result); - UNITTEST_TRUE(result.max_diff[0] < - ark::reduction_abs_error_bound(0.1, t.shape()[2])); - return ark::unittest::SUCCESS; -} - -ark::unittest::State test_reduce_sum_axis3() { - ark::Model m; - ark::Tensor t = m.tensor(ark::Dims(1, 1, 2, 8192), ark::FP32); - ark::Tensor out = m.reduce_sum(t, /*axis=*/3); - - auto result = ark::op_test("reduce_sum_axis3", m, {t}, {out}, - baseline_reduce_sum_axis3); - UNITTEST_LOG(result); - UNITTEST_TRUE(result.max_diff[0] < - ark::reduction_abs_error_bound(0.1, t.shape()[3])); - return ark::unittest::SUCCESS; -} - -ark::unittest::State test_reduce_sum_axis3_padded() { - ark::Model m; - ark::Tensor t = m.tensor(ark::Dims(1, 1, 2, 8192), ark::FP32); - ark::Tensor out = - m.tensor(ark::Dims(1, 1, 2, 1), ark::FP32, ark::Dims(1, 1, 2, 32)); - out = m.reduce_sum(t, /*axis=*/3, true, out); - - auto result = ark::op_test("reduce_sum_axis3_padded", m, {t}, {out}, - baseline_reduce_sum_axis3); - UNITTEST_LOG(result); - UNITTEST_TRUE(result.max_diff[0] < - ark::reduction_abs_error_bound(0.1, t.shape()[3])); - return ark::unittest::SUCCESS; -} - -ark::unittest::State test_reduce_sum_fp16() { - { - ark::Model m; - ark::Tensor t = m.tensor(ark::Dims(7, 2, 4, 1024), ark::FP16); - ark::Tensor out = m.reduce_sum(t, /*axis=*/0); - - auto result = ark::op_test("reduce_sum_fp16_axis0", m, {t}, {out}, - baseline_reduce_sum_axis0); - UNITTEST_LOG(result); - UNITTEST_TRUE( - result.max_diff[0] < - ark::reduction_abs_error_bound(0.1, t.shape()[0])); - } - { - ark::Model m; - ark::Tensor t = m.tensor(ark::Dims(7, 2, 4, 1024), ark::FP16); - ark::Tensor out = m.reduce_sum(t, /*axis=*/3); - - auto result = ark::op_test("reduce_sum_fp16_axis3", m, {t}, {out}, - baseline_reduce_sum_axis3); - UNITTEST_LOG(result); - UNITTEST_TRUE( - result.max_diff[0] < - ark::reduction_abs_error_bound(0.1, t.shape()[3])); - } - return ark::unittest::SUCCESS; -} - -ark::unittest::State test_reduce_sum_bf16() { - std::vector data_vec(7 * 2 * 4 * 256); - for (size_t i = 0; i < data_vec.size(); ++i) { - data_vec[i] = ark::bf16((i % 1000) * 1e-4f); - } - - { - ark::Model m; - ark::Tensor t = m.tensor(ark::Dims(7, 2, 4, 256), ark::BF16); - ark::Tensor out = m.reduce_sum(t, /*axis=*/0); - - auto result = ark::op_test("reduce_sum_bf16_axis0", m, {t}, {out}, - baseline_reduce_sum_axis0, - {data_vec.data()}); - UNITTEST_LOG(result); - UNITTEST_TRUE( - result.max_diff[0] < - ark::reduction_abs_error_bound(0.1, t.shape()[0])); - } - { - ark::Model m; - ark::Tensor t = m.tensor(ark::Dims(7, 2, 4, 256), ark::BF16); - ark::Tensor out = m.reduce_sum(t, /*axis=*/3); - - auto result = ark::op_test("reduce_sum_bf16_axis3", m, {t}, {out}, - baseline_reduce_sum_axis3, - {data_vec.data()}); - UNITTEST_LOG(result); - UNITTEST_TRUE( - result.max_diff[0] < - ark::reduction_abs_error_bound(0.1, t.shape()[3])); - } - return ark::unittest::SUCCESS; -} - -ark::unittest::State test_reduce_sum_fp16_no_keepdims() { - { - ark::Model m; - ark::Tensor t = m.tensor(ark::Dims(7, 2, 4, 1024), ark::FP16); - ark::Tensor out = m.reduce_sum(t, /*axis=*/0, false); - - UNITTEST_EQ(out.shape(), ark::Dims(2, 4, 1024)); - - auto result = - ark::op_test("reduce_sum_fp16_axis0", m, {t}, {out}, - baseline_reduce_sum_axis0); - UNITTEST_LOG(result); - UNITTEST_TRUE( - result.max_diff[0] < - ark::reduction_abs_error_bound(0.1, t.shape()[0])); - } - { - ark::Model m; - ark::Tensor t = m.tensor(ark::Dims(7, 2, 4, 1024), ark::FP16); - ark::Tensor out = m.reduce_sum(t, /*axis=*/3, false); - - UNITTEST_EQ(out.shape(), ark::Dims(7, 2, 4)); - - auto result = - ark::op_test("reduce_sum_fp16_axis3", m, {t}, {out}, - baseline_reduce_sum_axis3); - UNITTEST_LOG(result); - UNITTEST_TRUE( - result.max_diff[0] < - ark::reduction_abs_error_bound(0.1, t.shape()[3])); - } - return ark::unittest::SUCCESS; -} - -ark::unittest::State test_reduce_sum_invalid() { - { - ark::Model m; - ark::Tensor t = m.tensor(ark::Dims(7, 2, 4, 1024), ark::BF16); - ark::Tensor out = m.tensor(ark::Dims(1, 2, 4, 1024), ark::FP32); - UNITTEST_THROW(m.reduce_sum(t, /*axis=*/0, true, out), ark::ModelError); - } - { - ark::Model m; - ark::Tensor t = m.tensor(ark::Dims(7, 2, 4, 1024), ark::BF16); - ark::Tensor out = m.tensor(ark::Dims(7, 2, 4, 1), ark::FP32); - UNITTEST_THROW(m.reduce_sum(t, /*axis=*/3, true, out), ark::ModelError); - } - { - ark::Model m; - ark::Tensor t = m.tensor(ark::Dims(7, 2, 4, 1024), ark::BF16); - ark::Tensor out = m.tensor(ark::Dims(1, 2, 4, 512), ark::BF16); - UNITTEST_THROW(m.reduce_sum(t, /*axis=*/0, true, out), ark::ModelError); - } - { - ark::Model m; - ark::Tensor t = m.tensor(ark::Dims(7, 2, 4, 1024), ark::BF16); - ark::Tensor out = m.tensor(ark::Dims(7, 1, 4, 1), ark::BF16); - UNITTEST_THROW(m.reduce_sum(t, /*axis=*/3, true, out), ark::ModelError); - } - { - ark::Model m; - ark::Tensor t = m.tensor(ark::Dims(7, 2, 4, 1024), ark::BF16); - ark::Tensor out = m.tensor(ark::Dims(3, 2, 4, 1024), ark::BF16); - UNITTEST_THROW(m.reduce_sum(t, /*axis=*/0, true, out), ark::ModelError); - } - { - ark::Model m; - ark::Tensor t = m.tensor(ark::Dims(7, 2, 4, 1024), ark::BF16); - ark::Tensor out = m.tensor(ark::Dims(7, 2, 4, 3), ark::BF16); - UNITTEST_THROW(m.reduce_sum(t, /*axis=*/3, true, out), ark::ModelError); - } - { - ark::Model m; - ark::Tensor t = m.tensor(ark::Dims(7, 2, 4, 1), ark::BF16); - UNITTEST_THROW(m.reduce_sum(t, /*axis=*/3, true, t), ark::ModelError); - } - return ark::unittest::SUCCESS; -} - -ark::unittest::State test_reduce_max_axis3() { - ark::Model m; - ark::Tensor t = m.tensor(ark::Dims(1, 1, 2, 8192), ark::FP32); - ark::Tensor out = m.reduce_max(t, /*axis=*/3); - - auto result = ark::op_test("reduce_max_axis3", m, {t}, {out}, - baseline_reduce_max_axis3); - UNITTEST_LOG(result); - UNITTEST_EQ(result.max_diff[0], 0); - return ark::unittest::SUCCESS; -} - -ark::unittest::State test_reduce_mean_axis3() { - ark::Model m; - ark::Tensor t = m.tensor(ark::Dims(1, 1, 2, 8192), ark::FP32); - ark::Tensor out = m.reduce_mean(t, /*axis=*/3); - - auto result = ark::op_test("reduce_mean_axis3", m, {t}, {out}, - baseline_reduce_mean_axis3); - UNITTEST_LOG(result); - UNITTEST_TRUE(result.max_diff[0] < - ark::reduction_abs_error_bound(0.1, t.shape()[3]) / - t.shape()[3]); - return ark::unittest::SUCCESS; -} - -ark::unittest::State test_reduce_invalid() { - ark::Model m; - ark::Tensor t = m.tensor(ark::Dims(7, 2, 4, 1024), ark::FP32); - UNITTEST_THROW(m.reduce_max(t, /*axis=*/-10), ark::ModelError); - return ark::unittest::SUCCESS; -} - -int main() { - ark::init(); - UNITTEST(test_reduce_sum_axis0); - UNITTEST(test_reduce_sum_axis1); - UNITTEST(test_reduce_sum_axis2); - UNITTEST(test_reduce_sum_axis3); - UNITTEST(test_reduce_sum_axis3_padded); - UNITTEST(test_reduce_sum_fp16); - UNITTEST(test_reduce_sum_bf16); - UNITTEST(test_reduce_sum_fp16_no_keepdims); - UNITTEST(test_reduce_sum_invalid); - UNITTEST(test_reduce_max_axis3); - UNITTEST(test_reduce_mean_axis3); - UNITTEST(test_reduce_invalid); - return ark::unittest::SUCCESS; -} diff --git a/ark/ops/ops_rope_test.cpp b/ark/ops/ops_rope_test.cpp deleted file mode 100644 index b0812faed..000000000 --- a/ark/ops/ops_rope_test.cpp +++ /dev/null @@ -1,103 +0,0 @@ -// Copyright (c) Microsoft Corporation. -// Licensed under the MIT license. - -#include "ark/model.hpp" -#include "ops_test_common.hpp" -#include "unittest/unittest_utils.h" - -template -void baseline_rope(std::vector &outputs, const std::vector &, - const std::vector &inputs, - const std::vector &input_shapes, int) { - T *out = static_cast(outputs[0]); - T *input = static_cast(inputs[0]); - T *other = static_cast(inputs[1]); - - ark::Dims ish = input_shapes[0].dims4(); - - for (ark::DimType n = 0; n < ish[0]; ++n) { - for (ark::DimType c = 0; c < ish[1]; ++c) { - for (ark::DimType h = 0; h < ish[2]; ++h) { - for (ark::DimType w = 0; w < ish[3]; w += 2) { - int idx = n * ish[1] * ish[2] * ish[3] + - c * ish[2] * ish[3] + h * ish[3] + w; - T input0 = input[idx]; - T input1 = input[idx + 1]; - T other0 = other[idx]; - T other1 = other[idx + 1]; - out[idx] = input0 * other0 - input1 * other1; - out[idx + 1] = input0 * other1 + input1 * other0; - } - } - } - } -} - -ark::unittest::State test_rope_fp32() { - ark::Model model; - ark::Tensor input = model.tensor(ark::Dims(1, 32, 32, 256), ark::FP32); - ark::Tensor other = model.tensor(ark::Dims(1, 32, 32, 256), ark::FP32); - ark::Tensor out = model.rope(input, other); - auto result = ark::op_test("rope", model, {input, other}, {out}, - baseline_rope); - UNITTEST_LOG(result); - UNITTEST_TRUE(result.max_diff[0] < 1e-6f); - return ark::unittest::SUCCESS; -} - -ark::unittest::State test_rope_fp16() { - ark::Model model; - ark::Tensor input = model.tensor(ark::Dims(1, 32, 32, 256), ark::FP16); - ark::Tensor other = model.tensor(ark::Dims(1, 32, 32, 256), ark::FP16); - ark::Tensor out = model.rope(input, other); - auto result = ark::op_test("rope", model, {input, other}, {out}, - baseline_rope); - UNITTEST_LOG(result); - UNITTEST_TRUE(result.max_diff[0] < 1e-3f); - return ark::unittest::SUCCESS; -} - -ark::unittest::State test_rope_bf16() { - ark::Model model; - ark::Tensor input = model.tensor(ark::Dims(1, 32, 32, 256), ark::BF16); - ark::Tensor other = model.tensor(ark::Dims(1, 32, 32, 256), ark::BF16); - ark::Tensor out = model.rope(input, other); - auto result = ark::op_test("rope", model, {input, other}, {out}, - baseline_rope); - UNITTEST_LOG(result); - UNITTEST_TRUE(result.max_diff[0] < 1e-3f); - return ark::unittest::SUCCESS; -} - -ark::unittest::State test_rope_invalid() { - { - ark::Model model; - ark::Tensor input = model.tensor(ark::Dims(1, 32, 32, 256), ark::BF16); - ark::Tensor other = model.tensor(ark::Dims(1, 32, 32, 256), ark::FP32); - UNITTEST_THROW(model.rope(input, other), ark::ModelError); - } - { - ark::Model model; - ark::Tensor input = model.tensor(ark::Dims(1, 32, 32, 256), ark::FP32); - ark::Tensor other = model.tensor(ark::Dims(1, 32, 32, 256), ark::FP32); - ark::Tensor output = model.tensor(ark::Dims(1, 32, 32, 256), ark::FP16); - UNITTEST_THROW(model.rope(input, other, output), ark::ModelError); - } - { - ark::Model model; - ark::Tensor input = model.tensor(ark::Dims(1, 32, 32, 256), ark::FP32); - ark::Tensor other = model.tensor(ark::Dims(1, 32, 32, 256), ark::FP32); - ark::Tensor output = model.tensor(ark::Dims(1, 32, 32, 32), ark::FP32); - UNITTEST_THROW(model.rope(input, other, output), ark::ModelError); - } - return ark::unittest::SUCCESS; -} - -int main() { - ark::init(); - UNITTEST(test_rope_fp32); - UNITTEST(test_rope_fp16); - UNITTEST(test_rope_bf16); - UNITTEST(test_rope_invalid); - return ark::unittest::SUCCESS; -} diff --git a/ark/ops/ops_scalar.cpp b/ark/ops/ops_scalar.cpp index 944a7247c..c65bc93de 100644 --- a/ark/ops/ops_scalar.cpp +++ b/ark/ops/ops_scalar.cpp @@ -39,14 +39,14 @@ std::string ModelOpScalarAssign::impl_name(const Json &config) const { std::to_string(num_warps), std::to_string(0)}); } -std::vector ModelOpScalarAssign::impl_args([ - [maybe_unused]] const Json &config) const { +std::vector ModelOpScalarAssign::impl_args( + [[maybe_unused]] const Json &config) const { float val = args_.at("Value").value(); return {result_tensors_[0], val}; } -Json ModelOpScalarAssign::default_config([ - [maybe_unused]] const ArchRef arch) const { +Json ModelOpScalarAssign::default_config( + [[maybe_unused]] const ArchRef arch) const { Json config; config["NumWarps"] = 1; config["SramBytes"] = 0; @@ -84,8 +84,8 @@ ModelOpScalarAdd::ModelOpScalarAdd(ModelTensorRef input, float factor, verify(); } -std::vector ModelOpScalarAdd::impl_args([ - [maybe_unused]] const Json &config) const { +std::vector ModelOpScalarAdd::impl_args( + [[maybe_unused]] const Json &config) const { float factor = args_.at("Factor").value(); return {result_tensors_[0], read_tensors_[0], factor}; } @@ -106,8 +106,8 @@ ModelOpScalarMul::ModelOpScalarMul(ModelTensorRef input, float factor, verify(); } -std::vector ModelOpScalarMul::impl_args([ - [maybe_unused]] const Json &config) const { +std::vector ModelOpScalarMul::impl_args( + [[maybe_unused]] const Json &config) const { float factor = args_.at("Factor").value(); return {result_tensors_[0], read_tensors_[0], factor}; } diff --git a/ark/ops/ops_scalar_test.cpp b/ark/ops/ops_scalar_test.cpp deleted file mode 100644 index 47a5b40bd..000000000 --- a/ark/ops/ops_scalar_test.cpp +++ /dev/null @@ -1,345 +0,0 @@ -// Copyright (c) Microsoft Corporation. -// Licensed under the MIT license. - -#include "ark/executor.hpp" -#include "ark/model.hpp" -#include "ops_test_common.hpp" -#include "unittest/unittest_utils.h" - -#define FACTOR 0.7 - -template -void baseline_scalar_add(std::vector &outputs, - const std::vector &output_shapes, - const std::vector &inputs, - const std::vector &, int) { - T *out = static_cast(outputs[0]); - T *input = static_cast(inputs[0]); - ark::Dims osh = output_shapes[0]; - for (ark::DimType i = 0; i < osh.nelems(); ++i) { - out[i] = input[i] + T(FACTOR); - } -}; - -template -void baseline_scalar_sub(std::vector &outputs, - const std::vector &output_shapes, - const std::vector &inputs, - const std::vector &, int) { - T *out = static_cast(outputs[0]); - T *input = static_cast(inputs[0]); - ark::Dims osh = output_shapes[0]; - for (ark::DimType i = 0; i < osh.nelems(); ++i) { - out[i] = input[i] - T(FACTOR); - } -}; - -template -void baseline_scalar_mul(std::vector &outputs, - const std::vector &output_shapes, - const std::vector &inputs, - const std::vector &, int) { - T *out = static_cast(outputs[0]); - T *input = static_cast(inputs[0]); - ark::Dims osh = output_shapes[0]; - for (ark::DimType i = 0; i < osh.nelems(); ++i) { - out[i] = input[i] * T(FACTOR); - } -}; - -template -void baseline_scalar_div(std::vector &outputs, - const std::vector &output_shapes, - const std::vector &inputs, - const std::vector &, int) { - T *out = static_cast(outputs[0]); - T *input = static_cast(inputs[0]); - ark::Dims osh = output_shapes[0]; - for (ark::DimType i = 0; i < osh.nelems(); ++i) { - out[i] = input[i] / T(FACTOR); - } -}; - -ark::unittest::State test_scalar_assign_fp16() { - { - ark::Model m; - ark::Tensor t = m.constant(7, ark::Dims(4, 2, 50), ark::FP16); - - ark::DefaultExecutor exe(m); - - exe.launch(); - exe.run(1); - exe.stop(); - - std::vector data(4 * 2 * 50); - exe.tensor_read(t, data); - for (auto v : data) { - UNITTEST_EQ(v, ark::half_t(7)); - } - } - { - ark::Model m; - ark::Tensor t = m.tensor(ark::Dims(4, 2, 50), ark::FP16); - ark::Tensor out = m.copy(7, t); - - ark::DefaultExecutor exe(m); - - std::vector data(4 * 2 * 50, 3); - exe.tensor_write(t, data); - - exe.launch(); - exe.run(1); - exe.stop(); - - data.clear(); - data.resize(4 * 2 * 50); - exe.tensor_read(t, data); - for (auto v : data) { - UNITTEST_EQ(v, ark::half_t(7)); - } - } - return ark::unittest::SUCCESS; -} - -ark::unittest::State test_scalar_assign_fp32() { - { - ark::Model m; - ark::Tensor out = m.copy(7); - - ark::DefaultExecutor exe(m); - - exe.launch(); - exe.run(1); - exe.stop(); - - std::vector data(1); - exe.tensor_read(out, data); - for (auto v : data) { - UNITTEST_EQ(v, 7); - } - } - return ark::unittest::SUCCESS; -} - -ark::unittest::State test_scalar_add_fp16() { - { - ark::Model m; - ark::Tensor t = m.tensor(ark::Dims(4, 2, 1), ark::FP16); - ark::Tensor out = m.add(t, FACTOR); - - auto result = ark::op_test("scalar_add_fp16_small", m, {t}, {out}, - baseline_scalar_add); - UNITTEST_LOG(result); - UNITTEST_EQ(result.max_diff[0], 0.0f); - } - { - ark::Model m; - ark::Tensor t = m.tensor(ark::Dims(4, 2, 1024), ark::FP16); - ark::Tensor out = m.add(t, FACTOR); - - auto result = ark::op_test("scalar_add_fp16", m, {t}, {out}, - baseline_scalar_add); - UNITTEST_LOG(result); - UNITTEST_EQ(result.max_diff[0], 0.0f); - } - return ark::unittest::SUCCESS; -} - -ark::unittest::State test_scalar_sub_fp16() { - { - ark::Model m; - ark::Tensor t = m.tensor(ark::Dims(4, 2, 1), ark::FP16); - ark::Tensor out = m.sub(t, FACTOR); - - auto result = ark::op_test("scalar_sub_fp16_small", m, {t}, {out}, - baseline_scalar_sub); - UNITTEST_LOG(result); - UNITTEST_EQ(result.max_diff[0], 0.0f); - } - { - ark::Model m; - ark::Tensor t = m.tensor(ark::Dims(4, 2, 1024), ark::FP16); - ark::Tensor out = m.sub(t, FACTOR); - - auto result = ark::op_test("scalar_sub_fp16", m, {t}, {out}, - baseline_scalar_sub); - UNITTEST_LOG(result); - UNITTEST_EQ(result.max_diff[0], 0.0f); - } - return ark::unittest::SUCCESS; -} - -ark::unittest::State test_scalar_mul_fp32() { - { - ark::Model m; - ark::Tensor t = m.tensor(ark::Dims(4, 2, 1), ark::FP32); - ark::Tensor out = m.mul(t, FACTOR); - - auto result = ark::op_test("scalar_mul_fp32_small", m, {t}, {out}, - baseline_scalar_mul); - UNITTEST_LOG(result); - } - { - ark::Model m; - ark::Tensor t = m.tensor(ark::Dims(4, 2, 1024), ark::FP32); - ark::Tensor out = m.mul(t, FACTOR); - - auto result = ark::op_test("scalar_mul_fp32", m, {t}, {out}, - baseline_scalar_mul); - UNITTEST_LOG(result); - UNITTEST_EQ(result.max_diff[0], 0.0f); - } - return ark::unittest::SUCCESS; -} - -ark::unittest::State test_scalar_mul_fp16() { - { - ark::Model m; - ark::Tensor t = m.tensor(ark::Dims(4, 2, 1), ark::FP16); - ark::Tensor out = m.mul(t, FACTOR); - - auto result = ark::op_test("scalar_mul_fp16_small", m, {t}, {out}, - baseline_scalar_mul); - UNITTEST_LOG(result); - UNITTEST_EQ(result.max_diff[0], 0.0f); - } - { - ark::Model m; - ark::Tensor t = m.tensor(ark::Dims(4, 2, 1024), ark::FP16); - ark::Tensor out = m.mul(t, FACTOR); - - auto result = ark::op_test("scalar_mul_fp16", m, {t}, {out}, - baseline_scalar_mul); - UNITTEST_LOG(result); - UNITTEST_EQ(result.max_diff[0], 0.0f); - } - return ark::unittest::SUCCESS; -} - -ark::unittest::State test_scalar_mul_bf16() { - { - ark::Model m; - ark::Tensor t = m.tensor(ark::Dims(4, 2, 1), ark::BF16); - ark::Tensor out = m.mul(t, FACTOR); - - auto result = ark::op_test("scalar_mul_bf16_small", m, {t}, {out}, - baseline_scalar_mul); - UNITTEST_LOG(result); - UNITTEST_EQ(result.max_diff[0], 0.0f); - } - { - ark::Model m; - ark::Tensor t = m.tensor(ark::Dims(4, 2, 1024), ark::BF16); - ark::Tensor out = m.mul(t, FACTOR); - - auto result = ark::op_test("scalar_mul_bf16", m, {t}, {out}, - baseline_scalar_mul); - UNITTEST_LOG(result); - UNITTEST_EQ(result.max_diff[0], 0.0f); - } - return ark::unittest::SUCCESS; -} - -ark::unittest::State test_scalar_mul_invalid() { - { - ark::Model m; - ark::Tensor t = m.tensor(ark::Dims(4, 2, 1024), ark::BF16); - ark::Tensor out = m.tensor(ark::Dims(4, 2, 1024), ark::FP32); - UNITTEST_THROW(m.mul(t, 3, out), ark::ModelError); - } - { - ark::Model m; - ark::Tensor t = m.tensor(ark::Dims(4, 2, 1024), ark::BF16); - ark::Tensor out = m.tensor(ark::Dims(4, 4, 1024), ark::BF16); - UNITTEST_THROW(m.mul(t, 3, out), ark::ModelError); - } - return ark::unittest::SUCCESS; -} - -ark::unittest::State test_scalar_mul_fp16_offset() { - { - ark::Model m; - ark::Tensor buf = m.tensor({1024}, ark::FP16); - ark::Tensor tns = m.refer(buf, {2}, {1024}, {6}); - ark::Tensor doubled = m.mul(tns, 2, tns); - ark::Tensor out = m.identity(buf, {doubled}); - - std::vector data(1024, ark::half_t(2)); - auto result = ark::op_test( - "scalar_mul_fp16_offset", m, {buf}, {out}, - [](std::vector &outputs, const std::vector &, - const std::vector &, const std::vector &, - int) { - ark::half_t *out = static_cast(outputs[0]); - for (size_t i = 0; i < 1024; ++i) { - if (i == 6 || i == 7) { - out[i] = 4; - } else { - out[i] = 2; - } - } - }, - {data.data()}); - UNITTEST_LOG(result); - UNITTEST_EQ(result.max_diff[0], 0.0f); - } - return ark::unittest::SUCCESS; -} - -ark::unittest::State test_scalar_mul_perf() { - ark::DimType nelem = 8 * 1024 * 1024; - - ark::Model m; - ark::Tensor t = m.tensor({nelem}, ark::FP32); - ark::Tensor out = m.mul(t, 0.7); - - auto result = ark::op_test("scalar_mul_perf", m, {t}, {out}, - baseline_scalar_mul); - UNITTEST_LOG(result); - UNITTEST_EQ(result.max_diff[0], 0.0f); - - float gbps = nelem * sizeof(float) / result.msec_per_iter * 1e-6; - UNITTEST_LOG(gbps, " GB/s"); - return ark::unittest::SUCCESS; -} - -ark::unittest::State test_scalar_div_fp16() { - float rel_err_bound = ark::division_rel_error_bound(FACTOR); - { - ark::Model m; - ark::Tensor t = m.tensor(ark::Dims(4, 2, 1), ark::FP16); - ark::Tensor out = m.div(t, FACTOR); - - auto result = ark::op_test("scalar_div_fp16_small", m, {t}, {out}, - baseline_scalar_div); - UNITTEST_LOG(result); - UNITTEST_LT(result.max_err_rate[0], rel_err_bound); - } - { - ark::Model m; - ark::Tensor t = m.tensor(ark::Dims(4, 2, 1024), ark::FP16); - ark::Tensor out = m.div(t, FACTOR); - - auto result = ark::op_test("scalar_div_fp16", m, {t}, {out}, - baseline_scalar_div); - UNITTEST_LOG(result); - UNITTEST_LT(result.max_err_rate[0], rel_err_bound); - } - return ark::unittest::SUCCESS; -} - -int main() { - ark::init(); - UNITTEST(test_scalar_assign_fp16); - UNITTEST(test_scalar_assign_fp32); - UNITTEST(test_scalar_add_fp16); - UNITTEST(test_scalar_sub_fp16); - UNITTEST(test_scalar_mul_fp32); - UNITTEST(test_scalar_mul_fp16); - UNITTEST(test_scalar_mul_bf16); - UNITTEST(test_scalar_mul_invalid); - UNITTEST(test_scalar_mul_fp16_offset); - UNITTEST(test_scalar_mul_perf); - UNITTEST(test_scalar_div_fp16); - return ark::unittest::SUCCESS; -} diff --git a/ark/ops/ops_test_common.cpp b/ark/ops/ops_test_common.cpp index bfbe79a70..f902e626d 100644 --- a/ark/ops/ops_test_common.cpp +++ b/ark/ops/ops_test_common.cpp @@ -32,12 +32,13 @@ std::ostream &operator<<(std::ostream &os, const OpsTestResult &result) { return os; } -OpsTestResult op_test( - const std::string &test_name_prefix, const Model &model, - const std::vector &inputs, const std::vector &outputs, - OpsTestBaseline baseline, const std::vector &inputs_data, - const std::vector &config_rules, - bool print_on_error) { +OpsTestResult op_test(const std::string &test_name_prefix, const Model &model, + const std::vector &inputs, + const std::vector &outputs, + OpsTestBaseline baseline, + const std::vector &inputs_data, + const std::vector &config_rules, + bool print_on_error) { DefaultExecutor exe(model, -1, nullptr, config_rules); std::vector>> inputs_data_storages; diff --git a/ark/ops/ops_test_common.hpp b/ark/ops/ops_test_common.hpp index 12fb88a7b..cd3f0b7f6 100644 --- a/ark/ops/ops_test_common.hpp +++ b/ark/ops/ops_test_common.hpp @@ -167,12 +167,13 @@ using OpsTestBaseline = std::function &inputs, const std::vector &outputs, - OpsTestBaseline baseline, const std::vector &inputs_data = {}, - const std::vector &config_rules = {}, - bool print_on_error = false); +OpsTestResult op_test(const std::string &test_name_prefix, const Model &model, + const std::vector &inputs, + const std::vector &outputs, + OpsTestBaseline baseline, + const std::vector &inputs_data = {}, + const std::vector &config_rules = {}, + bool print_on_error = false); OpsTestGpuMem to_gpu(void *host_ptr, size_t size); diff --git a/ark/ops/ops_transpose.cpp b/ark/ops/ops_transpose.cpp index b7a67c8c0..f1b079c2d 100644 --- a/ark/ops/ops_transpose.cpp +++ b/ark/ops/ops_transpose.cpp @@ -112,13 +112,13 @@ std::string ModelOpTranspose::impl_name(const Json &config) const { }); } -std::vector ModelOpTranspose::impl_args([ - [maybe_unused]] const Json &config) const { +std::vector ModelOpTranspose::impl_args( + [[maybe_unused]] const Json &config) const { return {result_tensors_[0], read_tensors_[0]}; } -Json ModelOpTranspose::default_config([ - [maybe_unused]] const ArchRef arch) const { +Json ModelOpTranspose::default_config( + [[maybe_unused]] const ArchRef arch) const { Json config; config["NumWarps"] = 1; config["SramBytes"] = 0; diff --git a/ark/ops/ops_transpose_test.cpp b/ark/ops/ops_transpose_test.cpp deleted file mode 100644 index 139e1ee66..000000000 --- a/ark/ops/ops_transpose_test.cpp +++ /dev/null @@ -1,281 +0,0 @@ -// Copyright (c) Microsoft Corporation. -// Licensed under the MIT license. - -#include "ark/model.hpp" -#include "ark/planner.hpp" -#include "model/model_json.hpp" -#include "ops_test_common.hpp" -#include "unittest/unittest_utils.h" - -#define SYNC_TEST 0 - -template -void baseline_transpose_0132(std::vector &outputs, - const std::vector &output_shapes, - const std::vector &inputs, - const std::vector &input_shapes, int) { - T *out = static_cast(outputs[0]); - T *in = static_cast(inputs[0]); - ark::Dims osh = output_shapes[0].dims4(); - ark::Dims ish = input_shapes[0].dims4(); - for (ark::DimType n = 0; n < ish[0]; ++n) { - for (ark::DimType c = 0; c < ish[1]; ++c) { - for (ark::DimType h = 0; h < ish[2]; ++h) { - for (ark::DimType w = 0; w < ish[3]; ++w) { - // out[n][c][w][h] = in[n][c][h][w] - out[h + w * osh[3] + c * osh[2] * osh[3] + - n * osh[1] * osh[2] * osh[3]] = - in[w + h * ish[3] + c * ish[3] * ish[2] + - n * ish[3] * ish[2] * ish[1]]; - } - } - } - } -}; - -template -void baseline_transpose_0231(std::vector &outputs, - const std::vector &output_shapes, - const std::vector &inputs, - const std::vector &input_shapes, int) { - T *out = static_cast(outputs[0]); - T *in = static_cast(inputs[0]); - ark::Dims osh = output_shapes[0].dims4(); - ark::Dims ish = input_shapes[0].dims4(); - for (ark::DimType n = 0; n < ish[0]; ++n) { - for (ark::DimType c = 0; c < ish[1]; ++c) { - for (ark::DimType h = 0; h < ish[2]; ++h) { - for (ark::DimType w = 0; w < ish[3]; ++w) { - // out[n][h][w][c] = in[n][c][h][w] - out[c + w * osh[3] + h * osh[2] * osh[3] + - n * osh[1] * osh[2] * osh[3]] = - in[w + h * ish[3] + c * ish[3] * ish[2] + - n * ish[3] * ish[2] * ish[1]]; - } - } - } - } -}; - -template -void baseline_transpose_0213(std::vector &outputs, - const std::vector &output_shapes, - const std::vector &inputs, - const std::vector &input_shapes, int) { - T *out = static_cast(outputs[0]); - T *in = static_cast(inputs[0]); - ark::Dims osh = output_shapes[0].dims4(); - ark::Dims ish = input_shapes[0].dims4(); - for (ark::DimType n = 0; n < ish[0]; ++n) { - for (ark::DimType c = 0; c < ish[1]; ++c) { - for (ark::DimType h = 0; h < ish[2]; ++h) { - for (ark::DimType w = 0; w < ish[3]; ++w) { - // out[n][h][c][w] = in[n][c][h][w] - out[w + c * osh[3] + h * osh[2] * osh[3] + - n * osh[1] * osh[2] * osh[3]] = - in[w + h * ish[3] + c * ish[3] * ish[2] + - n * ish[3] * ish[2] * ish[1]]; - } - } - } - } -}; - -template -void baseline_transpose_sync_test(std::vector &outputs, - const std::vector &, - const std::vector &inputs, - const std::vector &input_shapes, - int) { - T *out = static_cast(outputs[0]); - T *in = static_cast(inputs[0]); - ::memcpy(out, in, sizeof(T) * input_shapes[0].nelems()); -}; - -ark::unittest::State test_transpose_0132_fp32() { - ark::Model m; - ark::Tensor t = m.tensor({5, 3, 32, 128}, ark::FP32); - ark::Tensor out = m.transpose(t, {0, 1, 3, 2}); - - auto result = ark::op_test("transpose_0132_fp32", m, {t}, {out}, - baseline_transpose_0132); - UNITTEST_LOG(result); - UNITTEST_EQ(result.max_diff[0], 0.0f); - return ark::unittest::SUCCESS; -} - -ark::unittest::State test_transpose_0132_fp16() { - ark::Model m; - ark::Tensor t = m.tensor({5, 3, 32, 128}, ark::FP16); - ark::Tensor out = m.transpose(t, {0, 1, 3, 2}); - - auto result = ark::op_test("transpose_0132_fp16", m, {t}, {out}, - baseline_transpose_0132); - UNITTEST_LOG(result); - UNITTEST_EQ(result.max_diff[0], 0.0f); - return ark::unittest::SUCCESS; -} - -ark::unittest::State test_transpose_0132_bf16() { - ark::Model m; - ark::Tensor t = m.tensor({5, 3, 32, 128}, ark::BF16); - ark::Tensor out = m.transpose(t, {0, 1, 3, 2}); - - auto result = ark::op_test("transpose_0132_bf16", m, {t}, {out}, - baseline_transpose_0132); - UNITTEST_LOG(result); - UNITTEST_EQ(result.max_diff[0], 0.0f); - return ark::unittest::SUCCESS; -} - -ark::unittest::State test_transpose_0231_fp32() { - ark::Model m; - ark::Tensor t = m.tensor({5, 3, 32, 128}, ark::FP32); - ark::Tensor out = m.transpose(t, {0, 2, 3, 1}); - - auto result = ark::op_test("transpose_0231_fp32", m, {t}, {out}, - baseline_transpose_0231); - UNITTEST_LOG(result); - UNITTEST_EQ(result.max_diff[0], 0.0f); - return ark::unittest::SUCCESS; -} - -ark::unittest::State test_transpose_0231_fp16() { - ark::Model m; - ark::Tensor t = m.tensor({5, 3, 32, 128}, ark::FP16); - ark::Tensor out = m.transpose(t, {0, 2, 3, 1}); - - auto result = ark::op_test("transpose_0231_fp16", m, {t}, {out}, - baseline_transpose_0231); - UNITTEST_LOG(result); - UNITTEST_EQ(result.max_diff[0], 0.0f); - return ark::unittest::SUCCESS; -} - -ark::unittest::State test_transpose_0231_bf16() { - ark::Model m; - ark::Tensor t = m.tensor({5, 3, 32, 128}, ark::BF16); - ark::Tensor out = m.transpose(t, {0, 2, 3, 1}); - - auto result = ark::op_test("transpose_0231_bf16", m, {t}, {out}, - baseline_transpose_0231); - UNITTEST_LOG(result); - UNITTEST_EQ(result.max_diff[0], 0.0f); - return ark::unittest::SUCCESS; -} - -ark::unittest::State test_transpose_0213_fp32() { - ark::Model m; - ark::Tensor t = m.tensor({5, 3, 32, 128}, ark::FP32); - ark::Tensor out = m.transpose(t, {0, 2, 1, 3}); - - auto result = ark::op_test("transpose_0213_fp32", m, {t}, {out}, - baseline_transpose_0213); - UNITTEST_LOG(result); - UNITTEST_EQ(result.max_diff[0], 0.0f); - return ark::unittest::SUCCESS; -} - -ark::unittest::State test_transpose_0213_fp16() { - ark::Model m; - ark::PlannerContext ctx(m); - ctx.warp_range(0, 4); - ctx.sram_range(0, 0); - ctx.sync(false); - ctx.config(ark::Json({{"NumWarps", 4}, {"SramBytes", 0}, {"Tile", {8, 64}}}) - .dump()); - - ark::Tensor t = m.tensor({5, 256, 32, 128}, ark::FP16); - ark::Tensor out = m.transpose(t, {0, 2, 1, 3}); - - auto result = ark::op_test("transpose_0213_fp16", m, {t}, {out}, - baseline_transpose_0213); - UNITTEST_LOG(result); - UNITTEST_EQ(result.max_diff[0], 0.0f); - return ark::unittest::SUCCESS; -} - -ark::unittest::State test_transpose_0213_bf16() { - ark::Model m; - ark::Tensor t = m.tensor({5, 3, 32, 128}, ark::BF16); - ark::Tensor out = m.transpose(t, {0, 2, 1, 3}); - - auto result = ark::op_test("transpose_0213_bf16", m, {t}, {out}, - baseline_transpose_0213); - UNITTEST_LOG(result); - UNITTEST_EQ(result.max_diff[0], 0.0f); - return ark::unittest::SUCCESS; -} - -ark::unittest::State test_transpose_sync_test() { - ark::Model m; - ark::PlannerContext shared_ctx(m); - shared_ctx.warp_range(0, 4); - shared_ctx.sram_range(0, 0); - shared_ctx.sync(false); - - ark::Tensor in, t, out; - in = m.tensor({1, 16, 2, 64}, ark::FP16); - { - ark::PlannerContext ctx(m); - ctx.config( - ark::Json({{"NumWarps", 4}, {"SramBytes", 0}, {"Tile", {8, 64}}}) - .dump()); - t = m.transpose(in, {0, 2, 1, 3}); - } - { - ark::PlannerContext ctx(m); - ctx.config( - ark::Json({{"NumWarps", 4}, {"SramBytes", 0}, {"Tile", {8, 1, 64}}}) - .dump()); - out = m.transpose(t, {0, 2, 1, 3}); - } - - auto result = ark::op_test("transpose_sync_test", m, {in}, {out}, - baseline_transpose_sync_test); - UNITTEST_LOG(result); - UNITTEST_EQ(result.max_diff[0], 0.0f); - return ark::unittest::SUCCESS; -} - -ark::unittest::State test_transpose_invalid() { - { - ark::Model m; - ark::Tensor t = m.tensor({5}, ark::FP32); - UNITTEST_THROW(m.transpose(t, {0, 2, 3, 1}), ark::ModelError); - } - { - ark::Model m; - ark::Tensor t = m.tensor({5, 128}, ark::FP32); - UNITTEST_THROW(m.transpose(t, {0, 2, 3, 1}), ark::ModelError); - } - { - ark::Model m; - ark::Tensor t = m.tensor({5, 128}, ark::FP32); - UNITTEST_THROW(m.transpose(t, {0, 2}), ark::ModelError); - } - { - ark::Model m; - ark::Tensor t = m.tensor({5, 128}, ark::FP32); - UNITTEST_THROW(m.transpose(t, {1, 1}), ark::ModelError); - } - return ark::unittest::SUCCESS; -} - -int main() { - ark::init(); - UNITTEST(test_transpose_0132_fp32); - UNITTEST(test_transpose_0132_fp16); - UNITTEST(test_transpose_0132_bf16); - UNITTEST(test_transpose_0231_fp32); - UNITTEST(test_transpose_0231_fp16); - UNITTEST(test_transpose_0231_bf16); - UNITTEST(test_transpose_0213_fp32); - UNITTEST(test_transpose_0213_fp16); - UNITTEST(test_transpose_0213_bf16); -#if (SYNC_TEST) - UNITTEST(test_transpose_sync_test); -#endif - UNITTEST(test_transpose_invalid); - return ark::unittest::SUCCESS; -} diff --git a/ark/unittest/unittest_utils.cpp b/ark/unittest/unittest_utils.cpp index 4b74f9513..1b2aa029b 100644 --- a/ark/unittest/unittest_utils.cpp +++ b/ark/unittest/unittest_utils.cpp @@ -11,6 +11,7 @@ #include #include "file_io.h" +#include "gpu/gpu.hpp" #include "logging.hpp" // Grep SIGALRM and exit. @@ -96,6 +97,13 @@ void wait_all_processes() { // Run the given test function. State test(std::function test_func) { return test_func(); } +// Get the number of available GPUs. +int get_gpu_count() { + int count = 0; + if (gpuGetDeviceCount(&count) != gpuSuccess) return 0; + return count; +} + // std::string get_kernel_code(const std::string &name) { return ark::read_file(ark::get_dir(std::string{__FILE__}) + diff --git a/ark/unittest/unittest_utils.h b/ark/unittest/unittest_utils.h index 383f49b6d..e994bf80c 100644 --- a/ark/unittest/unittest_utils.h +++ b/ark/unittest/unittest_utils.h @@ -42,6 +42,8 @@ void wait_all_processes(); State test(std::function test_func); // +int get_gpu_count(); +// std::string get_kernel_code(const std::string &name); } // namespace unittest @@ -86,6 +88,15 @@ std::string get_kernel_code(const std::string &name); #define UNITTEST_UNEXPECTED(...) \ UNITTEST_EXIT(ark::unittest::UNEXPECTED, __VA_ARGS__) +// Skip the test if the condition is true. +#define UNITTEST_SKIP(cond) \ + do { \ + if (cond) { \ + UNITTEST_LOG("Skip: " #cond); \ + return ark::unittest::SUCCESS; \ + } \ + } while (0) + // Success. #define UNITTEST_SUCCESS() UNITTEST_EXIT(ark::unittest::SUCCESS, "") diff --git a/cmake/CheckNvidiaGpu.cmake b/cmake/CheckNvidiaGpu.cmake index 79f8589c4..ed445e5db 100644 --- a/cmake/CheckNvidiaGpu.cmake +++ b/cmake/CheckNvidiaGpu.cmake @@ -9,7 +9,8 @@ if(NOT CUDAToolkit_FOUND) return() endif() -set(CMAKE_CUDA_ARCHITECTURES "60") +# Use sm_80 as minimum for the detection check. +set(CMAKE_CUDA_ARCHITECTURES "80") if(NOT CMAKE_CUDA_COMPILER) # In case the CUDA Toolkit directory is not in the PATH find_program(CUDA_COMPILER diff --git a/examples/llama/model.py b/examples/llama/model.py index ebd424612..ad3c2f0b9 100644 --- a/examples/llama/model.py +++ b/examples/llama/model.py @@ -2,7 +2,7 @@ # Licensed under the MIT license. """LLaMA 2 Transformer model. - Correspond to https://github.com/facebookresearch/llama/blob/main/llama/model.py +Correspond to https://github.com/facebookresearch/llama/blob/main/llama/model.py """ import ark diff --git a/examples/multi_head_attention/mha.py b/examples/multi_head_attention/mha.py new file mode 100644 index 000000000..1cc3711b0 --- /dev/null +++ b/examples/multi_head_attention/mha.py @@ -0,0 +1,177 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +""" +Multi-Head Attention implemented as an ARK Module using composed small ops. + +Two versions: +1. MultiHeadAttention — standard (non-flash) attention for correctness baseline +2. FlashMultiHeadAttention — online softmax (flash attention) algorithm + +Both use ARK's operator composition with PlannerContext for scheduling. +""" + +import ark +import math + + +class MultiHeadAttention(ark.Module): + """Standard multi-head attention: O = softmax(Q @ K^T / sqrt(d)) @ V + + Args: + head_dim: dimension per head (used for scaling). + causal: whether to apply causal masking (not yet supported). + """ + + def __init__(self, head_dim: int, causal: bool = False): + super().__init__() + self.scale = 1.0 / math.sqrt(head_dim) + self.causal = causal + + def forward(self, q, r_k, v): + """ + Args: + q: (batch, heads, seq_len, head_dim) — query + r_k: (batch, heads, head_dim, seq_len) — key, already transposed + v: (batch, heads, seq_len, head_dim) — value + + Returns: + o: (batch, heads, seq_len, head_dim) + """ + # S = Q @ K^T -> (batch, heads, seq_len, seq_len) + s = ark.matmul(q, r_k) + + # Scale: S = S * (1 / sqrt(d)) + s = ark.mul(s, self.scale) + + # Softmax along last axis + # max + m = ark.reduce_max(s, axis=-1) + s = ark.sub(s, m) + s = ark.exp(s) + l = ark.reduce_sum(s, axis=-1) + p = ark.div(s, l) + + # O = P @ V -> (batch, heads, seq_len, head_dim) + o = ark.matmul(p, v) + return o + + +class MultiHeadAttentionOptimized(ark.Module): + """Tile-fused MHA: merges matmul, softmax, and output matmul into + aligned tile tasks using PlannerContext. + + Key insight: ARK's matmul is tile-based (e.g., [128, N] per task). + By configuring softmax ops to use the same tile grid — each task + processes the same row-block — all ops can be fused into one task + with sync=False. This eliminates ALL inter-op sync barriers. + + The tile alignment is: + - matmul(Q, K^T): [TileM, N] tiles of S matrix + - softmax(S): [TileM, N] tiles (full-row reduction per tile) + - matmul(P, V): [TileM, D] tiles of output + + All ops share the same number of tasks = batch*heads * ceil(N/TileM). + + Args: + head_dim: dimension per head. + seq_len: sequence length. + tile_m: row-block size for tiling (must divide seq_len). + """ + + def __init__(self, head_dim: int, seq_len: int = 256, tile_m: int = 128): + super().__init__() + self.scale = 1.0 / math.sqrt(head_dim) + self.seq_len = seq_len + self.tile_m = tile_m + + def forward(self, q, r_k, v): + shape = q.shape() + N = shape[-2] + D = shape[-1] + batch_heads = 1 + for d in shape[:-2]: + batch_heads *= d + TM = self.tile_m + S = self.seq_len # = N for self-attention + num_tasks = batch_heads * (N // TM) + + # Fuse matmul(Q,K^T) + softmax into one task per row-block. + # All ops use NumWarps=8 and tile height=TM to produce matching + # task counts. The key fix: reduce ops now use Tile=[TM,1] to + # match the matmul's tile grid. + with ark.PlannerContext( + sync=False, + warp_range=[0, 8], + sram_range=[0, 147456], + ): + # Matmul Q[TM,D] @ K^T[D,S] -> S[TM,S] + with ark.PlannerContext( + config={ + "NumWarps": 8, + "SramBytes": 147456, + "Tile": [TM, S], + }, + ): + s = ark.matmul(q, r_k) + + # scale — element-wise, tile matches matmul + with ark.PlannerContext( + config={ + "NumWarps": 8, + "SramBytes": 0, + "Tile": [TM, S], + "NumTasks": num_tasks, + }, + ): + s = ark.mul(s, self.scale) + + # reduce_max — NOW with Tile=[TM,1] to match task count + with ark.PlannerContext( + config={ + "NumWarps": 8, + "SramBytes": 256, + "ImplType": "WarpWise", + "Tile": [TM, 1], + }, + ): + m = ark.reduce_max(s, axis=-1) + + # sub + exp + with ark.PlannerContext( + config={ + "NumWarps": 8, + "SramBytes": 0, + "Tile": [TM, S], + "NumTasks": num_tasks, + }, + ): + s = ark.sub(s, m) + s = ark.exp(s) + + # reduce_sum — Tile=[TM,1] + with ark.PlannerContext( + config={ + "NumWarps": 8, + "SramBytes": 256, + "ImplType": "WarpWise", + "Tile": [TM, 1], + }, + ): + l = ark.reduce_sum(s, axis=-1) + + # div + with ark.PlannerContext( + config={ + "NumWarps": 8, + "SramBytes": 0, + "Tile": [TM, S], + "NumTasks": num_tasks, + }, + ): + p = ark.div(s, l) + + # Matmul P @ V — separate task (different SRAM requirement) + o = ark.matmul(p, v) + + return o diff --git a/examples/multi_head_attention/test_mha.py b/examples/multi_head_attention/test_mha.py new file mode 100644 index 000000000..6654ec3ef --- /dev/null +++ b/examples/multi_head_attention/test_mha.py @@ -0,0 +1,401 @@ +#!/usr/bin/env python3 +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +""" +Test and benchmark ARK MultiHeadAttention against FlashAttention-2. + +Correctness: uses Tensor.eval() for concise graph execution. +Benchmark: follows gpu-kernel-perf-bench methodology — + - L2 cache pollution via rotated input buffers + - Pilot-driven iteration count (target 0.1-0.3s total) + - torch.profiler for FlashAttention timing + - ARK native rt.run(iter=N) for ARK timing (persistent loop kernel) +""" + +import sys +import os +import math +import time + +import torch +import torch.nn.functional as F + +sys.path.insert(0, os.path.dirname(os.path.abspath(__file__))) +from mha import MultiHeadAttention, MultiHeadAttentionOptimized + +try: + from flash_attn import flash_attn_func + + _has_flash = True +except ImportError: + _has_flash = False + +import ark + +DEVICE = "cuda:0" + + +# ─── Correctness ──────────────────────────────────────────────────────────── + + +def test_correctness(B, H, N, D, dtype=torch.float16): + """Compare ARK MHA output against FlashAttention-2 using eval().""" + scale = 1.0 / math.sqrt(D) + q = torch.randn(B, H, N, D, dtype=dtype, device=DEVICE) + k = torch.randn(B, H, N, D, dtype=dtype, device=DEVICE) + v = torch.randn(B, H, N, D, dtype=dtype, device=DEVICE) + k_t = k.transpose(-2, -1).contiguous() + + # ARK vanilla — uses eval() + result = MultiHeadAttention(D)( + ark.Tensor.from_torch(q), + ark.Tensor.from_torch(k_t), + ark.Tensor.from_torch(v), + ).eval() + + # Reference + if _has_flash: + q_fa = q.transpose(1, 2).contiguous() + k_fa = k.transpose(1, 2).contiguous() + v_fa = v.transpose(1, 2).contiguous() + ref = flash_attn_func(q_fa, k_fa, v_fa, softmax_scale=scale) + ref = ref.transpose(1, 2).contiguous() + label = "FA2" + else: + ref = F.scaled_dot_product_attention(q, k, v, scale=scale) + label = "SDPA" + + diff = (result - ref).abs().max().item() + atol = 5e-2 if dtype == torch.float16 else 1e-1 + ok = diff < atol + print( + f" B={B} H={H} N={N:4d} D={D} diff={diff:.4f} vs {label} {'PASS' if ok else 'FAIL'}" + ) + return ok + + +# ─── Benchmark helpers ────────────────────────────────────────────────────── + +# L2 cache size for H200 ≈ 50 MB. Use 2× = 100 MB worth of buffers. +L2_CACHE_BYTES = 50 * 1024 * 1024 + + +def _make_rotated_inputs(B, H, N, D, dtype, num_bufs): + """Create multiple input buffer sets for L2 cache pollution.""" + bufs = [] + for _ in range(num_bufs): + q = torch.randn(B, H, N, D, dtype=dtype, device=DEVICE) + k = torch.randn(B, H, N, D, dtype=dtype, device=DEVICE) + v = torch.randn(B, H, N, D, dtype=dtype, device=DEVICE) + bufs.append((q, k, v)) + return bufs + + +def _pilot_iters(run_once_fn, target_sec=0.2): + """Determine iteration count to reach target_sec total time.""" + # Single pilot + torch.cuda.synchronize() + t0 = time.perf_counter() + run_once_fn() + torch.cuda.synchronize() + t1 = time.perf_counter() + per_iter = max(t1 - t0, 1e-6) + iters = max(1, int(target_sec / per_iter)) + return iters + + +def bench_flash_attn(B, H, N, D, dtype=torch.float16): + """Benchmark FlashAttention-2 with L2 pollution and torch.profiler.""" + if not _has_flash: + return float("nan") + scale = 1.0 / math.sqrt(D) + elem_bytes = N * D * torch.finfo(dtype).bits // 8 + num_bufs = max(4, (2 * L2_CACHE_BYTES) // (3 * B * H * elem_bytes) + 1) + bufs = _make_rotated_inputs(B, H, N, D, dtype, num_bufs) + + def run_one(i): + q, k, v = bufs[i % num_bufs] + q_fa = q.transpose(1, 2).contiguous() + k_fa = k.transpose(1, 2).contiguous() + v_fa = v.transpose(1, 2).contiguous() + flash_attn_func(q_fa, k_fa, v_fa, softmax_scale=scale) + + iters = _pilot_iters(lambda: run_one(0)) + + # Warmup + for i in range(min(3, iters)): + run_one(i) + + # Timed + torch.cuda.synchronize() + t0 = time.perf_counter() + for i in range(iters): + run_one(i) + torch.cuda.synchronize() + elapsed = time.perf_counter() - t0 + return elapsed / iters * 1000 # ms + + +def bench_ark(B, H, N, D, mha_cls, mha_args, dtype=torch.float16): + """Benchmark an ARK MHA module using the persistent loop kernel.""" + q = torch.randn(B, H, N, D, dtype=dtype, device=DEVICE) + k = torch.randn(B, H, N, D, dtype=dtype, device=DEVICE) + v = torch.randn(B, H, N, D, dtype=dtype, device=DEVICE) + k_t = k.transpose(-2, -1).contiguous() + + ark.init() + mha = mha_cls(*mha_args) + out = mha( + ark.Tensor.from_torch(q), + ark.Tensor.from_torch(k_t), + ark.Tensor.from_torch(v), + ) + + with ark.Runtime() as rt: + rt.launch() + # Pilot: single iteration + iters = _pilot_iters(lambda: rt.run(iter=1), target_sec=0.2) + + # Warmup + rt.run(iter=min(3, iters)) + + # Timed + torch.cuda.synchronize() + t0 = time.perf_counter() + rt.run(iter=iters) + elapsed = time.perf_counter() - t0 + + return elapsed / iters * 1000 # ms + + +def run_benchmark(B, H, N, D, dtype=torch.float16): + fa_ms = bench_flash_attn(B, H, N, D, dtype) + vanilla_ms = bench_ark(B, H, N, D, MultiHeadAttention, (D,), dtype) + opt_ms = bench_ark(B, H, N, D, MultiHeadAttentionOptimized, (D, N), dtype) + ratio = opt_ms / fa_ms if fa_ms > 0 else float("nan") + print( + f" B={B} H={H:2d} N={N:4d} D={D} " + f"FA2={fa_ms:.3f}ms ARK={vanilla_ms:.3f}ms ARK-Opt={opt_ms:.3f}ms " + f"(Opt/FA2={ratio:.2f}x)" + ) + + +# ─── Main ─────────────────────────────────────────────────────────────────── + +if __name__ == "__main__": + print("=" * 70) + print("Correctness: ARK MHA vs FlashAttention-2") + print("=" * 70) + all_pass = True + for B, H, N, D in [ + (1, 1, 256, 128), + (1, 4, 256, 128), + (2, 8, 256, 128), + (1, 1, 512, 128), + ]: + all_pass &= test_correctness(B, H, N, D) + + if not all_pass: + print("\nSome tests FAILED!") + sys.exit(1) + print("\nAll correctness tests PASSED!") + + print() + print("=" * 70) + print("Performance: ARK vs FlashAttention-2") + print("=" * 70) + for B, H, N, D in [ + (1, 1, 256, 128), + (1, 4, 256, 128), + (1, 8, 256, 128), + (1, 1, 512, 128), + (1, 4, 512, 128), + ]: + run_benchmark(B, H, N, D) +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +""" +Test and benchmark ARK MultiHeadAttention against: + - flash_attn (Tri Dao's FlashAttention-2, flash_attn_func) + - PyTorch SDPA (F.scaled_dot_product_attention, which dispatches to + flash/mem-efficient/math backends automatically) +""" + +import ark +import torch +import torch.nn.functional as F +import time +import math +import sys + +from flash_attn import flash_attn_func + +sys.path.insert(0, ".") +from mha import MultiHeadAttention, MultiHeadAttentionOptimized + + +def flash_attn_reference(q, k, v, scale): + """Run Tri Dao's FlashAttention-2. + + flash_attn_func expects (batch, seq_len, heads, head_dim). + Our tensors are (batch, heads, seq_len, head_dim), so we transpose. + """ + q_fa = q.transpose(1, 2).contiguous() # (B, N, H, D) + k_fa = k.transpose(1, 2).contiguous() + v_fa = v.transpose(1, 2).contiguous() + o_fa = flash_attn_func(q_fa, k_fa, v_fa, softmax_scale=scale) + return o_fa.transpose(1, 2).contiguous() # back to (B, H, N, D) + + +def torch_sdpa_reference(q, k, v, scale): + """PyTorch's scaled_dot_product_attention (auto backend selection).""" + return F.scaled_dot_product_attention(q, k, v, scale=scale) + + +def test_correctness(batch, heads, seq_len, head_dim, dtype=torch.float16): + print(f" B={batch}, H={heads}, N={seq_len}, D={head_dim}", end="") + scale = 1.0 / math.sqrt(head_dim) + + q = torch.randn( + batch, heads, seq_len, head_dim, dtype=dtype, device="cuda:0" + ) + k = torch.randn( + batch, heads, seq_len, head_dim, dtype=dtype, device="cuda:0" + ) + v = torch.randn( + batch, heads, seq_len, head_dim, dtype=dtype, device="cuda:0" + ) + + # Reference: FlashAttention-2 + ref = flash_attn_reference(q, k, v, scale) + + # ARK standard MHA + ark.init() + k_t = k.transpose(-2, -1).contiguous() + mha = MultiHeadAttention(head_dim) + ark_out = mha( + ark.Tensor.from_torch(q), + ark.Tensor.from_torch(k_t), + ark.Tensor.from_torch(v), + ) + with ark.Runtime() as rt: + rt.launch() + rt.run() + result = ark_out.to_torch() + + diff = (result - ref).abs().max().item() + atol = 5e-2 if dtype == torch.float16 else 1e-1 + ok = diff < atol + print(f" diff={diff:.4f} {'PASS' if ok else 'FAIL'}") + return ok + + +def bench_one(label, run_fn, num_warmup=10, num_iter=50): + """Benchmark helper: warmup, then time num_iter iterations.""" + for _ in range(num_warmup): + run_fn() + torch.cuda.synchronize() + start = time.time() + for _ in range(num_iter): + run_fn() + torch.cuda.synchronize() + ms = (time.time() - start) / num_iter * 1000 + return ms + + +def run_benchmark(batch, heads, seq_len, head_dim, dtype=torch.float16): + scale = 1.0 / math.sqrt(head_dim) + + q = torch.randn( + batch, heads, seq_len, head_dim, dtype=dtype, device="cuda:0" + ) + k = torch.randn( + batch, heads, seq_len, head_dim, dtype=dtype, device="cuda:0" + ) + v = torch.randn( + batch, heads, seq_len, head_dim, dtype=dtype, device="cuda:0" + ) + k_t = k.transpose(-2, -1).contiguous() + + # --- FlashAttention-2 (Tri Dao) --- + q_fa = q.transpose(1, 2).contiguous() + k_fa = k.transpose(1, 2).contiguous() + v_fa = v.transpose(1, 2).contiguous() + + flash_ms = bench_one( + "FlashAttn2", + lambda: flash_attn_func(q_fa, k_fa, v_fa, softmax_scale=scale), + ) + + # --- PyTorch SDPA --- + sdpa_ms = bench_one( + "SDPA", + lambda: F.scaled_dot_product_attention(q, k, v, scale=scale), + ) + + # --- ARK Vanilla --- + ark.init() + mha = MultiHeadAttention(head_dim) + ark_out = mha( + ark.Tensor.from_torch(q), + ark.Tensor.from_torch(k_t), + ark.Tensor.from_torch(v), + ) + with ark.Runtime() as rt: + rt.launch() + vanilla_ms = bench_one("ARK", lambda: rt.run(iter=1), num_warmup=5) + + # --- ARK Optimized (fused softmax) --- + ark.init() + mha_opt = MultiHeadAttentionOptimized(head_dim, seq_len) + ark_out2 = mha_opt( + ark.Tensor.from_torch(q), + ark.Tensor.from_torch(k_t), + ark.Tensor.from_torch(v), + ) + with ark.Runtime() as rt: + rt.launch() + opt_ms = bench_one("ARK-Opt", lambda: rt.run(iter=1), num_warmup=5) + + print( + f" B={batch} H={heads:2d} N={seq_len:4d} D={head_dim:3d} " + f"FlashAttn2={flash_ms:.3f}ms SDPA={sdpa_ms:.3f}ms " + f"ARK={vanilla_ms:.3f}ms ARK-Opt={opt_ms:.3f}ms " + f"(Opt/Flash={opt_ms/flash_ms:.2f}x)" + ) + return flash_ms, sdpa_ms, vanilla_ms, opt_ms + + +if __name__ == "__main__": + print("=" * 70) + print("Correctness: ARK MHA vs FlashAttention-2") + print("=" * 70) + all_pass = True + for B, H, N, D in [ + (1, 1, 256, 128), + (1, 4, 256, 128), + (2, 8, 256, 128), + (1, 1, 512, 128), + ]: + all_pass &= test_correctness(B, H, N, D) + + if not all_pass: + print("\nSome tests FAILED!") + sys.exit(1) + print("\nAll correctness tests PASSED!") + + print() + print("=" * 70) + print("Performance: ARK vs FlashAttention-2 vs PyTorch SDPA") + print("=" * 70) + for B, H, N, D in [ + (1, 1, 256, 128), + (1, 4, 256, 128), + (1, 8, 256, 128), + (1, 1, 512, 128), + (1, 4, 512, 128), + (1, 8, 512, 128), + ]: + run_benchmark(B, H, N, D) diff --git a/examples/tutorial/module_tutorial.py b/examples/tutorial/module_tutorial.py index b3bac67ea..af395869e 100644 --- a/examples/tutorial/module_tutorial.py +++ b/examples/tutorial/module_tutorial.py @@ -2,52 +2,32 @@ # Licensed under the MIT license. import torch -import numpy as np import torch.nn as nn import ark # Define the parameters of the model batch_size = 1 -seq_len = 64 +seq_len = 128 d_model = 512 d_ff = 2048 -def convert_state_dict(state_dict: dict, type="numpy"): - """ - Convert the state_dict of a module to np.ndarray or torch.Tensor type - """ - new_state_dict = {} - for key in state_dict: - if type == "torch": - new_state_dict[key] = torch.from_numpy(state_dict[key]) - elif type == "numpy": - new_state_dict[key] = state_dict[key].numpy() - return new_state_dict - - class SubModuleARK(ark.Module): - def __init__(self): + def __init__(self, weight_2): super(SubModuleARK, self).__init__() - # Define the parameters of the submodule - self.weight_2 = ark.parameter([d_ff, d_model], ark.fp32) + self.weight_2 = ark.Tensor.from_torch(weight_2) def forward(self, inputs): - # Perform the forward pass of the submodule - middle_result1 = ark.matmul(inputs, self.weight_2) - return middle_result1 + return ark.matmul(inputs, self.weight_2) class TestModelARK(ark.Module): - def __init__(self): + def __init__(self, weight_1, weight_2): super(TestModelARK, self).__init__() - # Define the parameters of the module - self.weight_1 = ark.parameter([d_model, d_ff], ark.fp32) - # Create a submodule of the module - self.submodule = SubModuleARK() + self.weight_1 = ark.Tensor.from_torch(weight_1) + self.submodule = SubModuleARK(weight_2) def forward(self, inputs): - # Perform the forward pass of the model output = ark.matmul(inputs, self.weight_1) output = ark.relu(output) output = self.submodule(output) @@ -56,103 +36,53 @@ def forward(self, inputs): return output -# Use pytorch to define the same model -class SubModulePytorch(nn.Module): - def __init__(self): - super(SubModulePytorch, self).__init__() - self.weight_2 = nn.Parameter(torch.ones(d_ff, d_model)) - - def forward(self, inputs): - middle_result1 = torch.matmul(inputs, self.weight_2) - return middle_result1 - - class TestModelPytorch(nn.Module): def __init__(self): super(TestModelPytorch, self).__init__() - # Define the parameters of the module - self.weight_1 = nn.Parameter(torch.ones(d_model, d_ff)) - # Create a submodule of the module - self.submodule = SubModulePytorch() + self.weight_1 = nn.Parameter(torch.ones(d_model, d_ff, device="cuda:0")) + self.submodule_weight_2 = nn.Parameter( + torch.ones(d_ff, d_model, device="cuda:0") + ) + self.layernorm = nn.LayerNorm(d_model, device="cuda:0") def forward(self, inputs): - # Perform the forward pass of the model output = torch.matmul(inputs, self.weight_1) output = nn.ReLU()(output) - output = self.submodule(output) - output = nn.LayerNorm(d_model)(output + inputs) + output = torch.matmul(output, self.submodule_weight_2) + output = self.layernorm(output + inputs) return output -# An example of using the ARK module def module_test(): - # Create an input tensor - input_tensor = ark.tensor([batch_size, seq_len, d_model], ark.fp32) - - # Create an ARK module - ark_model = TestModelARK() - - # Perform the forward pass - output_tensor = ark_model(input_tensor) - - # Initialize the ARK runtime - runtime = ark.Runtime() - - # Launch the ARK runtime - runtime.launch() - - # Initialize the input tensor - input_tensor_host = ( - (np.random.rand(batch_size, seq_len, d_model) - 0.5) * 0.1 - ).astype(np.float32) - input_tensor.from_numpy(input_tensor_host) - - # Initialize the parameters of the ARK module using numpy state_dict - weight_1_host = ((np.random.rand(d_model, d_ff) - 0.5) * 0.1).astype( - np.float32 + # Create torch tensors for input and weights + input_tensor = ( + torch.randn( + batch_size, seq_len, d_model, dtype=torch.float32, device="cuda:0" + ) + * 0.1 ) - weight_2_host = ((np.random.rand(d_ff, d_model) - 0.5) * 0.1).astype( - np.float32 + weight_1 = ( + torch.randn(d_model, d_ff, dtype=torch.float32, device="cuda:0") * 0.1 + ) + weight_2 = ( + torch.randn(d_ff, d_model, dtype=torch.float32, device="cuda:0") * 0.1 ) - state_dict = { - "weight_1": weight_1_host, - "submodule.weight_2": weight_2_host, - } - - # Load model parameters - ark_model.load_state_dict(state_dict) - - # Run the ARK model - runtime.run() - - # Copy the ARK module output tensor from device to host - output_tensor_host = output_tensor.to_numpy() - # For simplicity, we use float32 to compute the ground truth using pytorch - input_tensor_host_float32 = input_tensor_host.astype(np.float32) - torch_input = torch.from_numpy(input_tensor_host_float32) + # Build and evaluate the ARK model + ark_model = TestModelARK(weight_1, weight_2) + output = ark_model(input_tensor).eval() + # Compute PyTorch ground truth torch_model = TestModelPytorch() + torch_model.load_state_dict( + {"weight_1": weight_1, "submodule_weight_2": weight_2}, + strict=False, + ) + gt = torch_model(input_tensor) - # Convert the numpy.ndarray type state_dict to torch.Tensor type state_dict - torch_state_dict = convert_state_dict(state_dict, "torch") - # Load model parameters - torch_model.load_state_dict(torch_state_dict) - - # Run the pytorch model to compute the ground truth - gt = torch_model(torch_input).detach().numpy() - - # Test if the result is correct - max_error = np.max(np.abs(output_tensor_host - gt)) - avg_error = np.mean(np.abs(output_tensor_host - gt)) - - # Use ark_model.state_dict() to get the state_dict of the ARK module - # Note that the state_dict of the ARK module might be modified at the ARK kernel launch time - ark_state_dict = ark_model.state_dict() - - # Test if the parameters are the same - for k, v in state_dict.items(): - np.testing.assert_allclose(v, ark_state_dict[k]) + # Compare results + max_error = (output - gt).abs().max().item() + avg_error = (output - gt).abs().mean().item() print("ARK module test") print( @@ -165,7 +95,7 @@ def module_test(): "d_ff:", d_ff, ) - print("max error: ", max_error, "avg error: ", avg_error) + print("max error:", max_error, "avg error:", avg_error) if __name__ == "__main__": diff --git a/examples/tutorial/planner_tutorial.py b/examples/tutorial/planner_tutorial.py index 8702f8929..a0a88462a 100644 --- a/examples/tutorial/planner_tutorial.py +++ b/examples/tutorial/planner_tutorial.py @@ -35,25 +35,22 @@ def forward(self, input): "NumTasks": 65536, }, ): - with ark.PlannerContext(config={"ImplType": "WarpWise"}): + with ark.PlannerContext( + config={"ImplType": "WarpWise", "Tile": [1, 1]} + ): max = ark.reduce_max(input, axis=-1) with ark.PlannerContext(config={"Tile": [1, 2048]}): output = ark.sub(input, max) output = ark.exp(output) - with ark.PlannerContext(config={"ImplType": "WarpWise"}): + with ark.PlannerContext( + config={"ImplType": "WarpWise", "Tile": [1, 1]} + ): sum = ark.reduce_sum(output, axis=-1) with ark.PlannerContext(config={"Tile": [1, 2048]}): output = ark.div(output, sum) return output -def eval(tensor: ark.Tensor): - with ark.Runtime() as rt: - rt.launch() - rt.run() - return tensor.to_torch() - - def perf(num_iter: int = 1000): with ark.Runtime() as rt: rt.launch() @@ -73,7 +70,7 @@ def perf(num_iter: int = 1000): output = Softmax()(ark.Tensor.from_torch(input)) - if torch.allclose(eval(output), F.softmax(input, dim=-1), atol=1e-5): + if torch.allclose(output.eval(), F.softmax(input, dim=-1), atol=1e-5): print("Correct result") else: print("Incorrect result") diff --git a/examples/tutorial/planner_tutorial_2.py b/examples/tutorial/planner_tutorial_2.py index eb9998541..e949eeb56 100644 --- a/examples/tutorial/planner_tutorial_2.py +++ b/examples/tutorial/planner_tutorial_2.py @@ -1,8 +1,8 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT license. -import numpy as np import ark +import torch def quickstart_tutorial(): @@ -10,9 +10,8 @@ def quickstart_tutorial(): ark.init() M, N, K = 1024, 1024, 1024 - m0 = ark.tensor([M, K], ark.fp16) - m1 = ark.tensor([N, K], ark.fp16) - m2 = ark.tensor([M, K], ark.fp16) + m0 = torch.randn(M, K, dtype=torch.float16, device="cuda:0") * 0.01 + m1 = torch.randn(N, K, dtype=torch.float16, device="cuda:0") * 0.01 # stage 1: matmul with ark.PlannerContext(processor_range=[0, 108]): @@ -20,6 +19,7 @@ def quickstart_tutorial(): t0 = ark.matmul(m0, m1, transpose_other=True) # stage 2: parallel copy and matmul + m2 = ark.tensor([M, K], ark.fp16) with ark.PlannerContext(processor_range=[0, 54]): # Use SMs 0~53 t1 = ark.matmul(t0, m1) @@ -27,27 +27,20 @@ def quickstart_tutorial(): # Use SMs 54~107 t2 = ark.copy(input=t0, output=m2) - # Initialize the ARK runtime - runtime = ark.Runtime() - - # Launch the ARK runtime - runtime.launch() - - # Initialize - m0_host = np.random.rand(M, K).astype(np.float16) * 0.01 - m0.from_numpy(m0_host) - m1_host = np.random.rand(N, K).astype(np.float16) * 0.01 - m1.from_numpy(m1_host) - - # Run the ARK program - runtime.run() + # Evaluate and check results + with ark.Runtime() as rt: + rt.launch() + rt.run() + t0_result = t0.to_torch() + t1_result = t1.to_torch() + t2_result = t2.to_torch() # Check the matmul result - res_host = np.matmul(np.matmul(m0_host, m1_host.T), m1_host) - np.testing.assert_allclose(t1.to_numpy(), res_host, rtol=1e-3, atol=1e-3) + expected = torch.matmul(torch.matmul(m0, m1.T), m1) + torch.testing.assert_close(t1_result, expected, rtol=1e-3, atol=1e-3) # Check the copy result - np.testing.assert_equal(t2.to_numpy(), t0.to_numpy()) + torch.testing.assert_close(t2_result, t0_result, atol=0, rtol=0) print("Successful!") diff --git a/examples/tutorial/quickstart_tutorial.py b/examples/tutorial/quickstart_tutorial.py index 1fce51452..f36d31498 100644 --- a/examples/tutorial/quickstart_tutorial.py +++ b/examples/tutorial/quickstart_tutorial.py @@ -1,7 +1,7 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT license. -import numpy as np +import torch import ark @@ -10,35 +10,16 @@ def quickstart_tutorial(): ark.init() M, N = 64, 64 - # Create an input tensor - input_tensor = ark.tensor([M, N], ark.fp16) - # Create another tensor - other_tensor = ark.tensor([M, N], ark.fp16) + # Create input tensors on GPU + input_tensor = torch.randn(M, N, dtype=torch.float16, device="cuda:0") + other_tensor = torch.randn(M, N, dtype=torch.float16, device="cuda:0") - # Add the two tensors - output_tensor = ark.add(input_tensor, other_tensor) + # Add the two tensors using ARK and evaluate + output = ark.add(input_tensor, other_tensor).eval() - # Initialize the ARK runtime - runtime = ark.Runtime() - - # Launch the ARK runtime - runtime.launch() - - # Initialize the input and other tensor with random values - input_tensor_host = np.random.rand(M, N).astype(np.float16) - input_tensor.from_numpy(input_tensor_host) - other_tensor_host = np.random.rand(M, N).astype(np.float16) - other_tensor.from_numpy(other_tensor_host) - - # Run the ARK program - runtime.run() - - # Copy the output tensor from device memory to host memory, if dst is - # None, a new numpy array of the same shape as the src tensor will be returned - output_tensor_host = output_tensor.to_numpy() # Check if the output tensor is equal to the sum of the input and other tensor - np.testing.assert_allclose( - output_tensor_host, input_tensor_host + other_tensor_host + torch.testing.assert_close( + output, input_tensor + other_tensor, atol=0, rtol=0 ) print("Quickstart tutorial is successful!") diff --git a/examples/tutorial/torch_tutorial.py b/examples/tutorial/torch_tutorial.py new file mode 100644 index 000000000..3f1f89d1b --- /dev/null +++ b/examples/tutorial/torch_tutorial.py @@ -0,0 +1,28 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +""" +Tutorial: Using ARK with PyTorch tensors. + +Shows how to use eval() to run ARK computation on torch tensors +and get torch tensor results directly. +""" + +import ark +import torch + +ark.init() + +# Create torch tensors on GPU +x = torch.ones(64, dtype=torch.float32, device="cuda:0") * 2 +y = torch.ones(64, dtype=torch.float32, device="cuda:0") * 3 + +# Run ARK computation and get result as a torch tensor +result = ark.add(x, y).eval() +print(f"x + y = {result}") # tensor([5., 5., ...]) + +# Run again with different values +x = torch.ones(64, dtype=torch.float32, device="cuda:0") * 10 +y = torch.ones(64, dtype=torch.float32, device="cuda:0") * 20 +result = ark.add(x, y).eval() +print(f"10 + 20 = {result}") # tensor([30., 30., ...]) diff --git a/python/ark/__init__.py b/python/ark/__init__.py index 63480262c..31aaa5cef 100644 --- a/python/ark/__init__.py +++ b/python/ark/__init__.py @@ -7,7 +7,7 @@ os.environ["ARK_ROOT"] = os.path.abspath(os.path.dirname(__file__)) from .core import version -from .model import Model +from .model import Model, set_model, current_model, use_model __version__ = version() diff --git a/python/ark/data_type.py b/python/ark/data_type.py index fa2b2c064..0caac294a 100644 --- a/python/ark/data_type.py +++ b/python/ark/data_type.py @@ -10,6 +10,7 @@ "DataType", "fp16", "fp32", + "bf16", "int32", "uint32", "int8", diff --git a/python/ark/model.py b/python/ark/model.py index e103d4083..a1fd37c49 100644 --- a/python/ark/model.py +++ b/python/ark/model.py @@ -2,18 +2,19 @@ # Licensed under the MIT license. from typing import NewType +from contextlib import contextmanager from . import log from .core import CoreModel -__all__ = ["Model"] +__all__ = ["Model", "set_model", "current_model", "use_model"] ModelState = NewType("ModelState", None) class Model(CoreModel): @staticmethod - def get_model(): + def get_model() -> "Model": """ Get the underlying model. """ @@ -115,3 +116,54 @@ class ModelState: rank: int = 0 world_size: int = 1 device_id: int = 0 + + +def set_model(model: Model) -> None: + """Set the current active model. All subsequent ARK ops will be added to this model. + + Similar to ``torch.cuda.set_stream()``. + + Args: + model: The model to set as the current active model. + """ + ModelState.model = model + + +def current_model() -> Model: + """Return the current active model, creating one if none exists. + + Similar to ``torch.cuda.current_stream()``. + + Returns: + The current active model. + """ + return Model.get_model() + + +@contextmanager +def use_model(model: Model): + """Context manager to temporarily switch the active model. + + All ARK ops within the ``with`` block are added to the given model. + On exit, the previous model is restored. + + Similar to ``torch.cuda.stream()``. + + Example:: + + m1 = ark.Model() + m2 = ark.Model() + with ark.use_model(m1): + a = ark.add(x, y) # added to m1 + with ark.use_model(m2): + b = ark.mul(x, y) # added to m2 + + Args: + model: The model to use within the context. + """ + prev = ModelState.model + ModelState.model = model + try: + yield model + finally: + ModelState.model = prev diff --git a/python/ark/ops.py b/python/ark/ops.py index c0eefa2e0..68c3846c1 100644 --- a/python/ark/ops.py +++ b/python/ark/ops.py @@ -10,6 +10,13 @@ from . import log +def _ensure_ark(x): + """Convert a torch.Tensor to an ARK Tensor if needed.""" + if not _no_torch and isinstance(x, torch.Tensor): + return Tensor.from_torch(x) + return x + + __all__ = [ "tensor", "parameter", @@ -54,12 +61,14 @@ def is_list_or_tuple(obj): def add( - input: Union[Tensor, float], - other: Union[Tensor, float], + input: Union[Tensor, float, "torch.Tensor"], + other: Union[Tensor, float, "torch.Tensor"], output: Tensor = NullTensor, name: str = "add", ) -> Union[Tensor, float]: """ """ + input = _ensure_ark(input) + other = _ensure_ark(other) if isinstance(input, Tensor) and isinstance(other, Tensor): a = input._tensor b = other._tensor @@ -81,12 +90,13 @@ def add( def cast( - input: Tensor, + input: Union[Tensor, "torch.Tensor"], dtype: DataType, output: Tensor = NullTensor, name: str = "cast", ) -> Tensor: """ """ + input = _ensure_ark(input) if output is not NullTensor: output = output._tensor return Tensor( @@ -107,11 +117,12 @@ def constant( def copy( - input: Union[Tensor, float], + input: Union[Tensor, float, "torch.Tensor"], output: Tensor = NullTensor, name: str = "copy", ) -> Tensor: """ """ + input = _ensure_ark(input) if output is not NullTensor: output = output._tensor if isinstance(input, Tensor): @@ -120,12 +131,14 @@ def copy( def div( - input: Tensor, - other: Union[Tensor, float], + input: Union[Tensor, "torch.Tensor"], + other: Union[Tensor, float, "torch.Tensor"], output: Tensor = NullTensor, name: str = "div", ) -> Tensor: """ """ + input = _ensure_ark(input) + other = _ensure_ark(other) if output is not NullTensor: output = output._tensor if isinstance(other, Tensor): @@ -134,12 +147,14 @@ def div( def embedding( - input: Tensor, - weight: Tensor, + input: Union[Tensor, "torch.Tensor"], + weight: Union[Tensor, "torch.Tensor"], output: Tensor = NullTensor, name: str = "embedding", ) -> Tensor: """ """ + input = _ensure_ark(input) + weight = _ensure_ark(weight) if output is not NullTensor: output = output._tensor return Tensor( @@ -148,31 +163,36 @@ def embedding( def exp( - input: Tensor, + input: Union[Tensor, "torch.Tensor"], output: Tensor = NullTensor, name: str = "exp", ) -> Tensor: """ """ + input = _ensure_ark(input) if output is not NullTensor: output = output._tensor return Tensor(Model.get_model().exp(input._tensor, output, name)) def gelu( - input: Tensor, + input: Union[Tensor, "torch.Tensor"], output: Tensor = NullTensor, name: str = "gelu", ) -> Tensor: """ """ + input = _ensure_ark(input) if output is not NullTensor: output = output._tensor return Tensor(Model.get_model().gelu(input._tensor, output, name)) def identity( - input: Tensor, deps: List[Tensor] = [], name: str = "identity" + input: Union[Tensor, "torch.Tensor"], + deps: List[Tensor] = [], + name: str = "identity", ) -> Tensor: """ """ + input = _ensure_ark(input) dep_tensors = [] for dep in deps: if not isinstance(dep, Tensor): @@ -182,14 +202,16 @@ def identity( def matmul( - input: Tensor, - other: Tensor, + input: Union[Tensor, "torch.Tensor"], + other: Union[Tensor, "torch.Tensor"], output: Tensor = NullTensor, transpose_input: bool = False, transpose_other: bool = False, name: str = "matmul", ) -> Tensor: """ """ + input = _ensure_ark(input) + other = _ensure_ark(other) if output is not NullTensor: output = output._tensor return Tensor( @@ -205,12 +227,14 @@ def matmul( def mul( - input: Tensor, - other: Union[Tensor, float], + input: Union[Tensor, "torch.Tensor"], + other: Union[Tensor, float, "torch.Tensor"], output: Tensor = NullTensor, name: str = "mul", ) -> Tensor: """ """ + input = _ensure_ark(input) + other = _ensure_ark(other) if output is not NullTensor: output = output._tensor if isinstance(other, Tensor): @@ -218,8 +242,9 @@ def mul( return Tensor(Model.get_model().mul(input._tensor, other, output, name)) -def noop(input: Tensor, name: str = "noop"): +def noop(input: Union[Tensor, "torch.Tensor"], name: str = "noop"): """ """ + input = _ensure_ark(input) Model.get_model().noop(input._tensor, name) @@ -253,13 +278,14 @@ def placeholder( def reduce_max( - input: Tensor, + input: Union[Tensor, "torch.Tensor"], axis: int, keepdims: bool = True, output: Tensor = NullTensor, name: str = "reduce_max", ) -> Tensor: """ """ + input = _ensure_ark(input) if output is not NullTensor: output = output._tensor return Tensor( @@ -270,13 +296,14 @@ def reduce_max( def reduce_mean( - input: Tensor, + input: Union[Tensor, "torch.Tensor"], axis: int, keepdims: bool = True, output: Tensor = NullTensor, name: str = "reduce_mean", ) -> Tensor: """ """ + input = _ensure_ark(input) if output is not NullTensor: output = output._tensor return Tensor( @@ -287,13 +314,14 @@ def reduce_mean( def reduce_sum( - input: Tensor, + input: Union[Tensor, "torch.Tensor"], axis: int, keepdims: bool = True, output: Tensor = NullTensor, name: str = "reduce_sum", ) -> Tensor: """ """ + input = _ensure_ark(input) if output is not NullTensor: output = output._tensor return Tensor( @@ -304,18 +332,19 @@ def reduce_sum( def relu( - input: Tensor, + input: Union[Tensor, "torch.Tensor"], output: Tensor = NullTensor, name: str = "relu", ) -> Tensor: """ """ + input = _ensure_ark(input) if output is not NullTensor: output = output._tensor return Tensor(Model.get_model().relu(input._tensor, output, name)) def reshape( - input: Tensor, + input: Union[Tensor, "torch.Tensor"], shape: Iterable[int], allowzero: bool = False, name: str = "reshape", @@ -338,6 +367,7 @@ def reshape( "shape should be a list or tuple of integers" ) # only support tensors with up to 4 dimensions + input = _ensure_ark(input) if len(shape) > 4: raise log.InvalidUsageError( "Only support tensors with up to 4 dimensions" @@ -348,12 +378,14 @@ def reshape( def rope( - input: Tensor, - other: Tensor, + input: Union[Tensor, "torch.Tensor"], + other: Union[Tensor, "torch.Tensor"], output: Tensor = NullTensor, name: str = "rope", ) -> Tensor: """ """ + input = _ensure_ark(input) + other = _ensure_ark(other) if output is not NullTensor: output = output._tensor return Tensor( @@ -362,20 +394,25 @@ def rope( def rsqrt( - input: Tensor, + input: Union[Tensor, "torch.Tensor"], output: Tensor = NullTensor, name: str = "rsqrt", ) -> Tensor: """ """ + input = _ensure_ark(input) if output is not NullTensor: output = output._tensor return Tensor(Model.get_model().rsqrt(input._tensor, output, name)) def sharding( - input: Tensor, axis: int, dim_per_shard: int, name: str = "sharding" + input: Union[Tensor, "torch.Tensor"], + axis: int, + dim_per_shard: int, + name: str = "sharding", ) -> List[Tensor]: """ """ + input = _ensure_ark(input) _tensor_list = Model.get_model().sharding( input._tensor, axis, dim_per_shard, name ) @@ -383,34 +420,38 @@ def sharding( def sigmoid( - input: Tensor, + input: Union[Tensor, "torch.Tensor"], output: Tensor = NullTensor, name: str = "sigmoid", ) -> Tensor: """ """ + input = _ensure_ark(input) if output is not NullTensor: output = output._tensor return Tensor(Model.get_model().sigmoid(input._tensor, output, name)) def sqrt( - input: Tensor, + input: Union[Tensor, "torch.Tensor"], output: Tensor = NullTensor, name: str = "sqrt", ) -> Tensor: """ """ + input = _ensure_ark(input) if output is not NullTensor: output = output._tensor return Tensor(Model.get_model().sqrt(input._tensor, output, name)) def sub( - input: Tensor, - other: Union[Tensor, float], + input: Union[Tensor, "torch.Tensor"], + other: Union[Tensor, float, "torch.Tensor"], output: Tensor = NullTensor, name: str = "sub", ) -> Tensor: """ """ + input = _ensure_ark(input) + other = _ensure_ark(other) if output is not NullTensor: output = output._tensor if isinstance(other, Tensor): @@ -436,12 +477,13 @@ def tensor( def transpose( - input: Tensor, + input: Union[Tensor, "torch.Tensor"], perm: Iterable[int], output: Tensor = NullTensor, name: str = "transpose", ) -> Tensor: """ """ + input = _ensure_ark(input) if output is not NullTensor: output = output._tensor if not is_list_or_tuple(perm): @@ -493,6 +535,7 @@ def recv( Model.get_model().recv(output._tensor, remote_rank, tag, name) ) + ################################################################################ diff --git a/python/ark/planner.py b/python/ark/planner.py index 0ed9113e1..609282182 100644 --- a/python/ark/planner.py +++ b/python/ark/planner.py @@ -233,8 +233,10 @@ def __exit__(self, exc_type, exc_value, exc_tb): class Planner(CorePlanner): - def __init__(self, device_id: int = 0): - compressed = Model.get_model().compress() + def __init__(self, device_id: int = 0, model: "Model" = None): + if model is None: + model = Model.get_model() + compressed = model.compress() super().__init__(compressed, device_id) def install_config_rule(self, rule: Callable[[str, str], str]): diff --git a/python/ark/tensor.py b/python/ark/tensor.py index 216318b27..cd83429fc 100644 --- a/python/ark/tensor.py +++ b/python/ark/tensor.py @@ -38,6 +38,8 @@ def __init__( self._tensor: CoreTensor = _tensor self.initializer: Initializer = initializer self.requires_grad: bool = requires_grad + # Track which model this tensor belongs to for eval() + self._model: Model = Model.get_model() def __hash__(self): return self._tensor.id() @@ -283,6 +285,37 @@ def from_torch(tensor: torch.Tensor) -> "Tensor": ark_tensor.__torch_buffer__ = tensor return ark_tensor + def eval(self, stream: "torch.cuda.Stream" = None) -> torch.Tensor: + """ + Evaluate the ARK graph that produces this tensor and return the result + as a torch tensor. Creates a runtime, compiles the graph, runs it, + and returns the output via DLPack (zero-copy). + + Multiple independent ARK graphs can coexist — each tensor tracks + which model it belongs to, and eval() only runs that model. + The executor skips GPU recompilation if the plan hasn't changed. + + Args: + stream: Optional torch CUDA stream to run on. + + Returns: + torch.Tensor: The result tensor on the same device. + """ + if _no_torch: + raise log.SystemError("torch is not available") + from .runtime import Runtime + from .planner import Planner + + plan = Planner(model=self._model).plan() + cuda_stream = stream.cuda_stream if stream is not None else 0 + + with Runtime() as rt: + rt.launch(plan=plan, stream=cuda_stream, loop_mode=False) + rt.run() + result = self.to_torch() + + return result + def copy( self, data: Union[np.ndarray, torch.Tensor], stream: int = 0 ) -> "Tensor": diff --git a/python/unittest/common.py b/python/unittest/common.py index 0c385e89a..0bb866d3b 100644 --- a/python/unittest/common.py +++ b/python/unittest/common.py @@ -2,6 +2,7 @@ # Licensed under the MIT license. import pytest +import functools import ark @@ -19,6 +20,7 @@ def decorator(test_func): test_func ) + @functools.wraps(test_func) def wrapper(*args, **kwargs): ark.init() test_func(*args, **kwargs) diff --git a/python/unittest/ops/conftest.py b/python/unittest/ops/conftest.py new file mode 100644 index 000000000..5073ab08c --- /dev/null +++ b/python/unittest/ops/conftest.py @@ -0,0 +1,32 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +""" +Shared fixtures and helpers for ARK op numerical tests. +""" + +import sys +import os +import pytest + +# Add parent directory to path so `common` is importable +sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) +from common import ark + +try: + import torch + + _no_torch = False +except ImportError: + _no_torch = True + +# Skip entire ops/ directory if torch is unavailable +pytestmark = pytest.mark.skipif(_no_torch, reason="torch not available") + +DEVICE = "cuda:0" + + +@pytest.fixture(autouse=True) +def _ark_init(): + """Reset ARK state before each test so tests don't share models.""" + ark.init() diff --git a/python/unittest/ops/test_arithmetic.py b/python/unittest/ops/test_arithmetic.py new file mode 100644 index 000000000..ac917747f --- /dev/null +++ b/python/unittest/ops/test_arithmetic.py @@ -0,0 +1,119 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +"""Numerical tests for arithmetic ops: add, sub, mul, div (tensor and scalar).""" + +import pytest +import torch +from conftest import ark, DEVICE + + +@pytest.mark.parametrize( + "dtype", [torch.float32, torch.float16, torch.bfloat16] +) +def test_add(dtype): + a = torch.randn(8192, dtype=dtype, device=DEVICE) + b = torch.randn(8192, dtype=dtype, device=DEVICE) + assert torch.allclose(ark.add(a, b).eval(), a + b, atol=0, rtol=0) + + +def test_add_broadcast(): + a = torch.randn(4, 1024, dtype=torch.float16, device=DEVICE) + b = torch.randn(1, 1024, dtype=torch.float16, device=DEVICE) + assert torch.allclose(ark.add(a, b).eval(), a + b, atol=0, rtol=0) + + +def test_add_broadcast_3d(): + a = torch.randn(3, 1, 1024, dtype=torch.float16, device=DEVICE) + b = torch.randn(1, 4, 1, dtype=torch.float16, device=DEVICE) + assert torch.allclose(ark.add(a, b).eval(), a + b, atol=0, rtol=0) + + +@pytest.mark.parametrize("dtype", [torch.float32]) +def test_sub(dtype): + a = torch.randn(8192, dtype=dtype, device=DEVICE) + b = torch.randn(8192, dtype=dtype, device=DEVICE) + assert torch.allclose(ark.sub(a, b).eval(), a - b, atol=0, rtol=0) + + +@pytest.mark.parametrize("dtype", [torch.float32, torch.float16]) +def test_mul(dtype): + a = torch.randn(8192, dtype=dtype, device=DEVICE) + b = torch.randn(8192, dtype=dtype, device=DEVICE) + assert torch.allclose(ark.mul(a, b).eval(), a * b, atol=0, rtol=0) + + +def test_mul_broadcast(): + a = torch.randn(4, 1024, dtype=torch.float16, device=DEVICE) + b = torch.randn(1, 1024, dtype=torch.float16, device=DEVICE) + assert torch.allclose(ark.mul(a, b).eval(), a * b, atol=0, rtol=0) + + +def test_mul_broadcast_3d(): + a = torch.randn(3, 1, 1024, dtype=torch.float16, device=DEVICE) + b = torch.randn(1, 4, 1, dtype=torch.float16, device=DEVICE) + assert torch.allclose(ark.mul(a, b).eval(), a * b, atol=0, rtol=0) + + +def test_div_fp32(): + a = torch.randn(8192, dtype=torch.float32, device=DEVICE) + b = torch.randn(8192, dtype=torch.float32, device=DEVICE).abs() + 0.01 + assert torch.allclose(ark.div(a, b).eval(), a / b, atol=0, rtol=0) + + +# ─── Scalar operations ────────────────────────────────────────────────────── + +FACTOR = 0.75 + + +@pytest.mark.parametrize( + "dtype", [torch.float32, torch.float16, torch.bfloat16] +) +@pytest.mark.parametrize("shape", [(4, 2, 1), (4, 2, 1024)]) +def test_scalar_mul(dtype, shape): + a = torch.randn(shape, dtype=dtype, device=DEVICE) + assert torch.allclose(ark.mul(a, FACTOR).eval(), a * FACTOR, atol=0, rtol=0) + + +@pytest.mark.parametrize("shape", [(4, 2, 1), (4, 2, 1024)]) +def test_scalar_add(shape): + a = torch.randn(shape, dtype=torch.float16, device=DEVICE) + assert torch.allclose(ark.add(a, FACTOR).eval(), a + FACTOR, atol=0, rtol=0) + + +@pytest.mark.parametrize("shape", [(4, 2, 1), (4, 2, 1024)]) +def test_scalar_sub(shape): + a = torch.randn(shape, dtype=torch.float16, device=DEVICE) + assert torch.allclose(ark.sub(a, FACTOR).eval(), a - FACTOR, atol=0, rtol=0) + + +@pytest.mark.parametrize("shape", [(4, 2, 1), (4, 2, 1024)]) +def test_scalar_div(shape): + a = torch.randn(shape, dtype=torch.float16, device=DEVICE) + assert torch.allclose( + ark.div(a, FACTOR).eval(), a / FACTOR, atol=1e-3, rtol=1e-3 + ) + + +# ─── Constant & scalar copy ───────────────────────────────────────────────── + + +def test_constant_fp16(): + out = ark.constant(7, (4, 2, 50), ark.fp16).eval() + assert (out == 7).all() + + +def test_constant_fp32(): + out = ark.constant(7, (1,), ark.fp32).eval() + assert out.item() == 7.0 + + +def test_copy_scalar_fp16(): + t = torch.zeros(4, 2, 50, dtype=torch.float16, device=DEVICE) + out = ark.copy(7.0, ark.Tensor.from_torch(t)).eval() + assert (out == 7).all() + + +def test_copy_scalar_fp32(): + out = ark.copy(7.0).eval() + assert out.item() == 7.0 diff --git a/python/unittest/ops/test_cast.py b/python/unittest/ops/test_cast.py new file mode 100644 index 000000000..8587cdd2f --- /dev/null +++ b/python/unittest/ops/test_cast.py @@ -0,0 +1,29 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +"""Numerical tests for cast op.""" + +import pytest +import torch +from conftest import ark, DEVICE + + +@pytest.mark.parametrize( + "src_dtype, dst_dtype, ark_dst", + [ + (torch.float16, torch.float32, ark.fp32), + (torch.float32, torch.float16, ark.fp16), + (torch.float32, torch.int32, ark.int32), + (torch.int32, torch.float32, ark.fp32), + (torch.bfloat16, torch.float32, ark.fp32), + (torch.float32, torch.bfloat16, ark.bf16), + ], +) +def test_cast(src_dtype, dst_dtype, ark_dst): + a = torch.randn(4, 2, 1024, dtype=torch.float32, device=DEVICE).to( + src_dtype + ) + result = ark.cast(a, ark_dst).eval() + expected = a.to(dst_dtype) + assert result.dtype == dst_dtype + assert torch.allclose(result, expected, atol=0, rtol=0) diff --git a/python/unittest/ops/test_composite.py b/python/unittest/ops/test_composite.py new file mode 100644 index 000000000..f12194a56 --- /dev/null +++ b/python/unittest/ops/test_composite.py @@ -0,0 +1,33 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +"""Numerical tests for composite ops: softmax, layernorm.""" + +import pytest +import torch +import torch.nn.functional as F +from conftest import ark, DEVICE + + +@pytest.mark.parametrize("dtype", [torch.float32, torch.float16]) +def test_softmax(dtype): + shape = (4, 8, 256) + a = torch.randn(shape, dtype=dtype, device=DEVICE) + result = ark.softmax(a).eval() + expected = F.softmax(a, dim=-1) + atol = 1e-5 if dtype == torch.float32 else 1e-3 + assert torch.allclose( + result, expected, atol=atol, rtol=1e-3 + ), f"max_diff={(result - expected).abs().max()}" + + +def test_layernorm(): + shape = (4, 8, 256) + a = torch.randn(shape, dtype=torch.float32, device=DEVICE) + result = ark.layernorm(a, eps=1e-6).eval() + mean = a.mean(dim=-1, keepdim=True) + var = ((a - mean) ** 2).mean(dim=-1, keepdim=True) + expected = (a - mean) / torch.sqrt(var + 1e-6) + assert torch.allclose( + result, expected, atol=1e-4, rtol=1e-4 + ), f"max_diff={(result - expected).abs().max()}" diff --git a/python/unittest/ops/test_embedding_rope.py b/python/unittest/ops/test_embedding_rope.py new file mode 100644 index 000000000..6c2ede3e4 --- /dev/null +++ b/python/unittest/ops/test_embedding_rope.py @@ -0,0 +1,54 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +"""Numerical tests for embedding and rope ops.""" + +import pytest +import torch +import torch.nn.functional as F +from conftest import ark, DEVICE + + +@pytest.mark.parametrize( + "dtype", [torch.float32, torch.float16, torch.bfloat16] +) +def test_embedding(dtype): + vocab_size, embed_dim = 100, 64 + indices = torch.randint(0, vocab_size, (4, 8), device=DEVICE).to( + torch.int32 + ) + weight = torch.randn(vocab_size, embed_dim, dtype=dtype, device=DEVICE) + result = ark.embedding(indices, weight).eval() + expected = F.embedding(indices, weight) + assert torch.allclose( + result, expected, atol=0, rtol=0 + ), f"max_diff={(result - expected).abs().max()}" + + +@pytest.mark.parametrize( + "dtype", [torch.float32, torch.float16, torch.bfloat16] +) +def test_rope(dtype): + """Test rotary positional embedding against PyTorch complex-multiply reference. + ARK's rope computes element-wise complex multiplication on consecutive pairs: + c[2k] = a[2k]*b[2k] - a[2k+1]*b[2k+1] + c[2k+1] = a[2k]*b[2k+1] + a[2k+1]*b[2k] + """ + shape = (1, 1, 8, 64) + x = torch.randn(shape, dtype=dtype, device=DEVICE) + other = torch.randn(shape, dtype=dtype, device=DEVICE) + result = ark.rope(x, other).eval() + # PyTorch reference: complex multiply on paired elements + a = x.reshape(*shape[:-1], -1, 2) + b = other.reshape(*shape[:-1], -1, 2) + expected = torch.stack( + [ + a[..., 0] * b[..., 0] - a[..., 1] * b[..., 1], + a[..., 0] * b[..., 1] + a[..., 1] * b[..., 0], + ], + dim=-1, + ).reshape(shape) + atol = 1e-5 if dtype == torch.float32 else 5e-2 + assert torch.allclose( + result, expected, atol=atol, rtol=1e-3 + ), f"max_diff={(result - expected).abs().max()}" diff --git a/python/unittest/ops/test_math.py b/python/unittest/ops/test_math.py new file mode 100644 index 000000000..835a1b15c --- /dev/null +++ b/python/unittest/ops/test_math.py @@ -0,0 +1,56 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +"""Numerical tests for unary math ops: exp, gelu, relu, sigmoid, sqrt, rsqrt.""" + +import pytest +import torch +import torch.nn.functional as F +from conftest import ark, DEVICE + + +@pytest.mark.parametrize( + "dtype", [torch.float32, torch.float16, torch.bfloat16] +) +def test_exp(dtype): + a = torch.randn(4, 2, 1024, dtype=dtype, device=DEVICE) + atol = 1e-5 if dtype == torch.float32 else 1e-2 + assert torch.allclose(ark.exp(a).eval(), torch.exp(a), atol=atol, rtol=0) + + +@pytest.mark.parametrize("dtype", [torch.float32, torch.bfloat16]) +def test_gelu(dtype): + a = torch.randn(4, 2, 1024, dtype=dtype, device=DEVICE) + atol = 1e-5 if dtype == torch.float32 else 1e-2 + assert torch.allclose( + ark.gelu(a).eval(), F.gelu(a, approximate="tanh"), atol=atol, rtol=0 + ) + + +@pytest.mark.parametrize( + "dtype", [torch.float32, torch.float16, torch.bfloat16] +) +def test_relu(dtype): + a = torch.randn(4, 2, 1024, dtype=dtype, device=DEVICE) + assert torch.allclose(ark.relu(a).eval(), F.relu(a), atol=0, rtol=0) + + +@pytest.mark.parametrize("dtype", [torch.float32, torch.bfloat16]) +def test_sigmoid(dtype): + a = torch.randn(4, 2, 1024, dtype=dtype, device=DEVICE) + atol = 1e-5 if dtype == torch.float32 else 1e-2 + assert torch.allclose( + ark.sigmoid(a).eval(), torch.sigmoid(a), atol=atol, rtol=0 + ) + + +def test_sqrt_fp32(): + a = torch.rand(4, 2, 1024, dtype=torch.float32, device=DEVICE) + 0.01 + assert torch.allclose(ark.sqrt(a).eval(), torch.sqrt(a), atol=1e-6, rtol=0) + + +def test_rsqrt_fp32(): + a = torch.rand(4, 2, 1024, dtype=torch.float32, device=DEVICE) + 0.01 + assert torch.allclose( + ark.rsqrt(a).eval(), torch.rsqrt(a), atol=1e-4, rtol=0 + ) diff --git a/python/unittest/ops/test_matmul.py b/python/unittest/ops/test_matmul.py new file mode 100644 index 000000000..dfa26e988 --- /dev/null +++ b/python/unittest/ops/test_matmul.py @@ -0,0 +1,67 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +"""Numerical tests for matmul: NN, NT, TN, TT, batched.""" + +import pytest +import torch +from conftest import ark, DEVICE + + +@pytest.mark.parametrize( + "dtype", [torch.float32, torch.float16, torch.bfloat16] +) +def test_matmul_nn(dtype): + M, N, K = 256, 256, 512 + a = torch.randn(M, K, dtype=dtype, device=DEVICE) + b = torch.randn(K, N, dtype=dtype, device=DEVICE) + result = ark.matmul(a, b).eval() + expected = a @ b + atol = 1e-3 if dtype == torch.float32 else 1e-1 + assert torch.allclose( + result, expected, atol=atol, rtol=1e-2 + ), f"max_diff={(result - expected).abs().max()}" + + +def test_matmul_nt(): + M, N, K = 256, 256, 512 + a = torch.randn(M, K, dtype=torch.float16, device=DEVICE) + b = torch.randn(N, K, dtype=torch.float16, device=DEVICE) + result = ark.matmul(a, b, transpose_other=True).eval() + expected = a @ b.t() + assert torch.allclose( + result, expected, atol=1e-1, rtol=1e-2 + ), f"max_diff={(result - expected).abs().max()}" + + +def test_matmul_tn(): + M, N, K = 256, 256, 512 + a = torch.randn(K, M, dtype=torch.float16, device=DEVICE) + b = torch.randn(K, N, dtype=torch.float16, device=DEVICE) + result = ark.matmul(a, b, transpose_input=True).eval() + expected = a.t() @ b + assert torch.allclose( + result, expected, atol=1e-1, rtol=1e-2 + ), f"max_diff={(result - expected).abs().max()}" + + +def test_matmul_tt(): + M, N, K = 256, 256, 512 + a = torch.randn(K, M, dtype=torch.float16, device=DEVICE) + b = torch.randn(N, K, dtype=torch.float16, device=DEVICE) + result = ark.matmul(a, b, transpose_input=True, transpose_other=True).eval() + expected = a.t() @ b.t() + assert torch.allclose( + result, expected, atol=1e-1, rtol=1e-2 + ), f"max_diff={(result - expected).abs().max()}" + + +def test_matmul_batched(): + B, M, N, K = 4, 256, 256, 512 + a = torch.randn(B, M, K, dtype=torch.float16, device=DEVICE) + b = torch.randn(B, K, N, dtype=torch.float16, device=DEVICE) + result = ark.matmul(a, b).eval() + expected = a @ b + assert torch.allclose( + result, expected, atol=3e-1, rtol=1e-2 + ), f"max_diff={(result - expected).abs().max()}" diff --git a/python/unittest/ops/test_reduce.py b/python/unittest/ops/test_reduce.py new file mode 100644 index 000000000..e1b4f9ee6 --- /dev/null +++ b/python/unittest/ops/test_reduce.py @@ -0,0 +1,46 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +"""Numerical tests for reduce ops: reduce_sum, reduce_max, reduce_mean.""" + +import pytest +import torch +from conftest import ark, DEVICE + + +@pytest.mark.parametrize("axis", [0, 1, 2, 3]) +def test_reduce_sum_fp32(axis): + shape = [7, 2, 4, 1024] + a = torch.randn(shape, dtype=torch.float32, device=DEVICE) * 0.1 + result = ark.reduce_sum(a, axis=axis).eval() + expected = torch.sum(a, dim=axis, keepdim=True) + atol = shape[axis] * 1e-5 + assert torch.allclose( + result, expected, atol=atol, rtol=1e-4 + ), f"axis={axis}, max_diff={(result - expected).abs().max()}" + + +@pytest.mark.parametrize("axis", [0, 3]) +def test_reduce_sum_fp16(axis): + shape = [7, 2, 4, 1024] + a = torch.randn(shape, dtype=torch.float16, device=DEVICE) * 0.1 + result = ark.reduce_sum(a, axis=axis).eval() + expected = torch.sum(a, dim=axis, keepdim=True) + atol = shape[axis] * 2e-2 + assert torch.allclose( + result, expected, atol=atol, rtol=1e-2 + ), f"axis={axis}, max_diff={(result - expected).abs().max()}" + + +def test_reduce_max_fp32(): + a = torch.randn(1, 1, 2, 8192, dtype=torch.float32, device=DEVICE) + result = ark.reduce_max(a, axis=-1).eval() + expected = torch.max(a, dim=-1, keepdim=True).values + assert torch.allclose(result, expected, atol=0, rtol=0) + + +def test_reduce_mean_fp32(): + a = torch.randn(1, 1, 2, 8192, dtype=torch.float32, device=DEVICE) * 0.1 + result = ark.reduce_mean(a, axis=-1).eval() + expected = torch.mean(a, dim=-1, keepdim=True) + assert torch.allclose(result, expected, atol=1e-4, rtol=1e-4) diff --git a/python/unittest/ops/test_transpose.py b/python/unittest/ops/test_transpose.py new file mode 100644 index 000000000..d042b67de --- /dev/null +++ b/python/unittest/ops/test_transpose.py @@ -0,0 +1,28 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +"""Numerical tests for transpose op.""" + +import pytest +import torch +from conftest import ark, DEVICE + + +@pytest.mark.parametrize( + "perm, shape", + [ + ([0, 1, 3, 2], [2, 3, 64, 128]), + ([0, 2, 3, 1], [2, 3, 64, 128]), + ([0, 2, 1, 3], [2, 3, 64, 128]), + ], +) +@pytest.mark.parametrize( + "dtype", [torch.float32, torch.float16, torch.bfloat16] +) +def test_transpose(perm, shape, dtype): + a = torch.randn(shape, dtype=dtype, device=DEVICE) + result = ark.transpose(a, perm).eval() + expected = a.permute(perm).contiguous() + assert torch.allclose( + result, expected, atol=0, rtol=0 + ), f"max_diff={(result - expected).abs().max()}" diff --git a/python/unittest/test.py b/python/unittest/test.py index 822fb1f78..716bb5460 100644 --- a/python/unittest/test.py +++ b/python/unittest/test.py @@ -10,3 +10,4 @@ from test_profiler import * from test_runtime import * from test_tensor import * +from test_conversion import * diff --git a/python/unittest/test_conversion.py b/python/unittest/test_conversion.py new file mode 100644 index 000000000..4e9e92292 --- /dev/null +++ b/python/unittest/test_conversion.py @@ -0,0 +1,209 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +from common import ark, pytest_ark +import numpy as np +import pytest +from typing import Callable + +try: + import torch + + _no_torch = False +except ImportError: + _no_torch = True + +# ARK to Torch tests + + +def initialize_tensor(dimensions, dtype): + tensor = ark.tensor(dimensions, dtype) + tensor_host = np.random.rand(*dimensions).astype(dtype.to_numpy()) + return tensor, tensor_host + + +# Test function to validate the integrity of the PyTorch view of the ARK tensor, +# including its data and attributes such as shape and data type. +@pytest_ark(need_torch=True) +@pytest.mark.parametrize("num_dims,size", [(1, 5), (1, 1024), (2, 5), (2, 32)]) +@pytest.mark.parametrize("dtype", [ark.fp16, ark.fp32]) +def test_values_fixed_dims(num_dims: int, size: int, dtype: ark.DataType): + import torch + + dimensions = [size] * num_dims + + input_tensor, input_tensor_host = initialize_tensor(dimensions, dtype) + other_tensor, other_tensor_host = initialize_tensor(dimensions, dtype) + output_tensor = ark.add(input_tensor, other_tensor) + + with ark.Runtime() as rt: + rt.launch() + + input_tensor.from_numpy(input_tensor_host) + other_tensor.from_numpy(other_tensor_host) + + input_view = input_tensor.to_torch() + other_view = other_tensor.to_torch() + output_view = output_tensor.to_torch() + + rt.run() + + input_view_numpy = input_view.cpu().numpy() + other_view_numpy = other_view.cpu().numpy() + output_view_numpy = output_view.cpu().numpy() + + output_tensor_host = output_tensor.to_numpy() + + assert np.allclose(input_tensor_host, input_view_numpy) + assert np.allclose(other_tensor_host, other_view_numpy) + assert np.allclose(output_tensor_host, output_view_numpy) + + +# Function to check if there is a difference between two arrays at a specific index +def check_diff(input_tensor_host, input_view_numpy, value, index): + mask = np.ones(input_tensor_host.shape, dtype=bool) + mask[index] = False + if not np.allclose(input_tensor_host[mask], input_view_numpy[mask]): + print("Difference found at index: ", index) + return False + if input_view_numpy[index] != value: + print(input_view_numpy[index], value) + return False + return True + + +# Test function to check if changes to the torch views are reflected in the original tensors +@pytest_ark(need_torch=True) +@pytest.mark.parametrize("dtype", [ark.fp16, ark.fp32]) +def test_ark_to_torch_aliasing(dtype: ark.DataType): + import torch + + dimensions = [4, 4] + input_tensor, input_tensor_host = initialize_tensor(dimensions, dtype) + other_tensor, other_tensor_host = initialize_tensor(dimensions, dtype) + output_tensor = ark.mul(input_tensor, other_tensor) + + with ark.Runtime() as rt: + rt.launch() + input_tensor.from_numpy(input_tensor_host) + other_tensor.from_numpy(other_tensor_host) + + input_view = input_tensor.to_torch() + other_view = other_tensor.to_torch() + output_view = output_tensor.to_torch() + # make changes to the views + input_view[1, 1] = 20 + other_view[0, 0] = 30 + rt.run() + output_view[3, 0] = 40 + + output_tensor_host = output_tensor.to_numpy() + input_view_numpy = input_view.cpu().numpy() + other_view_numpy = other_view.cpu().numpy() + output_view_numpy = output_view.cpu().numpy() + + # Check if changes to the views are reflected in the original tensors + assert check_diff(input_tensor_host, input_view_numpy, 20, (1, 1)) + assert check_diff(other_tensor_host, other_view_numpy, 30, (0, 0)) + assert check_diff(output_tensor_host, output_view_numpy, 40, (3, 0)) + + +@pytest_ark(need_torch=True) +def test_conversion_torch(): + import torch + + dimensions = [4, 4] + t = ark.constant(7, dimensions) + + with ark.Runtime() as rt: + rt.launch() + + torch_tensor = t.to_torch() + + assert torch_tensor.shape == (4, 4) + assert torch_tensor.dtype == torch.float32 + assert torch_tensor.device.type == "cuda" + assert torch.all(torch_tensor == 0) + + rt.run() + + torch_tensor = t.to_torch() + assert torch.all(torch_tensor == 7) + + +# Torch to ARK tests + +ArkBinOp = Callable[[ark.Tensor, ark.Tensor], ark.Tensor] +TorchBinOp = Callable[..., "torch.Tensor"] +ArkUnOp = Callable[[ark.Tensor], ark.Tensor] +TorchUnOp = Callable[..., "torch.Tensor"] + + +# Verify the accuracy of binary operations involving ARK view tensors +@pytest_ark(need_torch=True) +def test_bin_op(): + import torch + + dtype = torch.float16 + tensor_dims = (2, 3) + input_tensor = torch.randn(tensor_dims, dtype=dtype, device="cuda:0") + other_tensor = torch.randn(tensor_dims, dtype=dtype, device="cuda:0") + expected_output = torch.add(input_tensor, other_tensor).cpu().numpy() + input_ark_view = ark.Tensor.from_torch(input_tensor) + other_ark_view = ark.Tensor.from_torch(other_tensor) + output = ark.add(input_ark_view, other_ark_view) + + with ark.Runtime() as rt: + rt.launch() + rt.run() + output_host = output.to_numpy() + + assert np.allclose(output_host, expected_output) + + +# Verify the accuracy of unary operations involving ARK view tensors +@pytest_ark(need_torch=True) +def test_unary_op(): + import torch + + dtype = torch.float16 + tensor_dims = (3, 3) + input_tensor = torch.randn(tensor_dims, dtype=dtype, device="cuda:0") + expected_output = torch.exp(input_tensor).cpu().numpy() + input_ark_view = ark.Tensor.from_torch(input_tensor) + output = ark.exp(input_ark_view) + + with ark.Runtime() as rt: + rt.launch() + rt.run() + output_host = output.to_numpy() + + assert np.allclose(output_host, expected_output) + + +# Test function to check if changes in torch tensors are reflected in ARK views +@pytest_ark(need_torch=True) +def test_torch_to_ark_aliasing(): + import torch + + dtype = torch.float16 + tensor_dims = (64, 64) + # Initialize a PyTorch tensor + input_tensor = torch.randn(tensor_dims, dtype=dtype, device="cuda:0") + other_tensor = torch.randn(tensor_dims, dtype=dtype, device="cuda:0") + + input_ark_view = ark.Tensor.from_torch(input_tensor) + other_ark_view = ark.Tensor.from_torch(other_tensor) + + output = ark.add(input_ark_view, other_ark_view) + # Perform in place operations + input_tensor += other_tensor + other_tensor += input_tensor + expected_output = (input_tensor + other_tensor).cpu().numpy() + + with ark.Runtime() as rt: + rt.launch() + rt.run() + output_host = output.to_numpy() + + assert np.allclose(output_host, expected_output) diff --git a/python/unittest/test_eval.py b/python/unittest/test_eval.py new file mode 100644 index 000000000..6986e03a0 --- /dev/null +++ b/python/unittest/test_eval.py @@ -0,0 +1,131 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +"""Test that Tensor.eval() correctly reuses compiled plans and recompiles +when the model graph changes.""" + +import pytest +import torch +import ark + +DEVICE = "cuda:0" + + +@pytest.fixture(autouse=True) +def _ark_init(): + """Reset ARK state before each test.""" + ark.init() + + +def _get_compiled_plan(): + """Return the plan string currently compiled in the executor.""" + from ark.executor import Executor + + return Executor.get().plan() + + +def test_eval_same_structure_produces_correct_results(): + """Two eval() calls on same-shaped graphs should both produce correct results. + Note: the plan strings may differ (different tensor IDs), but the executor's + file-level compile cache avoids redundant nvcc invocations.""" + a = torch.ones(64, dtype=torch.float32, device=DEVICE) * 3.0 + b = torch.ones(64, dtype=torch.float32, device=DEVICE) * 4.0 + + r1 = ark.add(a, b).eval() + assert torch.allclose(r1, a + b) + + r2 = ark.add(a, b).eval() + assert torch.allclose(r2, a + b) + + +def test_eval_recompile_on_different_graph(): + """A different graph should produce a different plan → triggers recompile.""" + a = torch.ones(64, dtype=torch.float32, device=DEVICE) * 2.0 + b = torch.ones(64, dtype=torch.float32, device=DEVICE) * 3.0 + + # Graph 1: add + r1 = ark.add(a, b).eval() + plan1 = _get_compiled_plan() + assert torch.allclose(r1, a + b) + + # Graph 2: mul (different op → different plan) + r2 = ark.mul(a, b).eval() + plan2 = _get_compiled_plan() + assert torch.allclose(r2, a * b) + + assert ( + plan1 != plan2 + ), "Different graph structure should produce a different plan" + + +def test_eval_recompile_on_graph_update(): + """Building more ops on top of a previously eval'd graph should + recompile and produce correct results.""" + ark.init() + a = torch.ones(64, dtype=torch.float32, device=DEVICE) * 2.0 + b = torch.ones(64, dtype=torch.float32, device=DEVICE) * 3.0 + + # Step 1: build c = a + b, eval + c = ark.add(a, b) + r1 = c.eval() + plan1 = _get_compiled_plan() + assert torch.allclose(r1, a + b) + + # Step 2: extend the SAME graph with d = c + a, eval + # c is still a valid ARK tensor in the same model + d = ark.add(c, a) + r2 = d.eval() + plan2 = _get_compiled_plan() + assert torch.allclose(r2, (a + b) + a) + + # The plan must have changed (graph grew from 1 op to 2 ops) + assert ( + plan1 != plan2 + ), "Extending the graph should produce a different plan and recompile" + + +def test_eval_with_torch_stream(): + """eval() with a torch.cuda.Stream should correctly interleave with + torch operations on the same stream across multiple iterations.""" + s = torch.cuda.Stream() + x = torch.ones(64, dtype=torch.float32, device=DEVICE) + + for i in range(5): + # Reset ARK model each iteration so eval() only runs the single add op + ark.init() + # torch op on the stream: x = x * 2 + with torch.cuda.stream(s): + x = x * 2 + # ARK op on the same stream: x = x + 1 + x = ark.add(x, 1.0).eval(stream=s) + + s.synchronize() + # Each iteration: x = x * 2 + 1 + # Starting from 1: 3, 7, 15, 31, 63 + expected = torch.full((64,), 63.0, dtype=torch.float32, device=DEVICE) + assert torch.allclose(x, expected) + + +def test_eval_chain_with_intermediate_read(): + """Build a chain of dependent ARK ops, eval() the final tensor, + then verify an intermediate tensor also has the correct value.""" + a = torch.ones(64, dtype=torch.float32, device=DEVICE) * 2.0 + + # Chain: b = a + 3 -> c = b * 4 -> d = c - 1 + b = ark.add(a, 3.0) + c = ark.mul(b, 4.0) + d = ark.sub(c, 1.0) + + # Only eval the final tensor + result = d.eval() + + # Final: (2+3)*4 - 1 = 19 + assert torch.allclose(result, torch.full((64,), 19.0, device=DEVICE)) + + # Intermediate b should also be materialized: 2+3 = 5 + b_val = b.to_torch() + assert torch.allclose(b_val, torch.full((64,), 5.0, device=DEVICE)) + + # Intermediate c should also be materialized: 5*4 = 20 + c_val = c.to_torch() + assert torch.allclose(c_val, torch.full((64,), 20.0, device=DEVICE)) diff --git a/python/unittest/test_placeholder.py b/python/unittest/test_placeholder.py index 640cc0e3c..74744853e 100644 --- a/python/unittest/test_placeholder.py +++ b/python/unittest/test_placeholder.py @@ -13,9 +13,7 @@ def test_placeholder_is_external(): assert t_placeholder.is_external(), "Placeholder tensor should be external" t_regular = ark.tensor([64], ark.fp32) - assert not t_regular.is_external(), ( - "Regular tensor should not be external" - ) + assert not t_regular.is_external(), "Regular tensor should not be external" @pytest_ark(need_torch=True) @@ -34,9 +32,9 @@ def test_placeholder_immediate_binding(): result = out.to_numpy() expected = torch_data.cpu().numpy() + 1.0 - assert np.allclose(result, expected), ( - f"max diff: {np.max(np.abs(result - expected))}" - ) + assert np.allclose( + result, expected + ), f"max diff: {np.max(np.abs(result - expected))}" @pytest_ark(need_torch=True) @@ -44,7 +42,9 @@ def test_placeholder_scalar_add(): """Test placeholder with scalar addition on non-aligned shape.""" import torch - torch_data = torch.arange(10, dtype=torch.float32, device="cuda:0").reshape(10, 1) + torch_data = torch.arange(10, dtype=torch.float32, device="cuda:0").reshape( + 10, 1 + ) t = ark.placeholder([10, 1], ark.fp32, data=torch_data) out = ark.add(t, 5.0) @@ -55,9 +55,9 @@ def test_placeholder_scalar_add(): result = out.to_numpy() expected = torch_data.cpu().numpy() + 5.0 - assert np.allclose(result, expected), ( - f"max diff: {np.max(np.abs(result - expected))}" - ) + assert np.allclose( + result, expected + ), f"max diff: {np.max(np.abs(result - expected))}" @pytest_ark(need_torch=True) @@ -79,9 +79,9 @@ def test_placeholder_multiple(): result = out.to_numpy() expected = torch_a.cpu().numpy() + torch_b.cpu().numpy() - assert np.allclose(result, expected), ( - f"max diff: {np.max(np.abs(result - expected))}" - ) + assert np.allclose( + result, expected + ), f"max diff: {np.max(np.abs(result - expected))}" @pytest_ark(need_torch=True) @@ -100,9 +100,9 @@ def test_placeholder_fp16(): result = out.to_numpy() expected = torch_data.cpu().numpy() * 0.5 - assert np.allclose(result, expected, atol=1e-2), ( - f"max diff: {np.max(np.abs(result - expected))}" - ) + assert np.allclose( + result, expected, atol=1e-2 + ), f"max diff: {np.max(np.abs(result - expected))}" @pytest_ark(need_torch=True) @@ -122,9 +122,9 @@ def test_placeholder_from_torch(): result = out.to_numpy() expected = torch_tensor.cpu().numpy() + 10.0 - assert np.allclose(result, expected), ( - f"max diff: {np.max(np.abs(result - expected))}" - ) + assert np.allclose( + result, expected + ), f"max diff: {np.max(np.abs(result - expected))}" @pytest_ark(need_torch=True) @@ -144,9 +144,9 @@ def test_placeholder_tensor_mappings_launch(): result = out.to_numpy() expected = torch_input.cpu().numpy() * 3.0 - assert np.allclose(result, expected), ( - f"max diff: {np.max(np.abs(result - expected))}" - ) + assert np.allclose( + result, expected + ), f"max diff: {np.max(np.abs(result - expected))}" @pytest_ark(need_torch=True) @@ -169,9 +169,9 @@ def test_placeholder_runtime_rebinding(): result2 = out.to_numpy() assert np.allclose(result1, 6.0), f"Run 1: expected 6.0, got {result1[:5]}" - assert np.allclose(result2, 11.0), ( - f"Run 2: expected 11.0, got {result2[:5]}" - ) + assert np.allclose( + result2, 11.0 + ), f"Run 2: expected 11.0, got {result2[:5]}" @pytest_ark(need_torch=True) diff --git a/tools/lint.sh b/tools/lint.sh new file mode 100755 index 000000000..5c97626b0 --- /dev/null +++ b/tools/lint.sh @@ -0,0 +1,74 @@ +#!/usr/bin/env bash + +PROJECT_ROOT=$(dirname "$(realpath "$0")")/.. +LINT_CPP=false +LINT_PYTHON=false +DRY_RUN=false +EXIT_CODE=0 + +usage() { + echo "Usage: $0 [cpp] [py] [dry]" + echo " cpp Lint C++ code" + echo " py Lint Python code" + echo " dry Dry run mode (no changes made)" +} + +# Parse arguments +for arg in "$@"; do + case "$arg" in + cpp) + LINT_CPP=true + ;; + py) + LINT_PYTHON=true + ;; + dry) + DRY_RUN=true + ;; + *) + echo "Error: Unknown argument '$arg'" + usage + exit 1 + ;; + esac +done + +# If no cpp or py specified, default to both +if [ "$LINT_CPP" = false ] && [ "$LINT_PYTHON" = false ]; then + LINT_CPP=true + LINT_PYTHON=true +fi + +if $LINT_CPP; then + echo "Linting C++ code..." + # Find all git-tracked files with .c/.h/.cpp/.hpp/.cc/.cu/.cuh extensions + files=$(git -C "$PROJECT_ROOT" ls-files --cached | grep -E '\.(c|h|cpp|hpp|cc|cu|cuh)$' | sed "s|^|$PROJECT_ROOT/|") + if [ -n "$files" ]; then + if $DRY_RUN; then + clang-format -style=file --dry-run --Werror $files + else + clang-format -style=file -i $files + fi + if [ $? -ne 0 ]; then + EXIT_CODE=1 + fi + fi +fi + +if $LINT_PYTHON; then + echo "Linting Python code..." + # Find all git-tracked files with .py extension + files=$(git -C "$PROJECT_ROOT" ls-files --cached | grep -E '\.py$' | sed "s|^|$PROJECT_ROOT/|") + if [ -n "$files" ]; then + if $DRY_RUN; then + python3 -m black --check --diff $files + else + python3 -m black $files + fi + if [ $? -ne 0 ]; then + EXIT_CODE=1 + fi + fi +fi + +exit $EXIT_CODE