From 4a73bf94b805c493a33b1711991ef5d084275761 Mon Sep 17 00:00:00 2001 From: Jonathan Clohessy Date: Tue, 17 Feb 2026 17:32:03 +0000 Subject: [PATCH 1/2] Optimize KLEIDIAI LHS packing in convolution MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit previously, the convolution kernel for KLEIDIAI would allocate a large contiguous buffer for the LHS (left-hand-side) matrix packing, which could consume excessive memory and reduce cache efficiency. This patch modifies the packing strategy to use a chunked approach: - Introduce a compile-time upper bound for temporary LHS packing buffers - Allocate a moderate-sized temporary buffer once. - Pack LHS rows in chunks, perform computation, then reuse the buffer for the next chunk. Benefits: - Significantly reduces peak memory usage. - Improves cache utilization and overall computation efficiency. - Avoids potential memory allocation failures for large convolutions. Performance improvement: - Test with model https://huggingface.co/garavv/arcface-onnx on MTK D9500 Before this patch ``` ./build/RelWithDebInfo/onnxruntime_perf_test -x 1 -r 1000 arc.onnx Number of inferences per second: 4.25327 ``` After this patch ``` ./build/RelWithDebInfo/onnxruntime_perf_test -x 1 -r 1000 arc.onnx Number of inferences per second: 5.03257 ``` ---------------------------------------------------------------------------------------------------------------- sme_with_patch | sme_without_patch | Diff (μs) | Change % | Node Name ---------------------------------------------------------------------------------------------------------------- 11975.10 | 31235.30 | 19260.20 | 160.84% | StatefulPartitionedCall/ResNet34/conv2_block1_1_conv/Conv2D 4514.80 | 7691.70 | 3176.90 | 70.37% | StatefulPartitionedCall/ResNet34/conv2_block2_1_conv/Conv2D 4220.20 | 7120.70 | 2900.50 | 68.73% | StatefulPartitionedCall/ResNet34/conv2_block3_1_conv/Conv2D 5429.20 | 8279.60 | 2850.40 | 52.50% | StatefulPartitionedCall/ResNet34/conv3_block1_1_conv/Conv2D 4497.80 | 5478.40 | 980.60 | 21.80% | StatefulPartitionedCall/ResNet34/conv4_block1_1_conv/Conv2D 3474.30 | 4351.80 | 877.50 | 25.26% | StatefulPartitionedCall/ResNet34/conv3_block3_1_conv/Conv2D 3627.30 | 4504.00 | 876.70 | 24.17% | StatefulPartitionedCall/ResNet34/conv3_block4_1_conv/Conv2D 5244.20 | 5961.10 | 716.90 | 13.67% | StatefulPartitionedCall/ResNet34/conv1_conv/Conv2D 3439.80 | 4050.90 | 611.10 | 17.77% | StatefulPartitionedCall/ResNet34/conv3_block2_1_conv/Conv2D 9749.80 | 10195.50 | 445.70 | 4.57% | StatefulPartitionedCall/ResNet34/conv2_block2_2_conv/Conv2D 3814.00 | 4209.80 | 395.80 | 10.38% | StatefulPartitionedCall/ResNet34/conv5_block2_2_conv/Conv2D 2715.90 | 3034.70 | 318.80 | 11.74% | StatefulPartitionedCall/ResNet34/conv4_block6_1_conv/Conv2D 4089.10 | 4367.80 | 278.70 | 6.82% | StatefulPartitionedCall/ResNet34/conv5_block1_1_conv/Conv2D 2698.00 | 2959.50 | 261.50 | 9.69% | StatefulPartitionedCall/ResNet34/conv4_block5_1_conv/Conv2D 3869.20 | 4102.80 | 233.60 | 6.04% | StatefulPartitionedCall/ResNet34/conv5_block3_2_conv/Conv2D 2767.90 | 2966.80 | 198.90 | 7.19% | StatefulPartitionedCall/ResNet34/conv4_block4_1_conv/Conv2D 9652.10 | 9816.60 | 164.50 | 1.70% | StatefulPartitionedCall/ResNet34/conv2_block3_2_conv/Conv2D 2897.50 | 3054.60 | 157.10 | 5.42% | StatefulPartitionedCall/ResNet34/conv4_block3_1_conv/Conv2D 4601.20 | 4748.60 | 147.40 | 3.20% | StatefulPartitionedCall/ResNet34/conv5_block1_2_conv/Conv2D 3134.00 | 3246.10 | 112.10 | 3.58% | StatefulPartitionedCall/ResNet34/conv4_block2_1_conv/Conv2D Signed-off-by: Qxiang Xu Signed-off-by: Jonathan Clohessy --- .../mlas/lib/kleidiai/convolve_kleidiai.cpp | 207 ++++++++---------- 1 file changed, 97 insertions(+), 110 deletions(-) diff --git a/onnxruntime/core/mlas/lib/kleidiai/convolve_kleidiai.cpp b/onnxruntime/core/mlas/lib/kleidiai/convolve_kleidiai.cpp index cca4f5a19c417..c9c15ddad012f 100644 --- a/onnxruntime/core/mlas/lib/kleidiai/convolve_kleidiai.cpp +++ b/onnxruntime/core/mlas/lib/kleidiai/convolve_kleidiai.cpp @@ -23,6 +23,8 @@ const KaiF32IMatmulKernel& imatmul_conv = GetKleidiAIF32IMatmulUKernel(); +// Maximum temporary buffer size (in bytes) for KleidiAI LHS packing. +constexpr size_t MAX_LHS_CHUNK_BYTES = 2097152; // Left-hand-side (input indirection) cache key struct LhsCacheKey { @@ -251,37 +253,6 @@ static std::unique_ptr NChwToNhwc(const size_t n, return t; } -static void MultiThreadedLHSPackSme(MLAS_THREADPOOL* ThreadPool, const size_t ci, const size_t m, const size_t kh, - const size_t kw, const void * const* lhs_ptrs, std::byte* lhs_data, - const float* in_data, - const float* pad_ptr) { - size_t m_step = imatmul_conv.ukernel.get_m_step(); - - // Minimize the kernel call count for the number of available threads - auto RequiredTiles = MlasDivRoundup(m, m_step); - auto MaxTiles = std::min(static_cast(MlasGetMaximumThreadCount(ThreadPool)), RequiredTiles); - m_step *= MlasDivRoundup(RequiredTiles, MaxTiles); - RequiredTiles = MlasDivRoundup(m, m_step); - - MlasTrySimpleParallel(ThreadPool, static_cast(RequiredTiles), [&](ptrdiff_t tid) { - - auto m_idx = static_cast(tid) * m_step; - auto offset = kai_get_lhs_packed_offset_lhs_imatmul_pack_x32p2vlx1_x32p_sme(m_idx,kh*kw,ci); - - KLEIDIAI_KERNEL_LOG("kai_run_lhs_imatmul_pack_x32p2vlx1_x32p_sme" - << " M=" << (m < (m_idx + m_step) ? m - m_idx : m_step) - << " k_chunk_count=" << (kh * kw) - << " k_chunk_length=" << ci); - kai_run_lhs_imatmul_pack_x32p2vlx1_x32p_sme( - m < (m_idx + m_step) ? m - m_idx : m_step, kh * kw, ci, - lhs_ptrs + m_idx * kh * kw, - reinterpret_cast(in_data), - reinterpret_cast(pad_ptr), - lhs_data + offset - ); - }); -} - size_t MLASCALL ArmKleidiAI::MlasConvSymmetricChannelsLast2DFloatPackWSize( @@ -418,81 +389,55 @@ static std::shared_ptr LhsPtrFill(const size_t ci, const size_t i return lhs_ptrs; } -static std::unique_ptr LhsPackImageDataSme(const size_t ci, const size_t ih, const size_t iw, - const size_t kh, const size_t kw, const size_t sh, - const size_t sw, const size_t padding, const float* in, - bool input_is_channels_last, - MLAS_THREADPOOL* ThreadPool) -{ +static const float* GetOrCreatePadDataSme(const size_t ci) { size_t padsize = 256; - if(ci > padsize) - { - // figure out how many blocks needed to correctly fill padding - padsize = ((ci + padsize - 1) / padsize) * padsize; + if (ci > padsize) { + padsize = MlasDivRoundup(ci, padsize) * padsize; } // pad_ptr must be at least 'ci' floats for padding pixels. - // Using a thread_local grow-only buffer to avoid cross-thread interference and ensure sizing is correct. - // - // The pad buffer contents are always zero. Since the buffer is grow-only and never written with non-zero data, - // we only need to zero-initialize newly-grown elements. + // The buffer is grow-only and zero-initializes newly-grown elements. thread_local std::vector pad_ptr; - if (pad_ptr.size() < padsize) { pad_ptr.resize(padsize, 0.f); } - //create lhs in format required for imatmul - const auto m = ComputeConvOutSize(ih, kh, padding, sh) * ComputeConvOutSize(iw, kw, padding, sw); - - const auto lhs_size = kai_get_lhs_packed_size_lhs_imatmul_pack_x32p2vlx1_x32p_sme(m,kh*kw,ci); - auto lhs = std::make_unique(lhs_size); + return pad_ptr.data(); +} - std::unique_ptr nhwc_holder; - const float* activation_src = nullptr; - if (input_is_channels_last) { - activation_src = in; - } else { - nhwc_holder = NChwToNhwc(1, ci, ih, iw, in, 1, 1, false, ThreadPool); - activation_src = nhwc_holder.get(); - } +static std::shared_ptr GetOrCreateLhsPtrTableSme(const size_t ci, const size_t ih, const size_t iw, + const size_t kh, const size_t kw, const size_t sh, + const size_t sw, const size_t padding, const float* pad_ptr) { + // LhsPtrFill stores geometry offsets only; the current input base is supplied when packing. + LhsCacheKey key = { + ci, ih, iw, + padding, sh, sw, + kh, kw, + 1, 1 + }; - // Cache of computed lhs ptr offsets. thread_local to prevent interference from parallel sessions. - // - // Entries include pointers to the pad buffer for out-of-bounds pixels, so we must not reuse entries after the - // pad buffer is reallocated. To avoid clearing the entire cache, we group caches by pad buffer identity and - // invalidate only the old group when the pad buffer moves. + // Setting up LHS ptr cache and tracking value of last passed pad_ptr using LhsPtrsCache = std::unordered_map>; thread_local std::unordered_map lhs_ptrs_cache_by_pad; + thread_local const float* last_pad_ptr = nullptr; // If pad_ptr moved (vector reallocation), drop only the old group to avoid accumulating unreachable entries. - thread_local const float* last_pad_ptr = nullptr; - const float* cur_pad_ptr = pad_ptr.data(); + const float* cur_pad_ptr = pad_ptr; if (last_pad_ptr != nullptr && last_pad_ptr != cur_pad_ptr) { lhs_ptrs_cache_by_pad.erase(last_pad_ptr); } last_pad_ptr = cur_pad_ptr; - LhsCacheKey key = { - ci, ih, iw, - padding, sh, sw, - kh, kw, - 1, 1 - }; - auto& lhs_ptrs_cache = lhs_ptrs_cache_by_pad[cur_pad_ptr]; std::shared_ptr lhs_ptrs; if (auto found = lhs_ptrs_cache.find(key); found != lhs_ptrs_cache.end()) { lhs_ptrs = found->second; } else { - lhs_ptrs = LhsPtrFill(ci, ih, iw, kh, kw, sh, sw, padding, &pad_ptr[0]); + lhs_ptrs = LhsPtrFill(ci, ih, iw, kh, kw, sh, sw, padding, pad_ptr); lhs_ptrs_cache[key] = lhs_ptrs; } - - MultiThreadedLHSPackSme(ThreadPool, ci, m, kh, kw, &lhs_ptrs[0], &lhs[0], activation_src, &pad_ptr[0]); - - return lhs; + return lhs_ptrs; } static void ConvolveSme(const size_t co, //channels out @@ -517,7 +462,6 @@ static void ConvolveSme(const size_t co, //channels out bool input_is_channels_last, MLAS_THREADPOOL* ThreadPool) { - //RhsPackWeightsBiasSme() - to perform dilation increases kernel size and masks unused weights //compute corrected dimensions of dilated kernel const auto d_kh = ComputeKernelSize(dilationh, kh); const auto d_kw = ComputeKernelSize(dilationw, kw); @@ -543,12 +487,12 @@ static void ConvolveSme(const size_t co, //channels out dim[2] = MlasDivRoundup(RequiredTiles * dim[2], dim[1] * dim[2]); //compute new step sizes - m_step *= MlasDivRoundup(MlasDivRoundup(m, dim[1]), m_step); - n_step *= MlasDivRoundup(MlasDivRoundup(co, dim[2]), n_step); + size_t m_tile_step = m_step * MlasDivRoundup(MlasDivRoundup(m, dim[1]), m_step); + size_t n_tile_step = n_step * MlasDivRoundup(MlasDivRoundup(co, dim[2]), n_step); //update tile iterations - dim[1] = MlasDivRoundup(m, m_step); - dim[2] = MlasDivRoundup(co, n_step); + dim[1] = MlasDivRoundup(m, m_tile_step); + dim[2] = MlasDivRoundup(co, n_tile_step); for (size_t g = 0; g < groups; ++g) { @@ -558,7 +502,17 @@ static void ConvolveSme(const size_t co, //channels out result = tmp_mlas_aligned; } - auto lhs = LhsPackImageDataSme(ci, ih, iw, d_kh, d_kw, sh, sw, padding, in, input_is_channels_last, ThreadPool); + // LHS packing data + const float* pad_data = GetOrCreatePadDataSme(ci); + std::unique_ptr nhwc_holder; + const float* activation_src = in; + if (!input_is_channels_last) { + nhwc_holder = NChwToNhwc(1, ci, ih, iw, in, 1, 1, false, ThreadPool); + activation_src = nhwc_holder.get(); + } + auto lhs_ptrs = GetOrCreateLhsPtrTableSme(ci, ih, iw, d_kh, d_kw, sh, sw, padding, pad_data); + + // RHS packing data const std::byte* rhs_data = packed_rhs ? packed_rhs + g * packed_rhs_group_stride : nullptr; std::unique_ptr rhs_storage; if (rhs_data == nullptr) { @@ -577,40 +531,73 @@ static void ConvolveSme(const size_t co, //channels out MlasTrySimpleParallel(ThreadPool, static_cast(dim[0] * dim[1] * dim[2]), [&](ptrdiff_t tid) { //compute B,M,N index from iteration index //ptrdiff_t BIdx = tid / (dim[1] * dim[2]); - ptrdiff_t MIdx = (tid % (dim[1] * dim[2])) / dim[2]; - ptrdiff_t NIdx = (tid % (dim[1] * dim[2])) % dim[2]; + const size_t m_idx = static_cast((tid % (dim[1] * dim[2])) / dim[2]); + const size_t n_idx = static_cast((tid % (dim[1] * dim[2])) % dim[2]); // Get rhs tile, B const size_t rhs_packed_offset = - imatmul_conv.ukernel.get_rhs_packed_offset(NIdx * n_step, d_kh * d_kw, ci); + imatmul_conv.ukernel.get_rhs_packed_offset(n_idx * n_tile_step, d_kh * d_kw, ci); auto BTile = reinterpret_cast( rhs_data + rhs_packed_offset ); - // Get lhs tile, A - const size_t lhs_packed_offset = - imatmul_conv.ukernel.get_lhs_packed_offset(MIdx * m_step, d_kh * d_kw, ci); - - auto ATile = reinterpret_cast( - reinterpret_cast(lhs.get()) + lhs_packed_offset - ); - - auto TileSizeM = (MIdx + 1) * m_step > m ? (m - MIdx * m_step) : m_step; - auto TileSizeN = (NIdx + 1) * n_step > co ? (co - NIdx * n_step) : n_step; - - // Get result tile, C - auto CTile = &reinterpret_cast(result)[ - MIdx * m_step * co * sizeof(float) + - NIdx * n_step * sizeof(float)]; - - KLEIDIAI_KERNEL_LOG(imatmul_conv.name - << " M=" << TileSizeM << " N=" << TileSizeN - << " k_chunk_count=" << (d_kh * d_kw) << " k_chunk_length=" << ci); - imatmul_conv.ukernel.run_imatmul( - TileSizeM, TileSizeN, d_kh * d_kw, ci, ATile, BTile, CTile, co * sizeof(float), - -std::numeric_limits::max(), std::numeric_limits::max() - ); + // Calculate lhs tile chunks + auto tile_size_n = (n_idx + 1) * n_tile_step > co ? (co - n_idx * n_tile_step) : n_tile_step; + + // Compute the starting row index of the current M tile + size_t tile_m_start = m_idx * m_tile_step; + // Actual number of rows in this tile (may be smaller for the last tile) + size_t tile_m_size = std::min(m_tile_step, m - tile_m_start); + + // Query the packed LHS buffer size for exactly one m_step block. + const size_t bytes_per_m_step = kai_get_lhs_packed_size_lhs_imatmul_pack_x32p2vlx1_x32p_sme( + m_step, d_kh * d_kw, ci); + + // Determine how many rows we can pack in one chunk. + size_t m_chunk = std::max(m_step, MAX_LHS_CHUNK_BYTES / bytes_per_m_step * m_step); + + // Do not exceed the number of rows available in this tile + m_chunk = std::min(tile_m_size, m_chunk); + + // Compute the exact packed buffer size for m_chunk rows. + const size_t lhs_buffer_bytes = kai_get_lhs_packed_size_lhs_imatmul_pack_x32p2vlx1_x32p_sme( + m_chunk, d_kh * d_kw, ci); + + // Allocate a single reusable buffer for LHS packing + auto lhs = std::make_unique(lhs_buffer_bytes); + + // Interpret the packed LHS buffer as the input A tile + // for the matrix multiplication kernel. + auto ATile = reinterpret_cast(lhs.get()); + + for (size_t m_base = 0; m_base < tile_m_size; m_base += m_chunk) { + // Actual number of rows processed in this iteration. + // The last chunk may be smaller than m_chunk. + const size_t tile_size_m = std::min(m_chunk, tile_m_size - m_base); + + // Pack TileSizeM rows of the LHS matrix into a temporary buffer. + kai_run_lhs_imatmul_pack_x32p2vlx1_x32p_sme(tile_size_m, + d_kh * d_kw, + ci, + lhs_ptrs.get() + (tile_m_start + m_base) * d_kh * d_kw, + reinterpret_cast(activation_src), + reinterpret_cast(pad_data), + lhs.get()); + + // Get result tile, C + auto CTile = &reinterpret_cast(result)[ + (tile_m_start + m_base) * co * sizeof(float) + + n_idx * n_tile_step * sizeof(float)]; + + KLEIDIAI_KERNEL_LOG(imatmul_conv.name + << " M=" << tile_size_m << " N=" << tile_size_n + << " k_chunk_count=" << (d_kh * d_kw) << " k_chunk_length=" << ci); + + imatmul_conv.ukernel.run_imatmul(tile_size_m, tile_size_n, d_kh * d_kw, ci, ATile, BTile, + CTile, co * sizeof(float), -std::numeric_limits::max(), + std::numeric_limits::max()); + } }); if (need_transpose) { From 7b049288be0a8df775cc1413a04854f3e16f23a8 Mon Sep 17 00:00:00 2001 From: Martin Klacer Date: Mon, 11 May 2026 17:30:46 +0100 Subject: [PATCH 2/2] Add heuristic delegation to SGEMM vs IGEMM to KleidiAI convolution path - Added SelectConvRoute function to mlasi_kleidiai.h to decide between GemmFallback and Igemm based on the convolution workload parameters - Updated CheckCapabilitiesSme function in convolve_kleidiai.cpp to use the new SelectConvRoute function Co-authored-by: Damien Dooley Signed-off-by: Martin Klacer --- onnxruntime/core/mlas/lib/convolve.cpp | 20 ++++++- .../mlas/lib/kleidiai/convolve_kleidiai.cpp | 18 +++++- .../core/mlas/lib/kleidiai/mlasi_kleidiai.h | 59 +++++++++++++++++++ 3 files changed, 91 insertions(+), 6 deletions(-) diff --git a/onnxruntime/core/mlas/lib/convolve.cpp b/onnxruntime/core/mlas/lib/convolve.cpp index 4378ec1948fdb..1e3c7cc61b083 100644 --- a/onnxruntime/core/mlas/lib/convolve.cpp +++ b/onnxruntime/core/mlas/lib/convolve.cpp @@ -15,6 +15,9 @@ Module Name: --*/ #include "mlasi.h" +#if defined(USE_KLEIDIAI) +#include "kleidiai/mlasi_kleidiai.h" +#endif // // Define the number of working buffer elements required per thread. @@ -574,6 +577,11 @@ Return Value: const size_t FilterCount = Parameters->FilterCount; const size_t OutputSize = Parameters->OutputSize; const size_t K = Parameters->K; + bool use_gemm_batch_override = false; +#if defined(USE_KLEIDIAI) + use_gemm_batch_override = + ArmKleidiAI::SelectConvRoute(Parameters) == ArmKleidiAI::ConvRoute::GemmFallback; +#endif // // Compute the strides to step through slices of the local segment. @@ -637,9 +645,15 @@ Return Value: SegmentStartN + n, CountN); } - MlasSgemmOperation(CblasNoTrans, CblasNoTrans, FilterCount, CountN, - CountK, 1.0f, Filter + k, K, ColumnBuffer, CountN, beta, - SegmentOutput, OutputSize); + if (use_gemm_batch_override) { + MlasGemm(CblasNoTrans, CblasNoTrans, FilterCount, CountN, + CountK, 1.0f, Filter + k, K, ColumnBuffer, CountN, beta, + SegmentOutput, OutputSize, nullptr, Parameters->BackendKernelSelectorConfig); + } else { + MlasSgemmOperation(CblasNoTrans, CblasNoTrans, FilterCount, CountN, + CountK, 1.0f, Filter + k, K, ColumnBuffer, CountN, beta, + SegmentOutput, OutputSize); + } beta = 1.0f; } diff --git a/onnxruntime/core/mlas/lib/kleidiai/convolve_kleidiai.cpp b/onnxruntime/core/mlas/lib/kleidiai/convolve_kleidiai.cpp index c9c15ddad012f..c0872d7155069 100644 --- a/onnxruntime/core/mlas/lib/kleidiai/convolve_kleidiai.cpp +++ b/onnxruntime/core/mlas/lib/kleidiai/convolve_kleidiai.cpp @@ -23,8 +23,9 @@ const KaiF32IMatmulKernel& imatmul_conv = GetKleidiAIF32IMatmulUKernel(); -// Maximum temporary buffer size (in bytes) for KleidiAI LHS packing. -constexpr size_t MAX_LHS_CHUNK_BYTES = 2097152; +// Hard cap of 2 MiB on temporary LHS packing chunks, intended to keep tiles +// reasonably cache-friendly and to avoid oversized temporary buffers +constexpr size_t MAX_LHS_CHUNK_BYTES = 2 * 1024 * 1024; // Left-hand-side (input indirection) cache key struct LhsCacheKey { @@ -126,7 +127,18 @@ static bool CheckCapabilitiesSme(const MLAS_CONV_PARAMETERS* Parameters) { return false; } - return true; + const auto route = ArmKleidiAI::SelectConvRoute(Parameters); + + if (route == ArmKleidiAI::ConvRoute::Igemm) { + return true; + } + + if (route == ArmKleidiAI::ConvRoute::GemmFallback) { + KLEIDIAI_DEBUG_LOG("CheckCapabilitiesSme returning false to prefer SGEMM-backed conv path."); + } else { + KLEIDIAI_DEBUG_LOG("CheckCapabilitiesSme returning false on functional or optimization checks."); + } + return false; } //General purpose axis swapping diff --git a/onnxruntime/core/mlas/lib/kleidiai/mlasi_kleidiai.h b/onnxruntime/core/mlas/lib/kleidiai/mlasi_kleidiai.h index 3c9f398ece887..ed6e532f5d8ed 100644 --- a/onnxruntime/core/mlas/lib/kleidiai/mlasi_kleidiai.h +++ b/onnxruntime/core/mlas/lib/kleidiai/mlasi_kleidiai.h @@ -56,6 +56,65 @@ inline const bool UseSME2 = MLAS_CPUIDINFO::GetCPUIDInfo().HasArm_SME2(); inline const bool UseSME = MLAS_CPUIDINFO::GetCPUIDInfo().HasArm_SME(); inline const std::string_view vendor_name = MLAS_CPUIDINFO::GetCPUIDInfo().GetCPUVendor(); +enum class ConvRoute { + None, + Igemm, + GemmFallback, +}; + +inline constexpr size_t ConvIgemmMaxWork = 1000000ULL; + +inline constexpr size_t ComputeDilatedKernelSize(size_t dilation, size_t kernel) { + return (dilation * kernel) - (dilation - 1); +} + +inline constexpr size_t ComputeConvOutputSize(size_t input, size_t kernel, size_t padding, size_t stride) { + if (stride > 0 && (input + 2 * padding) >= kernel) { + return (((input - kernel) + (2 * padding)) / stride) + 1; + } + + return 0; +} + +inline ConvRoute SelectConvRoute(const MLAS_CONV_PARAMETERS* Parameters) { + if ((Parameters->Dimensions != 2) || + (Parameters->BatchCount != 1) || + (Parameters->Beta != 0.f) || + (Parameters->Padding[0] != Parameters->Padding[1]) || + (Parameters->Padding[0] != Parameters->Padding[2]) || + (Parameters->Padding[0] != Parameters->Padding[3])) { + return ConvRoute::None; + } + + const auto effective_kernel_h = + ComputeDilatedKernelSize(Parameters->DilationShape[0], Parameters->KernelShape[0]); + const auto effective_kernel_w = + ComputeDilatedKernelSize(Parameters->DilationShape[1], Parameters->KernelShape[1]); + const auto output_m = + ComputeConvOutputSize(Parameters->InputShape[0], effective_kernel_h, Parameters->Padding[0], Parameters->StrideShape[0]) * + ComputeConvOutputSize(Parameters->InputShape[1], effective_kernel_w, Parameters->Padding[1], Parameters->StrideShape[1]); + + if (output_m == 0) { + return ConvRoute::None; + } + + const auto filter_count = Parameters->FilterCount; + if (filter_count == 1 || Parameters->KernelShape[0] < 3 || Parameters->KernelShape[1] < 3) { + return ConvRoute::None; + } + + const auto effective_k = Parameters->InputChannels * effective_kernel_h * effective_kernel_w; + if(effective_k == 0 || filter_count == 0) { + return ConvRoute::None; + } + + const auto igemm_max_output_m = (ConvIgemmMaxWork / effective_k / filter_count); + if (output_m > igemm_max_output_m) { + return ConvRoute::GemmFallback; + } + return ConvRoute::Igemm; +} + // Buffer packing routines. // size_t