Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 17 additions & 3 deletions onnxruntime/core/mlas/lib/convolve.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,9 @@ Module Name:
--*/

#include "mlasi.h"
#if defined(USE_KLEIDIAI)
#include "kleidiai/mlasi_kleidiai.h"
#endif
Comment on lines 17 to +20

//
// Define the number of working buffer elements required per thread.
Expand Down Expand Up @@ -574,6 +577,11 @@ Return Value:
const size_t FilterCount = Parameters->FilterCount;
const size_t OutputSize = Parameters->OutputSize;
const size_t K = Parameters->K;
bool use_gemm_batch_override = false;
#if defined(USE_KLEIDIAI)
use_gemm_batch_override =
ArmKleidiAI::SelectConvRoute(Parameters) == ArmKleidiAI::ConvRoute::GemmFallback;
#endif

//
// Compute the strides to step through slices of the local segment.
Expand Down Expand Up @@ -637,9 +645,15 @@ Return Value:
SegmentStartN + n, CountN);
}

MlasSgemmOperation(CblasNoTrans, CblasNoTrans, FilterCount, CountN,
CountK, 1.0f, Filter + k, K, ColumnBuffer, CountN, beta,
SegmentOutput, OutputSize);
if (use_gemm_batch_override) {
MlasGemm(CblasNoTrans, CblasNoTrans, FilterCount, CountN,
CountK, 1.0f, Filter + k, K, ColumnBuffer, CountN, beta,
SegmentOutput, OutputSize, nullptr, Parameters->BackendKernelSelectorConfig);
} else {
MlasSgemmOperation(CblasNoTrans, CblasNoTrans, FilterCount, CountN,
CountK, 1.0f, Filter + k, K, ColumnBuffer, CountN, beta,
SegmentOutput, OutputSize);
}

beta = 1.0f;
}
Expand Down
221 changes: 110 additions & 111 deletions onnxruntime/core/mlas/lib/kleidiai/convolve_kleidiai.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,9 @@

const KaiF32IMatmulKernel& imatmul_conv = GetKleidiAIF32IMatmulUKernel();

// Hard cap of 2 MiB on temporary LHS packing chunks, intended to keep tiles
// reasonably cache-friendly and to avoid oversized temporary buffers
constexpr size_t MAX_LHS_CHUNK_BYTES = 2 * 1024 * 1024;

// Left-hand-side (input indirection) cache key
struct LhsCacheKey {
Expand Down Expand Up @@ -124,7 +127,18 @@ static bool CheckCapabilitiesSme(const MLAS_CONV_PARAMETERS* Parameters) {
return false;
}

return true;
const auto route = ArmKleidiAI::SelectConvRoute(Parameters);

if (route == ArmKleidiAI::ConvRoute::Igemm) {
return true;
}

if (route == ArmKleidiAI::ConvRoute::GemmFallback) {
KLEIDIAI_DEBUG_LOG("CheckCapabilitiesSme returning false to prefer SGEMM-backed conv path.");
} else {
KLEIDIAI_DEBUG_LOG("CheckCapabilitiesSme returning false on functional or optimization checks.");
}
return false;
}

//General purpose axis swapping
Expand Down Expand Up @@ -251,37 +265,6 @@ static std::unique_ptr<float[]> NChwToNhwc(const size_t n,
return t;
}

static void MultiThreadedLHSPackSme(MLAS_THREADPOOL* ThreadPool, const size_t ci, const size_t m, const size_t kh,
const size_t kw, const void * const* lhs_ptrs, std::byte* lhs_data,
const float* in_data,
const float* pad_ptr) {
size_t m_step = imatmul_conv.ukernel.get_m_step();

// Minimize the kernel call count for the number of available threads
auto RequiredTiles = MlasDivRoundup(m, m_step);
auto MaxTiles = std::min(static_cast<size_t>(MlasGetMaximumThreadCount(ThreadPool)), RequiredTiles);
m_step *= MlasDivRoundup(RequiredTiles, MaxTiles);
RequiredTiles = MlasDivRoundup(m, m_step);

MlasTrySimpleParallel(ThreadPool, static_cast<ptrdiff_t>(RequiredTiles), [&](ptrdiff_t tid) {

auto m_idx = static_cast<size_t>(tid) * m_step;
auto offset = kai_get_lhs_packed_offset_lhs_imatmul_pack_x32p2vlx1_x32p_sme(m_idx,kh*kw,ci);

KLEIDIAI_KERNEL_LOG("kai_run_lhs_imatmul_pack_x32p2vlx1_x32p_sme"
<< " M=" << (m < (m_idx + m_step) ? m - m_idx : m_step)
<< " k_chunk_count=" << (kh * kw)
<< " k_chunk_length=" << ci);
kai_run_lhs_imatmul_pack_x32p2vlx1_x32p_sme(
m < (m_idx + m_step) ? m - m_idx : m_step, kh * kw, ci,
lhs_ptrs + m_idx * kh * kw,
reinterpret_cast<size_t>(in_data),
reinterpret_cast<const void*>(pad_ptr),
lhs_data + offset
);
});
}

