Skip to content

Commit 71f1973

Browse files
authored
Merge pull request #185 from AdaWorldAPI/claude/continue-ndarray-x0Oaw
simd_int_ops, hpc: AMX TDPBUSD arm for gemm_u8_i8 slice surface
2 parents ddf0905 + c18e3fa commit 71f1973

10 files changed

Lines changed: 1036 additions & 26 deletions

File tree

Cargo.toml

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -201,6 +201,27 @@ rayon = ["dep:rayon", "std"]
201201
# cfg-dispatch in `simd.rs` remains the production path.
202202
nightly-simd = ["std"]
203203

204+
# Runtime SIMD dispatch — release-binary distribution path. Compiles all
205+
# x86_64 backends into one artifact and selects per-op kernels via
206+
# `LazyLock<fn>` trampolines that read `simd_caps()` on first call. The
207+
# `crate::simd_runtime::*` module becomes reachable under this feature;
208+
# the existing compile-time `crate::simd::*` / `crate::simd_ops::*`
209+
# cascade is unchanged (additive). Use case: shipping one binary that
210+
# adapts across heterogeneous deployment silicon (AVX-512 server +
211+
# AVX2-only laptop) from the same artifact.
212+
#
213+
# Mutually exclusive with `nightly-simd` (the portable-SIMD polyfill
214+
# replaces the architecture-specific intrinsics that the runtime
215+
# trampolines select between; they can't coexist coherently).
216+
#
217+
# Per-call overhead: ~2-3 ns indirect-call through a static fn pointer
218+
# (LazyLock fires once at first call, every subsequent call is a
219+
# pointer deref). Invisible against any SIMD op's actual work.
220+
#
221+
# See `.claude/knowledge/simd-dispatch-architecture.md` § 7.1 / Phase 5
222+
# for the design rationale.
223+
runtime-dispatch = ["std"]
224+
204225
# HPC extras: p64 palette/NARS bridge + fractal manifold.
205226
# (blake3 was previously listed here; it is now part of `std` directly
206227
# because the cognitive substrate modules under hpc/ that import blake3

src/hpc/amx_matmul.rs

