@@ -215,6 +215,97 @@ pub unsafe fn int8_gemm_vpdpbusd_zmm(a_u8: &[u8], b_i8: &[i8], c: &mut [i32], m:
215215 }
216216}
217217
218+ // ═════════════════════════════════════════════════════════════════════
219+ // VPDPBUSD-ymm AVX-VNNI tier (Arrow Lake / Meteor Lake U / Alder Lake)
220+ // ═════════════════════════════════════════════════════════════════════
221+
222+ /// AVX-VNNI ymm `u8 × i8 → i32` GEMM kernel for arbitrary M × N × K.
223+ ///
224+ /// One `_mm256_dpbusd_avx_epi32` instruction: 8 i32 accumulator lanes,
225+ /// each receiving the sum of 4 `u8 × i8` products = **32 MACs per
226+ /// instruction**. Half the throughput-per-instruction of the
227+ /// `_mm512_dpbusd_epi32` zmm version (which does 64 MACs); fires on
228+ /// Arrow Lake / Meteor Lake U / Alder Lake silicon that has AVX-VNNI
229+ /// but NOT AVX-512.
230+ ///
231+ /// Same B pre-packing scheme as the zmm version (quad-interleaved per
232+ /// 8-wide j-block), same K-tail and N-tail handling, just narrower.
233+ /// Mirrors the `vnni2_dot_u8_i8` shape in `simd_amx.rs` but as a
234+ /// matrix-product instead of single-row dot.
235+ ///
236+ /// Output behavior: overwrites `c` (does NOT accumulate). Caller's
237+ /// responsibility to zero `c` first if needed.
238+ ///
239+ /// # Safety
240+ /// Caller must have feature-detected `avxvnni + avx2` at runtime.
241+ #[ cfg( target_arch = "x86_64" ) ]
242+ #[ target_feature( enable = "avxvnni,avx2" ) ]
243+ pub unsafe fn int8_gemm_vpdpbusd_ymm ( a_u8 : & [ u8 ] , b_i8 : & [ i8 ] , c : & mut [ i32 ] , m : usize , n : usize , k : usize ) {
244+ use core:: arch:: x86_64:: {
245+ __m256i, _mm256_dpbusd_avx_epi32, _mm256_loadu_si256, _mm256_set1_epi32, _mm256_setzero_si256,
246+ _mm256_storeu_si256,
247+ } ;
248+
249+ let k_quads = k / 4 ;
250+ let k_tail = k % 4 ;
251+
252+ // Pre-pack scratch: 8 i32 lanes per k_quad (vs 16 in the zmm
253+ // version). Same per-lane layout: each i32 holds 4 consecutive
254+ // B K-bytes for output column j+lane.
255+ let mut b_col_quads = vec ! [ 0i32 ; k_quads. max( 1 ) * 8 ] ;
256+ let mut out_buf = [ 0i32 ; 8 ] ;
257+
258+ for j_base in ( 0 ..n) . step_by ( 8 ) {
259+ let j_count = 8 . min ( n - j_base) ;
260+
261+ for k_quad in 0 ..k_quads {
262+ let row0 = 4 * k_quad * n;
263+ let row1 = ( 4 * k_quad + 1 ) * n;
264+ let row2 = ( 4 * k_quad + 2 ) * n;
265+ let row3 = ( 4 * k_quad + 3 ) * n;
266+ for jj in 0 ..j_count {
267+ let b0 = b_i8[ row0 + j_base + jj] as u8 as u32 ;
268+ let b1 = b_i8[ row1 + j_base + jj] as u8 as u32 ;
269+ let b2 = b_i8[ row2 + j_base + jj] as u8 as u32 ;
270+ let b3 = b_i8[ row3 + j_base + jj] as u8 as u32 ;
271+ b_col_quads[ k_quad * 8 + jj] = ( b0 | ( b1 << 8 ) | ( b2 << 16 ) | ( b3 << 24 ) ) as i32 ;
272+ }
273+ for jj in j_count..8 {
274+ b_col_quads[ k_quad * 8 + jj] = 0 ;
275+ }
276+ }
277+
278+ for i in 0 ..m {
279+ let mut acc = _mm256_setzero_si256 ( ) ;
280+ let a_row_off = i * k;
281+ for k_quad in 0 ..k_quads {
282+ let a0 = a_u8[ a_row_off + 4 * k_quad] as u32 ;
283+ let a1 = a_u8[ a_row_off + 4 * k_quad + 1 ] as u32 ;
284+ let a2 = a_u8[ a_row_off + 4 * k_quad + 2 ] as u32 ;
285+ let a3 = a_u8[ a_row_off + 4 * k_quad + 3 ] as u32 ;
286+ let packed_a = a0 | ( a1 << 8 ) | ( a2 << 16 ) | ( a3 << 24 ) ;
287+ let a_v = _mm256_set1_epi32 ( packed_a as i32 ) ;
288+ let b_v = _mm256_loadu_si256 ( b_col_quads. as_ptr ( ) . add ( k_quad * 8 ) as * const __m256i ) ;
289+ acc = _mm256_dpbusd_avx_epi32 ( acc, a_v, b_v) ;
290+ }
291+ _mm256_storeu_si256 ( out_buf. as_mut_ptr ( ) as * mut __m256i , acc) ;
292+
293+ if k_tail > 0 {
294+ for kk in ( k_quads * 4 ) ..k {
295+ let a_val = a_u8[ a_row_off + kk] as i32 ;
296+ let tail_row = kk * n;
297+ for jj in 0 ..j_count {
298+ out_buf[ jj] += a_val * b_i8[ tail_row + j_base + jj] as i32 ;
299+ }
300+ }
301+ }
302+
303+ let dst_off = i * n + j_base;
304+ c[ dst_off..dst_off + j_count] . copy_from_slice ( & out_buf[ ..j_count] ) ;
305+ }
306+ }
307+ }
308+
218309// ═════════════════════════════════════════════════════════════════════
219310// Scalar fallback (i32 reference)
220311// ═════════════════════════════════════════════════════════════════════
@@ -231,6 +322,71 @@ fn fallback_path(a_u8: &[u8], b_i8: &[i8], c: &mut [i32], k: usize) {
231322 }
232323}
233324
325+ // ═════════════════════════════════════════════════════════════════════
326+ // AMX tiled helper — arbitrary 16/16/64-aligned M × N × K via 16×16 tile loop
327+ // ═════════════════════════════════════════════════════════════════════
328+
329+ /// `u8 × i8 → i32` GEMM using AMX `TDPBUSD` for arbitrary M × N × K
330+ /// shapes that satisfy `m % 16 == 0 && n % 16 == 0 && k % 64 == 0`.
331+ ///
332+ /// Tile-decomposes the M × N output into 16×16 blocks and calls
333+ /// [`int8_tile_gemm_16x16`] per (i_tile, j_tile). B sub-block extracted
334+ /// into K × 16 scratch once per j-tile, reused across all M i-tiles —
335+ /// amortizes the column gather cost.
336+ ///
337+ /// **Overwrite semantics**: `c` is written, not accumulated. Caller
338+ /// does NOT need to zero `c` beforehand. (The underlying
339+ /// `int8_tile_gemm_16x16` accumulates into its tile buffer, but we
340+ /// zero the tile buffer before each call so the per-tile write to `c`
341+ /// is pure overwrite.)
342+ ///
343+ /// # Panics
344+ /// Panics if `a_u8`, `b_i8`, or `c` are too small for the requested
345+ /// `(m, n, k)`, mirroring the boundary contract from `gemm_u8_i8`. Also
346+ /// panics in debug builds when AMX isn't OS-enabled or when the shape
347+ /// alignment constraints aren't met (production builds skip those for
348+ /// performance — callers must runtime-check
349+ /// `crate::hpc::amx_matmul::amx_available()` and the 16/16/64
350+ /// alignment themselves).
351+ pub fn int8_gemm_amx_tiled ( a_u8 : & [ u8 ] , b_i8 : & [ i8 ] , c : & mut [ i32 ] , m : usize , n : usize , k : usize ) {
352+ // Length assertions (codex P1 from PR #185 — the function reads
353+ // `b_i8` via a 16-wide window per (kk, j_tile) iteration and a_u8
354+ // via a 16-row slice per i_tile, so mismatched shapes would
355+ // trigger out-of-bounds reads without these gates).
356+ assert ! ( a_u8. len( ) >= m * k, "int8_gemm_amx_tiled: a_u8.len()={} < m*k={}" , a_u8. len( ) , m * k) ;
357+ assert ! ( b_i8. len( ) >= k * n, "int8_gemm_amx_tiled: b_i8.len()={} < k*n={}" , b_i8. len( ) , k * n) ;
358+ assert ! ( c. len( ) >= m * n, "int8_gemm_amx_tiled: c.len()={} < m*n={}" , c. len( ) , m * n) ;
359+
360+ debug_assert ! ( crate :: hpc:: amx_matmul:: amx_available( ) ) ;
361+ debug_assert_eq ! ( m % 16 , 0 , "int8_gemm_amx_tiled: M must be multiple of 16" ) ;
362+ debug_assert_eq ! ( n % 16 , 0 , "int8_gemm_amx_tiled: N must be multiple of 16" ) ;
363+ debug_assert_eq ! ( k % 64 , 0 , "int8_gemm_amx_tiled: K must be multiple of 64" ) ;
364+
365+ let mut b_tile = vec ! [ 0i8 ; k * 16 ] ;
366+ let mut tile_c = vec ! [ 0i32 ; 256 ] ;
367+
368+ for j_tile in ( 0 ..n) . step_by ( 16 ) {
369+ // Pack B[0..k, j_tile..j_tile+16] into 16-wide K-rows
370+ // (contiguous memory for int8_tile_gemm_16x16's input shape).
371+ // Safe slicing — the row..row+16 range is bounded by
372+ // `b_i8.len() >= k * n` asserted at function entry.
373+ for kk in 0 ..k {
374+ let row = kk * n + j_tile;
375+ b_tile[ kk * 16 ..( kk + 1 ) * 16 ] . copy_from_slice ( & b_i8[ row..row + 16 ] ) ;
376+ }
377+ for i_tile in ( 0 ..m) . step_by ( 16 ) {
378+ let a_tile = & a_u8[ i_tile * k..( i_tile + 16 ) * k] ;
379+ tile_c. fill ( 0 ) ;
380+ int8_tile_gemm_16x16 ( a_tile, & b_tile, & mut tile_c, k) ;
381+ // Write tile_c (16 × 16, row-major) into c (M × N, row-major).
382+ for ii in 0 ..16 {
383+ let dst_off = ( i_tile + ii) * n + j_tile;
384+ c[ dst_off..dst_off + 16 ] . copy_from_slice ( & tile_c[ ii * 16 ..( ii + 1 ) * 16 ] ) ;
385+ }
386+ }
387+ }
388+ }
389+
234390// ═════════════════════════════════════════════════════════════════════
235391// Tests
236392// ═════════════════════════════════════════════════════════════════════
@@ -370,6 +526,65 @@ mod tests {
370526 }
371527 }
372528
529+ /// Codex P1 regression on PR #185: `int8_gemm_amx_tiled` is a
530+ /// safe public function — mismatched (m, n, k) vs slice lengths
531+ /// must panic at the function boundary, not trigger UB inside
532+ /// the unsafe slice/pointer arithmetic in the inner loop. This
533+ /// test passes deliberately-undersized buffers and expects a
534+ /// panic (which `#[should_panic]` catches).
535+ #[ test]
536+ #[ should_panic( expected = "b_i8.len()" ) ]
537+ fn amx_tiled_panics_on_undersized_b ( ) {
538+ let m = 16 ;
539+ let n = 32 ;
540+ let k = 64 ;
541+ let a = vec ! [ 0u8 ; m * k] ;
542+ let b = vec ! [ 0i8 ; k * ( n - 16 ) ] ; // half a j_tile short of what's claimed
543+ let mut c = vec ! [ 0i32 ; m * n] ;
544+ // Even on non-AMX hosts the assertion fires before reaching
545+ // the (debug-asserted) amx_available() check.
546+ int8_gemm_amx_tiled ( & a, & b, & mut c, m, n, k) ;
547+ }
548+
549+ /// Direct test for the VPDPBUSD-ymm arm (AVX-VNNI tier of
550+ /// `matmul_i8_to_i32`). Same shape / bit-exactness contract as
551+ /// the zmm version's test, just on the narrower 8-wide kernel.
552+ #[ cfg( target_arch = "x86_64" ) ]
553+ #[ test]
554+ fn vpdpbusd_ymm_matches_scalar ( ) {
555+ if !std:: is_x86_feature_detected!( "avxvnni" ) {
556+ eprintln ! ( "avxvnni not detected; skipping" ) ;
557+ return ;
558+ }
559+
560+ fn ref_gemm ( a : & [ u8 ] , b : & [ i8 ] , m : usize , n : usize , k : usize ) -> Vec < i32 > {
561+ let mut c = vec ! [ 0i32 ; m * n] ;
562+ for i in 0 ..m {
563+ for kk in 0 ..k {
564+ let av = a[ i * k + kk] as i32 ;
565+ for j in 0 ..n {
566+ c[ i * n + j] += av * b[ kk * n + j] as i32 ;
567+ }
568+ }
569+ }
570+ c
571+ }
572+
573+ // Sweep shapes spanning 8-aligned, K-tail (k % 4), N-tail
574+ // (n % 8), and small shapes to exercise every code path.
575+ for ( m, n, k) in [ ( 16 , 8 , 64 ) , ( 3 , 5 , 7 ) , ( 17 , 33 , 100 ) , ( 1 , 17 , 12 ) , ( 8 , 8 , 4 ) ] {
576+ let a: Vec < u8 > = ( 0 ..m * k) . map ( |i| ( ( i * 31 + 7 ) % 256 ) as u8 ) . collect ( ) ;
577+ let b: Vec < i8 > = ( 0 ..k * n)
578+ . map ( |i| ( ( i * 17 + 3 ) % 256 ) as u8 as i8 )
579+ . collect ( ) ;
580+ let expected = ref_gemm ( & a, & b, m, n, k) ;
581+ let mut got = vec ! [ 0i32 ; m * n] ;
582+ // SAFETY: avxvnni confirmed at the top of the test.
583+ unsafe { int8_gemm_vpdpbusd_ymm ( & a, & b, & mut got, m, n, k) } ;
584+ assert_eq ! ( got, expected, "VPDPBUSD-ymm mismatch at (M={}, N={}, K={})" , m, n, k) ;
585+ }
586+ }
587+
373588 #[ test]
374589 fn vnni_pack_i8_roundtrip ( ) {
375590 // Pack then verify the VNNI layout matches the spec:
0 commit comments