Skip to content

Commit 341216d

Browse files
committed
refactor(hpc): bf16_tile_gemm fallback delegates to the polyfill (dedup)
PR #222 added ndarray::simd::bf16_tile_gemm_16x16 by copying the F32x16 kernel out of hpc::bf16_tile_gemm::fallback_path, leaving the same kernel in two places. Collapse it: the polyfill fn is the single source of truth; the hpc AMX wrapper's fallback now calls crate::simd::bf16_tile_gemm_16x16, with the AMX TDPBF16PS tile path still layered on top. Drops the now-unused F32x16 / bf16_to_f32_batch import. Both suites pass (hpc fallback + simd_ops parity); clippy -D warnings + fmt clean. Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com> Claude-Session: https://claude.ai/code/session_01GJ4NVBSjq1w5h7RmTbVafb
1 parent afb53c2 commit 341216d

1 file changed

Lines changed: 5 additions & 33 deletions

File tree

src/hpc/bf16_tile_gemm.rs

Lines changed: 5 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,6 @@ use crate::hpc::amx_matmul::{
2121
amx_available, tile_dpbf16ps, tile_load, tile_loadconfig, tile_release, tile_store, tile_zero, vnni_pack_bf16,
2222
TileConfig,
2323
};
24-
use crate::simd::{bf16_to_f32_batch, F32x16};
2524

2625
// ═════════════════════════════════════════════════════════════════════
2726
// Public API — safe dispatching wrapper
@@ -104,39 +103,12 @@ unsafe fn amx_path(a_bf16: &[u16], b_vnni: &[u16], c: &mut [f32], k: usize) {
104103
// AVX-512 fallback (F32x16 + mul_add FMA)
105104
// ═════════════════════════════════════════════════════════════════════
106105

107-
/// Fallback: decode BF16→f32 and run a tight F32x16 GEMM with mul_add FMA.
108-
/// When AVX-512 is the compile-time baseline, this uses native __m512 FMA;
109-
/// on AVX2 it uses the emulated F32x16 = (F32x8, F32x8) pair — same logic.
106+
/// Fallback: delegate to the single source-of-truth SIMD-polyfill kernel
107+
/// [`crate::simd::bf16_tile_gemm_16x16`] (BF16→f32 decode + `F32x16` FMA). The
108+
/// `F32x16` wrapper owns the AVX-512 / AVX2 / NEON / scalar dispatch, so this
109+
/// AMX wrapper only adds the TDPBF16PS tile path on top of the same kernel.
110110
fn fallback_path(a_bf16: &[u16], b_bf16: &[u16], c: &mut [f32], k: usize) {
111-
// Decode BF16 → f32 (batch via SIMD when avx512bf16 / avx2 available)
112-
let mut a_f32 = vec![0.0f32; a_bf16.len()];
113-
let mut b_f32 = vec![0.0f32; b_bf16.len()];
114-
bf16_to_f32_batch(a_bf16, &mut a_f32);
115-
bf16_to_f32_batch(b_bf16, &mut b_f32);
116-
117-
// Tight GEMM: for each output (i,j), dot row-of-A with col-of-B via F32x16+FMA.
118-
// B is row-major [K, 16]; j-th column is b_f32[kk*16 + j] over kk=0..K.
119-
// We gather the column into a stack-sized buffer once per (i,j) pair to hit
120-
// the chunks_exact(16) + mul_add fast path on contiguous memory.
121-
for i in 0..16 {
122-
let a_row = &a_f32[i * k..i * k + k];
123-
for j in 0..16 {
124-
// Stream the column into a contiguous buffer
125-
let mut col = vec![0.0f32; k];
126-
for kk in 0..k {
127-
col[kk] = b_f32[kk * 16 + j];
128-
}
129-
130-
// Accumulate via F32x16::mul_add (FMA)
131-
let mut acc = F32x16::splat(0.0);
132-
for (ra, rb) in a_row.chunks_exact(16).zip(col.chunks_exact(16)) {
133-
let av = F32x16::from_slice(ra);
134-
let bv = F32x16::from_slice(rb);
135-
acc = av.mul_add(bv, acc);
136-
}
137-
c[i * 16 + j] += acc.reduce_sum();
138-
}
139-
}
111+
crate::simd::bf16_tile_gemm_16x16(a_bf16, b_bf16, c, k);
140112
}
141113

142114
// ═════════════════════════════════════════════════════════════════════

0 commit comments

Comments
 (0)