Lines changed: 15 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -606,28 +606,10 @@ pub fn matmul_i8_to_i32(
606606

607607
if amx_available() && m % 16 == 0 && n % 16 == 0 && k % 64 == 0 {
608608
// Tier 1 — AMX TDPBUSD tile path: shift LHS i8 → u8 (+128),
609-
// tile-GEMM via int8_tile_gemm_16x16, subtract bias.
609+
// delegate to the shared int8_gemm_amx_tiled helper, subtract
610+
// the sign-shift bias.
610611
let a_u8: Vec<u8> = a_i8.iter().map(|&v| (v as i32 + 128) as u8).collect();
611-
612-
let mut b_tile = vec![0i8; k * 16];
613-
let mut tile_c = vec![0i32; 256];
614-
615-
for j_tile in (0..n).step_by(16) {
616-
for kk in 0..k {
617-
let row = kk * n + j_tile;
618-
b_tile[kk * 16..(kk + 1) * 16]
619-
.copy_from_slice(unsafe { core::slice::from_raw_parts(b_i8.as_ptr().add(row), 16) });
620-
}
621-
for i_tile in (0..m).step_by(16) {
622-
let a_tile = &a_u8[i_tile * k..(i_tile + 16) * k];
623-
tile_c.fill(0);
624-
crate::hpc::int8_tile_gemm::int8_tile_gemm_16x16(a_tile, &b_tile, &mut tile_c, k);
625-
for ii in 0..16 {
626-
let dst_off = (i_tile + ii) * n + j_tile;
627-
c[dst_off..dst_off + 16].copy_from_slice(&tile_c[ii * 16..(ii + 1) * 16]);
628-
}
629-
}
630-
}
612+
crate::hpc::int8_tile_gemm::int8_gemm_amx_tiled(&a_u8, &b_i8, &mut c, m, n, k);
631613
subtract_i8_to_u8_bias(&mut c, &b_i8, m, n, k);
632614
} else if cfg!(target_arch = "x86_64") && std::is_x86_feature_detected!("avx512vnni") {
633615
// Tier 2 — AVX-512 VPDPBUSD zmm: 64 MACs per instruction, no
@@ -639,9 +621,19 @@ pub fn matmul_i8_to_i32(
639621
crate::hpc::int8_tile_gemm::int8_gemm_vpdpbusd_zmm(&a_u8, &b_i8, &mut c, m, n, k);
640622
}
641623
subtract_i8_to_u8_bias(&mut c, &b_i8, m, n, k);
624+
} else if cfg!(target_arch = "x86_64") && std::is_x86_feature_detected!("avxvnni") {
625+
// Tier 3 — AVX-VNNI ymm VPDPBUSD: 32 MACs per instruction.
626+
// Arrow Lake, Meteor Lake U, Alder Lake silicon that has
627+
// AVX-VNNI but dropped AVX-512. Same sign-shift bias trick.
628+
let a_u8: Vec<u8> = a_i8.iter().map(|&v| (v as i32 + 128) as u8).collect();
629+
// SAFETY: runtime feature-detected avxvnni above.
630+
unsafe {
631+
crate::hpc::int8_tile_gemm::int8_gemm_vpdpbusd_ymm(&a_u8, &b_i8, &mut c, m, n, k);
632+
}
633+
subtract_i8_to_u8_bias(&mut c, &b_i8, m, n, k);
642634
} else {
643-
// Tier 3 — Scalar i8×i8 → i32 reference for non-x86 hosts,
644-
// pre-AVX-512 silicon, or shapes that don't satisfy either of
635+
// Tier 4 — Scalar i8×i8 → i32 reference for non-x86 hosts,
636+
// pre-AVX-VNNI silicon, or shapes that don't satisfy any of
645637
// the SIMD tiers' alignment requirements.
646638
for i in 0..m {
647639
for p in 0..k {

src/hpc/int8_tile_gemm.rs

Lines changed: 215 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -215,6 +215,97 @@ pub unsafe fn int8_gemm_vpdpbusd_zmm(a_u8: &[u8], b_i8: &[i8], c: &mut [i32], m:
215215
}
216216
}
217217

218+
// ═════════════════════════════════════════════════════════════════════
219+
// VPDPBUSD-ymm AVX-VNNI tier (Arrow Lake / Meteor Lake U / Alder Lake)
220+
// ═════════════════════════════════════════════════════════════════════
221+
222+
/// AVX-VNNI ymm `u8 × i8 → i32` GEMM kernel for arbitrary M × N × K.
223+
///
224+
/// One `_mm256_dpbusd_avx_epi32` instruction: 8 i32 accumulator lanes,
225+
/// each receiving the sum of 4 `u8 × i8` products = **32 MACs per
226+
/// instruction**. Half the throughput-per-instruction of the
227+
/// `_mm512_dpbusd_epi32` zmm version (which does 64 MACs); fires on
228+
/// Arrow Lake / Meteor Lake U / Alder Lake silicon that has AVX-VNNI
229+
/// but NOT AVX-512.
230+
///
231+
/// Same B pre-packing scheme as the zmm version (quad-interleaved per
232+
/// 8-wide j-block), same K-tail and N-tail handling, just narrower.
233+
/// Mirrors the `vnni2_dot_u8_i8` shape in `simd_amx.rs` but as a
234+
/// matrix-product instead of single-row dot.
235+
///
236+
/// Output behavior: overwrites `c` (does NOT accumulate). Caller's
237+
/// responsibility to zero `c` first if needed.
238+
///
239+
/// # Safety
240+
/// Caller must have feature-detected `avxvnni + avx2` at runtime.
241+
#[cfg(target_arch = "x86_64")]
242+
#[target_feature(enable = "avxvnni,avx2")]
243+
pub unsafe fn int8_gemm_vpdpbusd_ymm(a_u8: &[u8], b_i8: &[i8], c: &mut [i32], m: usize, n: usize, k: usize) {
244+
use core::arch::x86_64::{
245+
__m256i, _mm256_dpbusd_avx_epi32, _mm256_loadu_si256, _mm256_set1_epi32, _mm256_setzero_si256,
246+
_mm256_storeu_si256,
247+
};
248+
249+
let k_quads = k / 4;
250+
let k_tail = k % 4;
251+
252+
// Pre-pack scratch: 8 i32 lanes per k_quad (vs 16 in the zmm
253+
// version). Same per-lane layout: each i32 holds 4 consecutive
254+
// B K-bytes for output column j+lane.
255+
let mut b_col_quads = vec![0i32; k_quads.max(1) * 8];
256+
let mut out_buf = [0i32; 8];
257+
258+
for j_base in (0..n).step_by(8) {
259+
let j_count = 8.min(n - j_base);
260+
261+
for k_quad in 0..k_quads {
262+
let row0 = 4 * k_quad * n;
263+
let row1 = (4 * k_quad + 1) * n;
264+
let row2 = (4 * k_quad + 2) * n;
265+
let row3 = (4 * k_quad + 3) * n;
266+
for jj in 0..j_count {
267+
let b0 = b_i8[row0 + j_base + jj] as u8 as u32;
268+
let b1 = b_i8[row1 + j_base + jj] as u8 as u32;
269+
let b2 = b_i8[row2 + j_base + jj] as u8 as u32;
270+
let b3 = b_i8[row3 + j_base + jj] as u8 as u32;
271+
b_col_quads[k_quad * 8 + jj] = (b0 | (b1 << 8) | (b2 << 16) | (b3 << 24)) as i32;
272+
}
273+
for jj in j_count..8 {
274+
b_col_quads[k_quad * 8 + jj] = 0;
275+
}
276+
}
277+
278+
for i in 0..m {
279+
let mut acc = _mm256_setzero_si256();
280+
let a_row_off = i * k;
281+
for k_quad in 0..k_quads {
282+
let a0 = a_u8[a_row_off + 4 * k_quad] as u32;
283+
let a1 = a_u8[a_row_off + 4 * k_quad + 1] as u32;
284+
let a2 = a_u8[a_row_off + 4 * k_quad + 2] as u32;
285+
let a3 = a_u8[a_row_off + 4 * k_quad + 3] as u32;
286+
let packed_a = a0 | (a1 << 8) | (a2 << 16) | (a3 << 24);
287+
let a_v = _mm256_set1_epi32(packed_a as i32);
288+
let b_v = _mm256_loadu_si256(b_col_quads.as_ptr().add(k_quad * 8) as *const __m256i);
289+
acc = _mm256_dpbusd_avx_epi32(acc, a_v, b_v);
290+
}
291+
_mm256_storeu_si256(out_buf.as_mut_ptr() as *mut __m256i, acc);
292+
293+
if k_tail > 0 {
294+
for kk in (k_quads * 4)..k {
295+
let a_val = a_u8[a_row_off + kk] as i32;
296+
let tail_row = kk * n;
297+
for jj in 0..j_count {
298+
out_buf[jj] += a_val * b_i8[tail_row + j_base + jj] as i32;
299+
}
300+
}
301+
}
302+
303+
let dst_off = i * n + j_base;
304+
c[dst_off..dst_off + j_count].copy_from_slice(&out_buf[..j_count]);
305+
}
306+
}
307+
}
308+
218309
// ═════════════════════════════════════════════════════════════════════
219310
// Scalar fallback (i32 reference)
220311
// ═════════════════════════════════════════════════════════════════════
@@ -231,6 +322,71 @@ fn fallback_path(a_u8: &[u8], b_i8: &[i8], c: &mut [i32], k: usize) {
231322
}
232323
}
233324

325+
// ═════════════════════════════════════════════════════════════════════
326+
// AMX tiled helper — arbitrary 16/16/64-aligned M × N × K via 16×16 tile loop
327+
// ═════════════════════════════════════════════════════════════════════
328+
329+
/// `u8 × i8 → i32` GEMM using AMX `TDPBUSD` for arbitrary M × N × K
330+
/// shapes that satisfy `m % 16 == 0 && n % 16 == 0 && k % 64 == 0`.
331+
///
332+
/// Tile-decomposes the M × N output into 16×16 blocks and calls
333+
/// [`int8_tile_gemm_16x16`] per (i_tile, j_tile). B sub-block extracted
334+
/// into K × 16 scratch once per j-tile, reused across all M i-tiles —
335+
/// amortizes the column gather cost.
336+
///
337+
/// **Overwrite semantics**: `c` is written, not accumulated. Caller
338+
/// does NOT need to zero `c` beforehand. (The underlying
339+
/// `int8_tile_gemm_16x16` accumulates into its tile buffer, but we
340+
/// zero the tile buffer before each call so the per-tile write to `c`
341+
/// is pure overwrite.)
342+
///
343+
/// # Panics
344+
/// Panics if `a_u8`, `b_i8`, or `c` are too small for the requested
345+
/// `(m, n, k)`, mirroring the boundary contract from `gemm_u8_i8`. Also
346+
/// panics in debug builds when AMX isn't OS-enabled or when the shape
347+
/// alignment constraints aren't met (production builds skip those for
348+
/// performance — callers must runtime-check
349+
/// `crate::hpc::amx_matmul::amx_available()` and the 16/16/64
350+
/// alignment themselves).
351+
pub fn int8_gemm_amx_tiled(a_u8: &[u8], b_i8: &[i8], c: &mut [i32], m: usize, n: usize, k: usize) {
352+
// Length assertions (codex P1 from PR #185 — the function reads
353+
// `b_i8` via a 16-wide window per (kk, j_tile) iteration and a_u8
354+
// via a 16-row slice per i_tile, so mismatched shapes would
355+
// trigger out-of-bounds reads without these gates).
356+
assert!(a_u8.len() >= m * k, "int8_gemm_amx_tiled: a_u8.len()={} < m*k={}", a_u8.len(), m * k);
357+
assert!(b_i8.len() >= k * n, "int8_gemm_amx_tiled: b_i8.len()={} < k*n={}", b_i8.len(), k * n);
358+
assert!(c.len() >= m * n, "int8_gemm_amx_tiled: c.len()={} < m*n={}", c.len(), m * n);
359+
360+
debug_assert!(crate::hpc::amx_matmul::amx_available());
361+
debug_assert_eq!(m % 16, 0, "int8_gemm_amx_tiled: M must be multiple of 16");
362+
debug_assert_eq!(n % 16, 0, "int8_gemm_amx_tiled: N must be multiple of 16");
363+
debug_assert_eq!(k % 64, 0, "int8_gemm_amx_tiled: K must be multiple of 64");
364+
365+
let mut b_tile = vec![0i8; k * 16];
366+
let mut tile_c = vec![0i32; 256];
367+
368+
for j_tile in (0..n).step_by(16) {
369+
// Pack B[0..k, j_tile..j_tile+16] into 16-wide K-rows
370+
// (contiguous memory for int8_tile_gemm_16x16's input shape).
371+
// Safe slicing — the row..row+16 range is bounded by
372+
// `b_i8.len() >= k * n` asserted at function entry.
373+
for kk in 0..k {
374+
let row = kk * n + j_tile;
375+
b_tile[kk * 16..(kk + 1) * 16].copy_from_slice(&b_i8[row..row + 16]);
376+
}
377+
for i_tile in (0..m).step_by(16) {
378+
let a_tile = &a_u8[i_tile * k..(i_tile + 16) * k];
379+
tile_c.fill(0);
380+
int8_tile_gemm_16x16(a_tile, &b_tile, &mut tile_c, k);
381+
// Write tile_c (16 × 16, row-major) into c (M × N, row-major).
382+
for ii in 0..16 {
383+
let dst_off = (i_tile + ii) * n + j_tile;
384+
c[dst_off..dst_off + 16].copy_from_slice(&tile_c[ii * 16..(ii + 1) * 16]);
385+
}
386+
}
387+
}
388+
}
389+
234390
// ═════════════════════════════════════════════════════════════════════
235391
// Tests
236392
// ═════════════════════════════════════════════════════════════════════
@@ -370,6 +526,65 @@ mod tests {
370526
}
371527
}
372528

529+
/// Codex P1 regression on PR #185: `int8_gemm_amx_tiled` is a
530+
/// safe public function — mismatched (m, n, k) vs slice lengths
531+
/// must panic at the function boundary, not trigger UB inside
532+
/// the unsafe slice/pointer arithmetic in the inner loop. This
533+
/// test passes deliberately-undersized buffers and expects a
534+
/// panic (which `#[should_panic]` catches).
535+
#[test]
536+
#[should_panic(expected = "b_i8.len()")]
537+
fn amx_tiled_panics_on_undersized_b() {
538+
let m = 16;
539+
let n = 32;
540+
let k = 64;
541+
let a = vec![0u8; m * k];
542+
let b = vec![0i8; k * (n - 16)]; // half a j_tile short of what's claimed
543+
let mut c = vec![0i32; m * n];
544+
// Even on non-AMX hosts the assertion fires before reaching
545+
// the (debug-asserted) amx_available() check.
546+
int8_gemm_amx_tiled(&a, &b, &mut c, m, n, k);
547+
}
548+
549+
/// Direct test for the VPDPBUSD-ymm arm (AVX-VNNI tier of
550+
/// `matmul_i8_to_i32`). Same shape / bit-exactness contract as
551+
/// the zmm version's test, just on the narrower 8-wide kernel.
552+
#[cfg(target_arch = "x86_64")]
553+
#[test]
554+
fn vpdpbusd_ymm_matches_scalar() {
555+
if !std::is_x86_feature_detected!("avxvnni") {
556+
eprintln!("avxvnni not detected; skipping");
557+
return;
558+
}
559+
560+
fn ref_gemm(a: &[u8], b: &[i8], m: usize, n: usize, k: usize) -> Vec<i32> {
561+
let mut c = vec![0i32; m * n];
562+
for i in 0..m {
563+
for kk in 0..k {
564+
let av = a[i * k + kk] as i32;
565+
for j in 0..n {
566+
c[i * n + j] += av * b[kk * n + j] as i32;
567+
}
568+
}
569+
}
570+
c
571+
}
572+
573+
// Sweep shapes spanning 8-aligned, K-tail (k % 4), N-tail
574+
// (n % 8), and small shapes to exercise every code path.
575+
for (m, n, k) in [(16, 8, 64), (3, 5, 7), (17, 33, 100), (1, 17, 12), (8, 8, 4)] {
576+
let a: Vec<u8> = (0..m * k).map(|i| ((i * 31 + 7) % 256) as u8).collect();
577+
let b: Vec<i8> = (0..k * n)
578+
.map(|i| ((i * 17 + 3) % 256) as u8 as i8)
579+
.collect();
580+
let expected = ref_gemm(&a, &b, m, n, k);
581+
let mut got = vec![0i32; m * n];
582+
// SAFETY: avxvnni confirmed at the top of the test.
583+
unsafe { int8_gemm_vpdpbusd_ymm(&a, &b, &mut got, m, n, k) };
584+
assert_eq!(got, expected, "VPDPBUSD-ymm mismatch at (M={}, N={}, K={})", m, n, k);
585+
}
586+
}
587+
373588
#[test]
374589
fn vnni_pack_i8_roundtrip() {
375590
// Pack then verify the VNNI layout matches the spec:

src/lib.rs

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -313,6 +313,15 @@ pub mod simd_ops;
313313
#[allow(clippy::all, missing_docs, dead_code, unused_variables, unused_imports)]
314314
pub mod simd_half;
315315

316+
/// Runtime SIMD dispatch — release-binary distribution path.
317+
///
318+
/// Gated under `--features runtime-dispatch`. Mutually exclusive with
319+
/// `nightly-simd` (the cfg in `simd_runtime/mod.rs` enforces this with
320+
/// a `compile_error!`). See `.claude/knowledge/simd-dispatch-architecture.md`
321+
/// § 7.1 / Phase 5 for the design.
322+
#[cfg(feature = "runtime-dispatch")]
323+
pub mod simd_runtime;
324+
316325
/// Pluggable linear algebra backends (native SIMD, MKL, OpenBLAS).
317326
#[cfg(feature = "std")]
318327
pub mod backend;

src/simd_int_ops.rs

Lines changed: 46 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -255,9 +255,31 @@ pub fn gemm_u8_i8(a: &[u8], b: &[i8], c: &mut [i32], m: usize, n: usize, k: usiz
255255
assert!(b.len() >= k * n, "gemm_u8_i8: b.len()={} < k*n={}", b.len(), k * n);
256256
assert!(c.len() >= m * n, "gemm_u8_i8: c.len()={} < m*n={}", c.len(), m * n);
257257

258-
// Compile-time dispatch chain. Exactly one arm survives per build;
259-
// the others are stripped by `#[cfg]` so the compiler emits a direct
260-
// call to the chosen kernel with no runtime branch.
258+
// Tier 0 — runtime AMX check. AMX is a different feature class than
259+
// the rest of the dispatch chain: it requires CPUID + XCR0 + a Linux
260+
// `prctl(ARCH_REQ_XCOMP_PERM, 18)` to be granted, none of which fit
261+
// a `target_feature` compile-time gate. The check is one CPUID +
262+
// one XGETBV + one prctl (idempotent, cached after first call). On
263+
// aligned shapes (16/16/64) this dispatches to TDPBUSD via the
264+
// shared `int8_gemm_amx_tiled` helper — 16 384 MACs per instruction
265+
// vs VPDPBUSD-zmm's 64. Since `gemm_u8_i8` is u8×i8 natively (no
266+
// sign-shift bias needed), the AMX path is a direct call with no
267+
// bias correction — simpler than `matmul_i8_to_i32`'s i8×i8 path.
268+
#[cfg(target_arch = "x86_64")]
269+
{
270+
if crate::hpc::amx_matmul::amx_available()
271+
&& m.is_multiple_of(16)
272+
&& n.is_multiple_of(16)
273+
&& k.is_multiple_of(64)
274+
{
275+
crate::hpc::int8_tile_gemm::int8_gemm_amx_tiled(a, b, c, m, n, k);
276+
return;
277+
}
278+
}
279+
280+
// Compile-time dispatch chain (tiers 1-3). Exactly one arm survives
281+
// per build; the others are stripped by `#[cfg]` so the compiler
282+
// emits a direct call to the chosen kernel with no runtime branch.
261283

262284
#[cfg(all(target_arch = "x86_64", target_feature = "avx512vnni"))]
263285
{
@@ -731,4 +753,25 @@ mod tests {
731753
);
732754
}
733755
}
756+
757+
/// Exercises the AMX dispatch tier added on top of `gemm_u8_i8`'s
758+
/// compile-time cascade. On AMX-enabled silicon (Sapphire Rapids+
759+
/// with the right OS prctl), 16/16/64-aligned shapes go through
760+
/// TDPBUSD via `int8_gemm_amx_tiled`. Anywhere else this falls back
761+
/// to the compile-time cascade — the assertion still holds because
762+
/// the scalar reference is exact integer arithmetic.
763+
#[test]
764+
fn gemm_u8_i8_amx_aligned_32x32x128() {
765+
let m = 32; // 2 × 16-wide M-tiles
766+
let n = 32; // 2 × 16-wide N-tiles
767+
let k = 128; // 2 × 64-wide K-blocks per tile
768+
let a: Vec<u8> = (0..m * k).map(|i| ((i * 13 + 7) % 256) as u8).collect();
769+
let b: Vec<i8> = (0..k * n)
770+
.map(|i| ((i * 19 + 11) % 256) as u8 as i8)
771+
.collect();
772+
let expected = ref_gemm_u8_i8(&a, &b, m, n, k);
773+
let mut c = vec![0i32; m * n];
774+
gemm_u8_i8(&a, &b, &mut c, m, n, k);
775+
assert_eq!(c, expected, "gemm_u8_i8 AMX path mismatch");
776+
}
734777
}

0 commit comments

Comments
 (0)