Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 29 additions & 14 deletions onnxruntime/core/mlas/lib/qkv_quant_kernel_avx2.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -42,16 +42,16 @@ DequantInt4x8(const uint8_t* src, size_t col, bool per_channel, const float* sca

// Load 4 packed bytes safely without strict-aliasing / alignment UB.
// Compilers optimize memcpy of 4 bytes to a single mov instruction.
int raw_bytes;
uint32_t raw_bytes;
std::memcpy(&raw_bytes, base, sizeof(raw_bytes));
__m128i packed = _mm_cvtsi32_si128(raw_bytes);
__m128i packed = _mm_cvtsi32_si128(static_cast<int>(raw_bytes));

// Low nibbles (even columns): AND with 0x0F
__m128i lo_mask = _mm_set1_epi8(0x0F);
__m128i lo = _mm_and_si128(packed, lo_mask);

// High nibbles (odd columns): shift right 4 using 32-bit granularity
// to prevent bit bleeding across 16-bit lane boundaries, then mask.
// High nibbles (odd columns): shift right by 4 within 32-bit lanes, then mask.
// Any cross-byte bits from the shift land in the upper nibble and are discarded by the mask.
__m128i hi = _mm_and_si128(_mm_srli_epi32(packed, 4), lo_mask);

// Interleave low and high nibbles: [lo0,hi0, lo1,hi1, lo2,hi2, lo3,hi3]
Expand Down Expand Up @@ -126,27 +126,26 @@ FusedDotInt8(
acc0 = _mm256_fmadd_ps(a0, bf0, acc0);
}
} else {
__m256 scale_vec = _mm256_broadcast_ss(scales);
// Per-tensor: defer scale multiplication until after accumulation.
// sum(a[k] * b[k] * scale) = scale * sum(a[k] * b[k])
// This saves one vmulps per 8 elements in the hot loop.
for (; k < vec_end; k += 16) {
__m128i raw0 = _mm_loadl_epi64(reinterpret_cast<const __m128i*>(b_row + k));
__m256i i32_0 = _mm256_cvtepi8_epi32(raw0);
__m256 bf0 = _mm256_cvtepi32_ps(i32_0);
bf0 = _mm256_mul_ps(bf0, scale_vec);
__m256 a0 = _mm256_loadu_ps(a_row + k);
acc0 = _mm256_fmadd_ps(a0, bf0, acc0);

__m128i raw1 = _mm_loadl_epi64(reinterpret_cast<const __m128i*>(b_row + k + 8));
__m256i i32_1 = _mm256_cvtepi8_epi32(raw1);
__m256 bf1 = _mm256_cvtepi32_ps(i32_1);
bf1 = _mm256_mul_ps(bf1, scale_vec);
__m256 a1 = _mm256_loadu_ps(a_row + k + 8);
acc1 = _mm256_fmadd_ps(a1, bf1, acc1);
}
for (; k + 8 <= K; k += 8) {
__m128i raw0 = _mm_loadl_epi64(reinterpret_cast<const __m128i*>(b_row + k));
__m256i i32_0 = _mm256_cvtepi8_epi32(raw0);
__m256 bf0 = _mm256_cvtepi32_ps(i32_0);
bf0 = _mm256_mul_ps(bf0, scale_vec);
__m256 a0 = _mm256_loadu_ps(a_row + k);
acc0 = _mm256_fmadd_ps(a0, bf0, acc0);
}
Expand All @@ -161,9 +160,15 @@ FusedDotInt8(
float dot = _mm_cvtss_f32(sum4);

// Scalar tail
for (; k < K; ++k) {
float sc = per_channel ? scales[k] : scales[0];
dot += a_row[k] * static_cast<float>(b_row[k]) * sc;
if (per_channel) {
for (; k < K; ++k) {
dot += a_row[k] * static_cast<float>(b_row[k]) * scales[k];
}
} else {
for (; k < K; ++k) {
dot += a_row[k] * static_cast<float>(b_row[k]);
}
dot *= scales[0];
}
return dot;
}
Expand Down Expand Up @@ -326,7 +331,7 @@ SVGemm_Avx2(
}
}
} else {
__m256 scale_vec = _mm256_broadcast_ss(Scales);
// Per-tensor: accumulate unscaled dot products, then scale the output row once.
for (size_t k = 0; k < K; ++k) {
const int8_t* b_row = reinterpret_cast<const int8_t*>(B_bytes + k * row_bytes);
const float a_val = a_row[k];
Expand All @@ -337,15 +342,25 @@ SVGemm_Avx2(
__m128i raw = _mm_loadl_epi64(reinterpret_cast<const __m128i*>(b_row + n));
__m256i i32 = _mm256_cvtepi8_epi32(raw);
__m256 bf = _mm256_cvtepi32_ps(i32);
bf = _mm256_mul_ps(bf, scale_vec);
__m256 c_vec = _mm256_loadu_ps(c_row + n);
c_vec = _mm256_fmadd_ps(a_broadcast, bf, c_vec);
_mm256_storeu_ps(c_row + n, c_vec);
}
for (; n < N; ++n) {
c_row[n] += a_val * static_cast<float>(b_row[n]) * Scales[0];
c_row[n] += a_val * static_cast<float>(b_row[n]);
}
}

__m256 scale_vec = _mm256_broadcast_ss(Scales);
n = 0;
for (; n < vec_end_n; n += 8) {
__m256 c_vec = _mm256_loadu_ps(c_row + n);
c_vec = _mm256_mul_ps(c_vec, scale_vec);
_mm256_storeu_ps(c_row + n, c_vec);
}
for (; n < N; ++n) {
c_row[n] *= Scales[0];
}
}
} else {
// INT4 fused path
Expand Down
58 changes: 38 additions & 20 deletions onnxruntime/core/mlas/lib/qkv_quant_kernel_avx512vnni.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -83,12 +83,13 @@ QuantizeRowToU8(const float* src, uint8_t* dst, size_t len)
i = 0;
for (; i < vec_end; i += 16) {
__m512 v = _mm512_loadu_ps(src + i);
// q = round(v * inv_scale) + 128, clamped to [0, 255]
// q = (v * inv_scale) + 128, clamped to [0, 255]
__m512 scaled = _mm512_fmadd_ps(v, inv_scale_vec, zp_vec);
scaled = _mm512_roundscale_ps(scaled, _MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC);
scaled = _mm512_max_ps(scaled, min_val);
scaled = _mm512_min_ps(scaled, max_clamp);
__m512i qi = _mm512_cvtps_epi32(scaled);
// Round-to-nearest-even and convert to int32 in a single instruction
// (AVX-512 embedded rounding eliminates a separate vrndscaleps).
__m512i qi = _mm512_cvt_roundps_epi32(scaled, _MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC);
Comment on lines 85 to +92
// Pack 16 int32 -> 16 uint8
__m128i packed = _mm512_cvtepi32_epi8(qi);
_mm_storeu_si128(reinterpret_cast<__m128i*>(dst + i), packed);
Expand Down Expand Up @@ -169,9 +170,11 @@ VnniDotInt8PerTensor(

// Correction: dpbusd computed sum(a_u8 * b_s8).
// We want sum((a_u8 - 128) * b_s8) = sum(a_u8 * b_s8) - 128 * sum(b_s8)
float corrected = static_cast<float>(dot_i32) - 128.0f * static_cast<float>(b_sum_i32);
// Perform correction in int32 to preserve precision (avoids float rounding
// when |dot_i32| or |128*b_sum_i32| exceed 2^24).
int32_t corrected = dot_i32 - (128 * b_sum_i32);

return corrected * scale_a * scale_b;
return static_cast<float>(corrected) * scale_a * scale_b;
}

//
Expand Down Expand Up @@ -221,27 +224,26 @@ FusedDotInt8_Avx512(
acc0 = _mm512_fmadd_ps(a0, bf0, acc0);
}
} else {
__m512 scale_vec = _mm512_set1_ps(scales[0]);
// Per-tensor: defer scale multiplication until after accumulation.
// sum(a[k] * b[k] * scale) = scale * sum(a[k] * b[k])
// This saves one vmulps per 16 elements in the hot loop.
for (; k < vec_end; k += 32) {
__m128i raw0 = _mm_loadu_si128(reinterpret_cast<const __m128i*>(b_row + k));
__m512i i32_0 = _mm512_cvtepi8_epi32(raw0);
__m512 bf0 = _mm512_cvtepi32_ps(i32_0);
bf0 = _mm512_mul_ps(bf0, scale_vec);
__m512 a0 = _mm512_loadu_ps(a_row + k);
acc0 = _mm512_fmadd_ps(a0, bf0, acc0);

__m128i raw1 = _mm_loadu_si128(reinterpret_cast<const __m128i*>(b_row + k + 16));
__m512i i32_1 = _mm512_cvtepi8_epi32(raw1);
__m512 bf1 = _mm512_cvtepi32_ps(i32_1);
bf1 = _mm512_mul_ps(bf1, scale_vec);
__m512 a1 = _mm512_loadu_ps(a_row + k + 16);
acc1 = _mm512_fmadd_ps(a1, bf1, acc1);
}
for (; k + 16 <= K; k += 16) {
__m128i raw0 = _mm_loadu_si128(reinterpret_cast<const __m128i*>(b_row + k));
__m512i i32_0 = _mm512_cvtepi8_epi32(raw0);
__m512 bf0 = _mm512_cvtepi32_ps(i32_0);
bf0 = _mm512_mul_ps(bf0, scale_vec);
__m512 a0 = _mm512_loadu_ps(a_row + k);
acc0 = _mm512_fmadd_ps(a0, bf0, acc0);
}
Expand All @@ -251,9 +253,15 @@ FusedDotInt8_Avx512(
float dot = _mm512_reduce_add_ps(acc0);

// Scalar tail
for (; k < K; ++k) {
float sc = per_channel ? scales[k] : scales[0];
dot += a_row[k] * static_cast<float>(b_row[k]) * sc;
if (per_channel) {
for (; k < K; ++k) {
dot += a_row[k] * static_cast<float>(b_row[k]) * scales[k];
}
} else {
for (; k < K; ++k) {
dot += a_row[k] * static_cast<float>(b_row[k]);
}
dot *= scales[0];
}
return dot;
}
Expand Down Expand Up @@ -402,11 +410,11 @@ VnniMultiDot4Int8PerTensor(
bs[3] += static_cast<int32_t>(b3[k]);
}

const float zp = 128.0f;
out[0] = (static_cast<float>(dot[0]) - zp * static_cast<float>(bs[0])) * combined_scale;
out[1] = (static_cast<float>(dot[1]) - zp * static_cast<float>(bs[1])) * combined_scale;
out[2] = (static_cast<float>(dot[2]) - zp * static_cast<float>(bs[2])) * combined_scale;
out[3] = (static_cast<float>(dot[3]) - zp * static_cast<float>(bs[3])) * combined_scale;
// Zero-point correction in int32 for precision (see VnniDotInt8PerTensor).
out[0] = static_cast<float>(dot[0] - 128 * bs[0]) * combined_scale;
out[1] = static_cast<float>(dot[1] - 128 * bs[1]) * combined_scale;
out[2] = static_cast<float>(dot[2] - 128 * bs[2]) * combined_scale;
out[3] = static_cast<float>(dot[3] - 128 * bs[3]) * combined_scale;
}

// ============================================================================
Expand Down Expand Up @@ -569,7 +577,7 @@ SVGemm_Avx512Vnni(
}
}
} else {
__m512 scale_vec = _mm512_set1_ps(Scales[0]);
// Per-tensor: accumulate unscaled dot products, then scale the output row once.
for (size_t k = 0; k < K; ++k) {
const int8_t* b_row = reinterpret_cast<const int8_t*>(B_bytes + k * row_bytes);
const float a_val = a_row[k];
Expand All @@ -580,15 +588,25 @@ SVGemm_Avx512Vnni(
__m128i raw = _mm_loadu_si128(reinterpret_cast<const __m128i*>(b_row + n));
__m512i i32 = _mm512_cvtepi8_epi32(raw);
__m512 bf = _mm512_cvtepi32_ps(i32);
bf = _mm512_mul_ps(bf, scale_vec);
__m512 c_vec = _mm512_loadu_ps(c_row + n);
c_vec = _mm512_fmadd_ps(a_broadcast, bf, c_vec);
_mm512_storeu_ps(c_row + n, c_vec);
}
for (; n < N; ++n) {
c_row[n] += a_val * static_cast<float>(b_row[n]) * Scales[0];
c_row[n] += a_val * static_cast<float>(b_row[n]);
}
}

__m512 scale_vec = _mm512_set1_ps(Scales[0]);
n = 0;
for (; n < vec_end_n; n += 16) {
__m512 c_vec = _mm512_loadu_ps(c_row + n);
c_vec = _mm512_mul_ps(c_vec, scale_vec);
_mm512_storeu_ps(c_row + n, c_vec);
}
for (; n < N; ++n) {
c_row[n] *= Scales[0];
}
}
} else {
// INT4 path: 512-bit wide
Expand Down
57 changes: 43 additions & 14 deletions onnxruntime/core/mlas/lib/qkv_quant_kernel_neon.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -29,12 +29,13 @@ using namespace MlasKVQuantInternal;
namespace {

//
// Dequantize 8 INT8 values starting at `col` and scale them.
// Dequantize 8 INT8 values starting at `col`.
// Per-channel rows are always scaled. Per-tensor rows may defer scaling.
// Produces two float32x4_t (8 floats total) stored into dst.
//
inline void
DequantInt8x8_Neon(const int8_t* src, size_t col, bool per_channel,
const float* scales, float* dst)
const float* scales, bool apply_per_tensor_scale, float* dst)
{
// Load 8 int8 values
int8x8_t raw = vld1_s8(src + col);
Expand All @@ -52,7 +53,7 @@ DequantInt8x8_Neon(const int8_t* src, size_t col, bool per_channel,
float32x4_t sc_hi = vld1q_f32(scales + col + 4);
f_lo = vmulq_f32(f_lo, sc_lo);
f_hi = vmulq_f32(f_hi, sc_hi);
} else {
} else if (apply_per_tensor_scale) {
float32x4_t sc = vdupq_n_f32(scales[0]);
f_lo = vmulq_f32(f_lo, sc);
f_hi = vmulq_f32(f_hi, sc);
Expand All @@ -64,10 +65,11 @@ DequantInt8x8_Neon(const int8_t* src, size_t col, bool per_channel,

//
// Dequantize 8 INT4 values (4 packed bytes) starting at even column `col`.
// Per-channel rows are always scaled. Per-tensor rows may defer scaling.
//
inline void
DequantInt4x8_Neon(const uint8_t* src, size_t col, bool per_channel,
const float* scales, float* dst)
const float* scales, bool apply_per_tensor_scale, float* dst)
{
const uint8_t* base = src + col / 2;

Expand All @@ -94,7 +96,7 @@ DequantInt4x8_Neon(const uint8_t* src, size_t col, bool per_channel,
float32x4_t sc_hi = vld1q_f32(scales + col + 4);
f_lo = vmulq_f32(f_lo, sc_lo);
f_hi = vmulq_f32(f_hi, sc_hi);
} else {
} else if (apply_per_tensor_scale) {
float32x4_t sc = vdupq_n_f32(scales[0]);
f_lo = vmulq_f32(f_lo, sc);
f_hi = vmulq_f32(f_hi, sc);
Expand All @@ -106,14 +108,17 @@ DequantInt4x8_Neon(const uint8_t* src, size_t col, bool per_channel,

//
// Dequantize one row of length `cols` from packed quantized buffer into `dst`.
// `apply_per_tensor_scale=false` leaves per-tensor rows unscaled so callers can
// factor the single scale out of an outer accumulation loop.
//
void
DequantRow_Neon(
const void* src_raw,
float* dst,
size_t cols,
MLAS_KV_QUANT_TYPE qt,
const float* scales)
const float* scales,
bool apply_per_tensor_scale)
{
const bool int4 = IsInt4Mode(qt);
const bool per_channel = IsPerChannelMode(qt);
Expand All @@ -124,22 +129,32 @@ DequantRow_Neon(
if (!int4) {
const auto* src = static_cast<const int8_t*>(src_raw);
for (; c < vec_end; c += 8) {
DequantInt8x8_Neon(src, c, per_channel, scales, dst + c);
DequantInt8x8_Neon(src, c, per_channel, scales, apply_per_tensor_scale, dst + c);
}
for (; c < cols; ++c) {
float sc = per_channel ? scales[c] : scales[0];
dst[c] = static_cast<float>(src[c]) * sc;
if (per_channel) {
dst[c] = static_cast<float>(src[c]) * scales[c];
} else if (apply_per_tensor_scale) {
dst[c] = static_cast<float>(src[c]) * scales[0];
} else {
dst[c] = static_cast<float>(src[c]);
}
}
} else {
const auto* src = static_cast<const uint8_t*>(src_raw);
for (; c < vec_end; c += 8) {
DequantInt4x8_Neon(src, c, per_channel, scales, dst + c);
DequantInt4x8_Neon(src, c, per_channel, scales, apply_per_tensor_scale, dst + c);
}
for (; c < cols; ++c) {
uint8_t packed = src[c / 2];
int nibble = (c & 1) == 0 ? (packed & 0x0F) : ((packed >> 4) & 0x0F);
float sc = per_channel ? scales[c] : scales[0];
dst[c] = static_cast<float>(nibble - kInt4Bias) * sc;
if (per_channel) {
dst[c] = static_cast<float>(nibble - kInt4Bias) * scales[c];
} else if (apply_per_tensor_scale) {
dst[c] = static_cast<float>(nibble - kInt4Bias) * scales[0];
} else {
dst[c] = static_cast<float>(nibble - kInt4Bias);
}
}
}
}
Expand Down Expand Up @@ -174,7 +189,7 @@ QKGemm_Neon(

for (size_t n = 0; n < N; ++n) {
const uint8_t* b_row = B_bytes + n * row_bytes;
DequantRow_Neon(b_row, b_buf, K, QuantType, Scales);
DequantRow_Neon(b_row, b_buf, K, QuantType, Scales, true);

for (size_t m = 0; m < M; ++m) {
const float* a_row = A + m * lda;
Expand Down Expand Up @@ -246,6 +261,7 @@ SVGemm_Neon(
{
const size_t row_bytes = MlasKVQuantPackedRowBytes(QuantType, N);
const auto* B_bytes = static_cast<const uint8_t*>(B);
const bool per_channel = IsPerChannelMode(QuantType);

float b_stack[256];
float* b_buf = b_stack;
Expand All @@ -272,7 +288,7 @@ SVGemm_Neon(

for (size_t k = 0; k < K; ++k) {
const uint8_t* b_row_packed = B_bytes + k * row_bytes;
DequantRow_Neon(b_row_packed, b_buf, N, QuantType, Scales);
DequantRow_Neon(b_row_packed, b_buf, N, QuantType, Scales, per_channel);

const float a_val = a_row[k];
float32x4_t a_broadcast = vdupq_n_f32(a_val);
Expand All @@ -288,6 +304,19 @@ SVGemm_Neon(
c_row[n] += a_val * b_buf[n];
}
}

if (!per_channel) {
const float32x4_t scale_vec = vdupq_n_f32(Scales[0]);
n = 0;
for (; n < vec_end_n; n += 4) {
float32x4_t c_vec = vld1q_f32(c_row + n);
c_vec = vmulq_f32(c_vec, scale_vec);
vst1q_f32(c_row + n, c_vec);
}
for (; n < N; ++n) {
c_row[n] *= Scales[0];
}
}
}
}

Expand Down
Loading