size_t
MLASCALL
ArmKleidiAI::MlasConvSymmetricChannelsLast2DFloatPackWSize(
Expand Down Expand Up @@ -418,81 +401,55 @@ static std::shared_ptr<const void*[]> LhsPtrFill(const size_t ci, const size_t i
return lhs_ptrs;
}

static std::unique_ptr<std::byte[]> LhsPackImageDataSme(const size_t ci, const size_t ih, const size_t iw,
const size_t kh, const size_t kw, const size_t sh,
const size_t sw, const size_t padding, const float* in,
bool input_is_channels_last,
MLAS_THREADPOOL* ThreadPool)
{
static const float* GetOrCreatePadDataSme(const size_t ci) {
size_t padsize = 256;
if(ci > padsize)
{
// figure out how many blocks needed to correctly fill padding
padsize = ((ci + padsize - 1) / padsize) * padsize;
if (ci > padsize) {
padsize = MlasDivRoundup(ci, padsize) * padsize;
}

// pad_ptr must be at least 'ci' floats for padding pixels.
// Using a thread_local grow-only buffer to avoid cross-thread interference and ensure sizing is correct.
//
// The pad buffer contents are always zero. Since the buffer is grow-only and never written with non-zero data,
// we only need to zero-initialize newly-grown elements.
// The buffer is grow-only and zero-initializes newly-grown elements.
thread_local std::vector<float> pad_ptr;

if (pad_ptr.size() < padsize) {
pad_ptr.resize(padsize, 0.f);
}

//create lhs in format required for imatmul
const auto m = ComputeConvOutSize(ih, kh, padding, sh) * ComputeConvOutSize(iw, kw, padding, sw);

const auto lhs_size = kai_get_lhs_packed_size_lhs_imatmul_pack_x32p2vlx1_x32p_sme(m,kh*kw,ci);
auto lhs = std::make_unique<std::byte[]>(lhs_size);
return pad_ptr.data();
}

std::unique_ptr<float[]> nhwc_holder;
const float* activation_src = nullptr;
if (input_is_channels_last) {
activation_src = in;
} else {
nhwc_holder = NChwToNhwc(1, ci, ih, iw, in, 1, 1, false, ThreadPool);
activation_src = nhwc_holder.get();
}
static std::shared_ptr<const void*[]> GetOrCreateLhsPtrTableSme(const size_t ci, const size_t ih, const size_t iw,
const size_t kh, const size_t kw, const size_t sh,
const size_t sw, const size_t padding, const float* pad_ptr) {
// LhsPtrFill stores geometry offsets only; the current input base is supplied when packing.
LhsCacheKey key = {
ci, ih, iw,
padding, sh, sw,
kh, kw,
1, 1
};

// Cache of computed lhs ptr offsets. thread_local to prevent interference from parallel sessions.
//
// Entries include pointers to the pad buffer for out-of-bounds pixels, so we must not reuse entries after the
// pad buffer is reallocated. To avoid clearing the entire cache, we group caches by pad buffer identity and
// invalidate only the old group when the pad buffer moves.
// Setting up LHS ptr cache and tracking value of last passed pad_ptr
using LhsPtrsCache = std::unordered_map<LhsCacheKey, std::shared_ptr<const void*[]>>;
thread_local std::unordered_map<const float*, LhsPtrsCache> lhs_ptrs_cache_by_pad;
thread_local const float* last_pad_ptr = nullptr;

// If pad_ptr moved (vector reallocation), drop only the old group to avoid accumulating unreachable entries.
thread_local const float* last_pad_ptr = nullptr;
const float* cur_pad_ptr = pad_ptr.data();
const float* cur_pad_ptr = pad_ptr;
if (last_pad_ptr != nullptr && last_pad_ptr != cur_pad_ptr) {
lhs_ptrs_cache_by_pad.erase(last_pad_ptr);
}
last_pad_ptr = cur_pad_ptr;

LhsCacheKey key = {
ci, ih, iw,
padding, sh, sw,
kh, kw,
1, 1
};

auto& lhs_ptrs_cache = lhs_ptrs_cache_by_pad[cur_pad_ptr];

std::shared_ptr<const void*[]> lhs_ptrs;
if (auto found = lhs_ptrs_cache.find(key); found != lhs_ptrs_cache.end()) {
lhs_ptrs = found->second;
} else {
lhs_ptrs = LhsPtrFill(ci, ih, iw, kh, kw, sh, sw, padding, &pad_ptr[0]);
lhs_ptrs = LhsPtrFill(ci, ih, iw, kh, kw, sh, sw, padding, pad_ptr);
lhs_ptrs_cache[key] = lhs_ptrs;
}

MultiThreadedLHSPackSme(ThreadPool, ci, m, kh, kw, &lhs_ptrs[0], &lhs[0], activation_src, &pad_ptr[0]);

return lhs;
return lhs_ptrs;
}

static void ConvolveSme(const size_t co, //channels out
Expand All @@ -517,7 +474,6 @@ static void ConvolveSme(const size_t co, //channels out
bool input_is_channels_last,
MLAS_THREADPOOL* ThreadPool) {

//RhsPackWeightsBiasSme() - to perform dilation increases kernel size and masks unused weights
//compute corrected dimensions of dilated kernel
const auto d_kh = ComputeKernelSize(dilationh, kh);
const auto d_kw = ComputeKernelSize(dilationw, kw);
Expand All @@ -543,12 +499,12 @@ static void ConvolveSme(const size_t co, //channels out
dim[2] = MlasDivRoundup(RequiredTiles * dim[2], dim[1] * dim[2]);

//compute new step sizes
m_step *= MlasDivRoundup(MlasDivRoundup(m, dim[1]), m_step);
n_step *= MlasDivRoundup(MlasDivRoundup(co, dim[2]), n_step);
size_t m_tile_step = m_step * MlasDivRoundup(MlasDivRoundup(m, dim[1]), m_step);
size_t n_tile_step = n_step * MlasDivRoundup(MlasDivRoundup(co, dim[2]), n_step);

//update tile iterations
dim[1] = MlasDivRoundup(m, m_step);
dim[2] = MlasDivRoundup(co, n_step);
dim[1] = MlasDivRoundup(m, m_tile_step);
dim[2] = MlasDivRoundup(co, n_tile_step);

for (size_t g = 0; g < groups; ++g) {

Expand All @@ -558,7 +514,17 @@ static void ConvolveSme(const size_t co, //channels out
result = tmp_mlas_aligned;
}

auto lhs = LhsPackImageDataSme(ci, ih, iw, d_kh, d_kw, sh, sw, padding, in, input_is_channels_last, ThreadPool);
// LHS packing data
const float* pad_data = GetOrCreatePadDataSme(ci);
std::unique_ptr<float[]> nhwc_holder;
const float* activation_src = in;
if (!input_is_channels_last) {
nhwc_holder = NChwToNhwc(1, ci, ih, iw, in, 1, 1, false, ThreadPool);
activation_src = nhwc_holder.get();
}
auto lhs_ptrs = GetOrCreateLhsPtrTableSme(ci, ih, iw, d_kh, d_kw, sh, sw, padding, pad_data);

// RHS packing data
const std::byte* rhs_data = packed_rhs ? packed_rhs + g * packed_rhs_group_stride : nullptr;
std::unique_ptr<std::byte[]> rhs_storage;
if (rhs_data == nullptr) {
Expand All @@ -577,40 +543,73 @@ static void ConvolveSme(const size_t co, //channels out
MlasTrySimpleParallel(ThreadPool, static_cast<ptrdiff_t>(dim[0] * dim[1] * dim[2]), [&](ptrdiff_t tid) {
//compute B,M,N index from iteration index
//ptrdiff_t BIdx = tid / (dim[1] * dim[2]);
ptrdiff_t MIdx = (tid % (dim[1] * dim[2])) / dim[2];
ptrdiff_t NIdx = (tid % (dim[1] * dim[2])) % dim[2];
const size_t m_idx = static_cast<size_t>((tid % (dim[1] * dim[2])) / dim[2]);
const size_t n_idx = static_cast<size_t>((tid % (dim[1] * dim[2])) % dim[2]);

// Get rhs tile, B
const size_t rhs_packed_offset =
imatmul_conv.ukernel.get_rhs_packed_offset(NIdx * n_step, d_kh * d_kw, ci);
imatmul_conv.ukernel.get_rhs_packed_offset(n_idx * n_tile_step, d_kh * d_kw, ci);

auto BTile = reinterpret_cast<const void*>(
rhs_data + rhs_packed_offset
);

// Get lhs tile, A
const size_t lhs_packed_offset =
imatmul_conv.ukernel.get_lhs_packed_offset(MIdx * m_step, d_kh * d_kw, ci);

auto ATile = reinterpret_cast<const float*>(
reinterpret_cast<const std::byte*>(lhs.get()) + lhs_packed_offset
);

auto TileSizeM = (MIdx + 1) * m_step > m ? (m - MIdx * m_step) : m_step;
auto TileSizeN = (NIdx + 1) * n_step > co ? (co - NIdx * n_step) : n_step;

// Get result tile, C
auto CTile = &reinterpret_cast<std::byte*>(result)[
MIdx * m_step * co * sizeof(float) +
NIdx * n_step * sizeof(float)];

KLEIDIAI_KERNEL_LOG(imatmul_conv.name
<< " M=" << TileSizeM << " N=" << TileSizeN
<< " k_chunk_count=" << (d_kh * d_kw) << " k_chunk_length=" << ci);
imatmul_conv.ukernel.run_imatmul(
TileSizeM, TileSizeN, d_kh * d_kw, ci, ATile, BTile, CTile, co * sizeof(float),
-std::numeric_limits<float>::max(), std::numeric_limits<float>::max()
);
// Calculate lhs tile chunks
auto tile_size_n = (n_idx + 1) * n_tile_step > co ? (co - n_idx * n_tile_step) : n_tile_step;

// Compute the starting row index of the current M tile
size_t tile_m_start = m_idx * m_tile_step;
// Actual number of rows in this tile (may be smaller for the last tile)
size_t tile_m_size = std::min(m_tile_step, m - tile_m_start);

// Query the packed LHS buffer size for exactly one m_step block.
const size_t bytes_per_m_step = kai_get_lhs_packed_size_lhs_imatmul_pack_x32p2vlx1_x32p_sme(
m_step, d_kh * d_kw, ci);

// Determine how many rows we can pack in one chunk.
size_t m_chunk = std::max<size_t>(m_step, MAX_LHS_CHUNK_BYTES / bytes_per_m_step * m_step);

// Do not exceed the number of rows available in this tile
m_chunk = std::min<size_t>(tile_m_size, m_chunk);

// Compute the exact packed buffer size for m_chunk rows.
const size_t lhs_buffer_bytes = kai_get_lhs_packed_size_lhs_imatmul_pack_x32p2vlx1_x32p_sme(
m_chunk, d_kh * d_kw, ci);

// Allocate a single reusable buffer for LHS packing
auto lhs = std::make_unique<std::byte[]>(lhs_buffer_bytes);

// Interpret the packed LHS buffer as the input A tile
// for the matrix multiplication kernel.
auto ATile = reinterpret_cast<const float*>(lhs.get());

for (size_t m_base = 0; m_base < tile_m_size; m_base += m_chunk) {
// Actual number of rows processed in this iteration.
// The last chunk may be smaller than m_chunk.
const size_t tile_size_m = std::min(m_chunk, tile_m_size - m_base);

// Pack TileSizeM rows of the LHS matrix into a temporary buffer.
kai_run_lhs_imatmul_pack_x32p2vlx1_x32p_sme(tile_size_m,
d_kh * d_kw,
ci,
lhs_ptrs.get() + (tile_m_start + m_base) * d_kh * d_kw,
reinterpret_cast<size_t>(activation_src),
Comment on lines +592 to +596
reinterpret_cast<const void*>(pad_data),
lhs.get());

// Get result tile, C
auto CTile = &reinterpret_cast<std::byte*>(result)[
(tile_m_start + m_base) * co * sizeof(float) +
n_idx * n_tile_step * sizeof(float)];

KLEIDIAI_KERNEL_LOG(imatmul_conv.name
<< " M=" << tile_size_m << " N=" << tile_size_n
<< " k_chunk_count=" << (d_kh * d_kw) << " k_chunk_length=" << ci);

imatmul_conv.ukernel.run_imatmul(tile_size_m, tile_size_n, d_kh * d_kw, ci, ATile, BTile,
CTile, co * sizeof(float), -std::numeric_limits<float>::max(),
std::numeric_limits<float>::max());
}
});

if (need_transpose) {
Expand Down
Loading
Loading