diff --git a/libDF/src/lib.rs b/libDF/src/lib.rs index 7ab568856..a8f8d361b 100644 --- a/libDF/src/lib.rs +++ b/libDF/src/lib.rs @@ -221,7 +221,9 @@ impl DFState { } pub fn apply_mask(&self, output: &mut [Complex32], gains: &[f32]) { - apply_interp_band_gain(output, gains, &self.erb) + // apply_band_gain is the Complex32 specialisation of apply_interp_band_gain + // and carries a SIMD-vectorised inner loop on wasm32. + apply_band_gain(output, gains, &self.erb) } } @@ -242,6 +244,144 @@ pub fn band_mean_norm_freq(xs: &[Complex32], xout: &mut [f32], state: &mut [f32] } pub fn band_mean_norm_erb(xs: &mut [f32], state: &mut [f32], alpha: f32) { + debug_assert_eq!(xs.len(), state.len()); + band_mean_norm_erb_inner(xs, state, alpha); +} + +pub fn band_unit_norm(xs: &mut [Complex32], state: &mut [f32], alpha: f32) { + debug_assert_eq!(xs.len(), state.len()); + band_unit_norm_inner(xs, state, alpha); +} + +/// Band unit norm, but with transposed output type. I.e. out contains first all real elements, +/// followed by all imaginary elements. This memory layout is different from Complex32 slice which +/// contains real and imaginary part as interleaved values. +pub fn band_unit_norm_t(xs: &[Complex32], state: &mut [f32], alpha: f32, out: &mut [f32]) { + debug_assert_eq!(xs.len(), state.len()); + debug_assert_eq!(xs.len(), out.len() / 2); + let (o_re, o_im) = out.split_at_mut(xs.len()); + band_unit_norm_t_inner(xs, state, alpha, o_re, o_im); +} + +pub fn compute_band_corr(out: &mut [f32], x: &[Complex32], p: &[Complex32], erb_fb: &[usize]) { + for y in out.iter_mut() { + *y = 0.0; + } + debug_assert_eq!(erb_fb.len(), out.len()); + debug_assert_eq!(x.len(), p.len()); + + // Each Complex32 occupies 2 contiguous f32 (re, im). Reinterpret the slices + // as flat &[f32] of length 2*N so we can vectorize with f32x4 loads. + // SAFETY: Complex32 is #[repr(C)] { re: f32, im: f32 } -> 8 bytes, alignment 4, + // identical to two contiguous f32. Length is exactly 2 * x.len(). + let xf: &[f32] = + unsafe { core::slice::from_raw_parts(x.as_ptr() as *const f32, x.len() * 2) }; + let pf: &[f32] = + unsafe { core::slice::from_raw_parts(p.as_ptr() as *const f32, p.len() * 2) }; + + let mut bcsum = 0usize; + for (&band_size, out_b) in erb_fb.iter().zip(out.iter_mut()) { + let k = 1.0f32 / band_size as f32; + let f_start = bcsum * 2; + let f_len = band_size * 2; + let xb = &xf[f_start..f_start + f_len]; + let pb = &pf[f_start..f_start + f_len]; + // sum := sum over band of x[i].re*p[i].re + x[i].im*p[i].im + // == sum over flattened pairs of xb[2j]*pb[2j] + xb[2j+1]*pb[2j+1] + // == sum_lanes( sum over 4-wide chunks of xb[..]*pb[..] ) + let sum: f32 = compute_band_corr_inner(xb, pb); + *out_b = sum * k; + bcsum += band_size; + } +} + +#[cfg(target_arch = "wasm32")] +#[inline] +fn compute_band_corr_inner(xb: &[f32], pb: &[f32]) -> f32 { + use core::arch::wasm32::*; + debug_assert_eq!(xb.len(), pb.len()); + let n = xb.len(); + let n4 = n & !3; // round down to multiple of 4 + let mut acc = f32x4_splat(0.0); + let xp = xb.as_ptr(); + let pp = pb.as_ptr(); + let mut i = 0usize; + while i < n4 { + // SAFETY: xp/pp are aligned to f32 (4 bytes); v128_load uses unaligned semantics. + // We bounds-check via i < n4 <= n == xb.len() == pb.len(). + unsafe { + let xv = v128_load(xp.add(i) as *const v128); + let pv = v128_load(pp.add(i) as *const v128); + let prod = f32x4_mul(xv, pv); + acc = f32x4_add(acc, prod); + } + i += 4; + } + // Horizontal reduce the 4 lanes. + let mut sum = f32x4_extract_lane::<0>(acc) + + f32x4_extract_lane::<1>(acc) + + f32x4_extract_lane::<2>(acc) + + f32x4_extract_lane::<3>(acc); + // Tail: 0..3 leftover f32 (i.e. 0 or 1 trailing complex pair if band_size is odd). + while i < n { + sum += unsafe { *xp.add(i) * *pp.add(i) }; + i += 1; + } + sum +} + +#[cfg(not(target_arch = "wasm32"))] +#[inline] +fn compute_band_corr_inner(xb: &[f32], pb: &[f32]) -> f32 { + debug_assert_eq!(xb.len(), pb.len()); + let mut sum = 0.0f32; + for (a, b) in xb.iter().zip(pb.iter()) { + sum += a * b; + } + sum +} + +// Element-wise IIR mean-norm: state[i] = x[i]*(1-α) + state[i]*α; x[i] = (x[i] - state[i])/40. +// Per-bin independent (no recurrence between bins) — straightforward SIMD. +#[cfg(target_arch = "wasm32")] +#[inline] +fn band_mean_norm_erb_inner(xs: &mut [f32], state: &mut [f32], alpha: f32) { + use core::arch::wasm32::*; + debug_assert_eq!(xs.len(), state.len()); + let n = xs.len(); + let n4 = n & !3; + let one_minus_a = f32x4_splat(1.0 - alpha); + let alpha_v = f32x4_splat(alpha); + let inv40 = f32x4_splat(1.0 / 40.0); + let xp = xs.as_mut_ptr(); + let sp = state.as_mut_ptr(); + let mut i = 0usize; + while i < n4 { + // SAFETY: i < n4 <= n == xs.len() == state.len(). v128_load takes 16 bytes + // (4 f32). xp/sp are aligned to f32 (4 bytes); v128_load uses unaligned semantics. + unsafe { + let xv = v128_load(xp.add(i) as *const v128); + let sv = v128_load(sp.add(i) as *const v128); + let new_s = f32x4_add(f32x4_mul(xv, one_minus_a), f32x4_mul(sv, alpha_v)); + v128_store(sp.add(i) as *mut v128, new_s); + let x_norm = f32x4_mul(f32x4_sub(xv, new_s), inv40); + v128_store(xp.add(i) as *mut v128, x_norm); + } + i += 4; + } + while i < n { + unsafe { + let new_s = *xp.add(i) * (1.0 - alpha) + *sp.add(i) * alpha; + *sp.add(i) = new_s; + *xp.add(i) = (*xp.add(i) - new_s) / 40.0; + } + i += 1; + } +} + +#[cfg(not(target_arch = "wasm32"))] +#[inline] +fn band_mean_norm_erb_inner(xs: &mut [f32], state: &mut [f32], alpha: f32) { debug_assert_eq!(xs.len(), state.len()); for (x, s) in xs.iter_mut().zip(state.iter_mut()) { *s = *x * (1. - alpha) + *s * alpha; @@ -250,21 +390,294 @@ pub fn band_mean_norm_erb(xs: &mut [f32], state: &mut [f32], alpha: f32) { } } -pub fn band_unit_norm(xs: &mut [Complex32], state: &mut [f32], alpha: f32) { +// Multiply every f32 lane in `xs` by scalar `k`, in place. +#[cfg(target_arch = "wasm32")] +#[inline] +fn f32_scale_inplace(xs: &mut [f32], k: f32) { + use core::arch::wasm32::*; + let n = xs.len(); + let n4 = n & !3; + let kv = f32x4_splat(k); + let xp = xs.as_mut_ptr(); + let mut i = 0usize; + while i < n4 { + unsafe { + let xv = v128_load(xp.add(i) as *const v128); + v128_store(xp.add(i) as *mut v128, f32x4_mul(xv, kv)); + } + i += 4; + } + while i < n { + unsafe { + *xp.add(i) *= k; + } + i += 1; + } +} + +#[cfg(not(target_arch = "wasm32"))] +#[inline] +fn f32_scale_inplace(xs: &mut [f32], k: f32) { + for x in xs.iter_mut() { + *x *= k; + } +} + +// Element-wise multiply: xs[i] *= ws[i] for the whole slice, in place. +#[cfg(target_arch = "wasm32")] +#[inline] +fn f32_mul_inplace(xs: &mut [f32], ws: &[f32]) { + use core::arch::wasm32::*; + debug_assert_eq!(xs.len(), ws.len()); + let n = xs.len(); + let n4 = n & !3; + let xp = xs.as_mut_ptr(); + let wp = ws.as_ptr(); + let mut i = 0usize; + while i < n4 { + unsafe { + let xv = v128_load(xp.add(i) as *const v128); + let wv = v128_load(wp.add(i) as *const v128); + v128_store(xp.add(i) as *mut v128, f32x4_mul(xv, wv)); + } + i += 4; + } + while i < n { + unsafe { + *xp.add(i) *= *wp.add(i); + } + i += 1; + } +} + +#[cfg(not(target_arch = "wasm32"))] +#[inline] +fn f32_mul_inplace(xs: &mut [f32], ws: &[f32]) { + debug_assert_eq!(xs.len(), ws.len()); + for (x, &w) in xs.iter_mut().zip(ws.iter()) { + *x *= w; + } +} + +// Three-slice element-wise add: out[i] = a[i] + b[i]. +#[cfg(target_arch = "wasm32")] +#[inline] +fn f32_add_to(a: &[f32], b: &[f32], out: &mut [f32]) { + use core::arch::wasm32::*; + debug_assert_eq!(a.len(), b.len()); + debug_assert_eq!(a.len(), out.len()); + let n = a.len(); + let n4 = n & !3; + let ap = a.as_ptr(); + let bp = b.as_ptr(); + let op = out.as_mut_ptr(); + let mut i = 0usize; + while i < n4 { + unsafe { + let av = v128_load(ap.add(i) as *const v128); + let bv = v128_load(bp.add(i) as *const v128); + v128_store(op.add(i) as *mut v128, f32x4_add(av, bv)); + } + i += 4; + } + while i < n { + unsafe { + *op.add(i) = *ap.add(i) + *bp.add(i); + } + i += 1; + } +} + +#[cfg(not(target_arch = "wasm32"))] +#[inline] +fn f32_add_to(a: &[f32], b: &[f32], out: &mut [f32]) { + debug_assert_eq!(a.len(), b.len()); + debug_assert_eq!(a.len(), out.len()); + for ((&x, &y), o) in a.iter().zip(b.iter()).zip(out.iter_mut()) { + *o = x + y; + } +} + +// In-place element-wise add: xs[i] += ys[i]. +#[cfg(target_arch = "wasm32")] +#[inline] +fn f32_add_inplace(xs: &mut [f32], ys: &[f32]) { + use core::arch::wasm32::*; + debug_assert_eq!(xs.len(), ys.len()); + let n = xs.len(); + let n4 = n & !3; + let xp = xs.as_mut_ptr(); + let yp = ys.as_ptr(); + let mut i = 0usize; + while i < n4 { + unsafe { + let xv = v128_load(xp.add(i) as *const v128); + let yv = v128_load(yp.add(i) as *const v128); + v128_store(xp.add(i) as *mut v128, f32x4_add(xv, yv)); + } + i += 4; + } + while i < n { + unsafe { + *xp.add(i) += *yp.add(i); + } + i += 1; + } +} + +#[cfg(not(target_arch = "wasm32"))] +#[inline] +fn f32_add_inplace(xs: &mut [f32], ys: &[f32]) { + debug_assert_eq!(xs.len(), ys.len()); + for (x, &y) in xs.iter_mut().zip(ys.iter()) { + *x += y; + } +} + +// IIR per-bin unit-norm on interleaved Complex32: +// state[i] = sqrt(re[i]^2 + im[i]^2) * (1 - α) + state[i] * α; +// xs[i] /= sqrt(state[i]) (Complex32 / f32 = each component / f32) +// +// SIMD path processes 4 Complex32 per iteration. The interleaved layout +// [re0,im0,re1,im1,re2,im2,re3,im3] is loaded as two v128s, de-interleaved +// via i32x4_shuffle into pure-real and pure-imag vectors so the norm can be +// computed lane-wise. The normalisation step then divides each Complex32 +// component by sqrt(state[i]) by re-interleaving the divisor. +#[cfg(target_arch = "wasm32")] +#[inline] +fn band_unit_norm_inner(xs: &mut [Complex32], state: &mut [f32], alpha: f32) { + use core::arch::wasm32::*; debug_assert_eq!(xs.len(), state.len()); + let n = xs.len(); + let n4 = n & !3; + let one_minus_a = f32x4_splat(1.0 - alpha); + let alpha_v = f32x4_splat(alpha); + let xf = xs.as_mut_ptr() as *mut f32; + let sp = state.as_mut_ptr(); + let mut i = 0usize; + while i < n4 { + // SAFETY: i < n4 <= n, and Complex32 is #[repr(C)] {re: f32, im: f32}, + // so xs as &mut [f32] of length 2N is valid. v128_load is unaligned. + unsafe { + let lo = v128_load(xf.add(i * 2) as *const v128); + let hi = v128_load(xf.add(i * 2 + 4) as *const v128); + // De-interleave: re_v = [re0, re1, re2, re3], im_v = [im0, im1, im2, im3] + let re_v = i32x4_shuffle::<0, 2, 4, 6>(lo, hi); + let im_v = i32x4_shuffle::<1, 3, 5, 7>(lo, hi); + // norm = sqrt(re² + im²) (note: this is (re²+im²).sqrt(), not libm hypot) + let norm_sq = f32x4_add(f32x4_mul(re_v, re_v), f32x4_mul(im_v, im_v)); + let norm_v = f32x4_sqrt(norm_sq); + // state update + let sv = v128_load(sp.add(i) as *const v128); + let new_s = f32x4_add(f32x4_mul(norm_v, one_minus_a), f32x4_mul(sv, alpha_v)); + v128_store(sp.add(i) as *mut v128, new_s); + // xs /= sqrt(state): build duplicated divisor per Complex32 + // for lo: [sqrt_s0, sqrt_s0, sqrt_s1, sqrt_s1] + // for hi: [sqrt_s2, sqrt_s2, sqrt_s3, sqrt_s3] + let sqrt_s = f32x4_sqrt(new_s); + let div_lo = i32x4_shuffle::<0, 0, 1, 1>(sqrt_s, sqrt_s); + let div_hi = i32x4_shuffle::<2, 2, 3, 3>(sqrt_s, sqrt_s); + v128_store(xf.add(i * 2) as *mut v128, f32x4_div(lo, div_lo)); + v128_store(xf.add(i * 2 + 4) as *mut v128, f32x4_div(hi, div_hi)); + } + i += 4; + } + // Tail: 0..3 trailing Complex32. Use the SAME (re²+im²).sqrt() as the SIMD + // path (NOT Complex32::norm() which is libm hypot) so vectorised + tail + // produce identical results across the full length. + while i < n { + unsafe { + let xi_re = *xf.add(i * 2); + let xi_im = *xf.add(i * 2 + 1); + let norm = (xi_re * xi_re + xi_im * xi_im).sqrt(); + let new_s = norm * (1.0 - alpha) + *sp.add(i) * alpha; + *sp.add(i) = new_s; + let sqrt_s = new_s.sqrt(); + *xf.add(i * 2) = xi_re / sqrt_s; + *xf.add(i * 2 + 1) = xi_im / sqrt_s; + } + i += 1; + } +} + +#[cfg(not(target_arch = "wasm32"))] +#[inline] +fn band_unit_norm_inner(xs: &mut [Complex32], state: &mut [f32], alpha: f32) { for (x, s) in xs.iter_mut().zip(state.iter_mut()) { *s = x.norm() * (1. - alpha) + *s * alpha; *x /= s.sqrt(); } } -/// Band unit norm, but with transposed output type. I.e. out contains first all real elements, -/// followed by all imaginary elements. This memory layout is different from Complex32 slice which -/// contains real and imaginary part as interleaved values. -pub fn band_unit_norm_t(xs: &[Complex32], state: &mut [f32], alpha: f32, out: &mut [f32]) { +// Same IIR norm as band_unit_norm but writes to o_re / o_im split halves of +// the output (xs read-only). The output halves are CONTIGUOUS so no +// re-interleave step is needed for the divide — simpler than band_unit_norm. +#[cfg(target_arch = "wasm32")] +#[inline] +fn band_unit_norm_t_inner( + xs: &[Complex32], + state: &mut [f32], + alpha: f32, + o_re: &mut [f32], + o_im: &mut [f32], +) { + use core::arch::wasm32::*; debug_assert_eq!(xs.len(), state.len()); - debug_assert_eq!(xs.len(), out.len() / 2); - let (o_re, o_im) = out.split_at_mut(xs.len()); + debug_assert_eq!(xs.len(), o_re.len()); + debug_assert_eq!(xs.len(), o_im.len()); + let n = xs.len(); + let n4 = n & !3; + let one_minus_a = f32x4_splat(1.0 - alpha); + let alpha_v = f32x4_splat(alpha); + let xf = xs.as_ptr() as *const f32; + let sp = state.as_mut_ptr(); + let rp = o_re.as_mut_ptr(); + let ip = o_im.as_mut_ptr(); + let mut i = 0usize; + while i < n4 { + unsafe { + let lo = v128_load(xf.add(i * 2) as *const v128); + let hi = v128_load(xf.add(i * 2 + 4) as *const v128); + let re_v = i32x4_shuffle::<0, 2, 4, 6>(lo, hi); + let im_v = i32x4_shuffle::<1, 3, 5, 7>(lo, hi); + let norm_sq = f32x4_add(f32x4_mul(re_v, re_v), f32x4_mul(im_v, im_v)); + let norm_v = f32x4_sqrt(norm_sq); + let sv = v128_load(sp.add(i) as *const v128); + let new_s = f32x4_add(f32x4_mul(norm_v, one_minus_a), f32x4_mul(sv, alpha_v)); + v128_store(sp.add(i) as *mut v128, new_s); + let sqrt_s = f32x4_sqrt(new_s); + // o_re / o_im are stored contiguously, divide directly + let or_v = v128_load(rp.add(i) as *const v128); + let oi_v = v128_load(ip.add(i) as *const v128); + v128_store(rp.add(i) as *mut v128, f32x4_div(or_v, sqrt_s)); + v128_store(ip.add(i) as *mut v128, f32x4_div(oi_v, sqrt_s)); + } + i += 4; + } + while i < n { + unsafe { + let xi_re = *xf.add(i * 2); + let xi_im = *xf.add(i * 2 + 1); + let norm = (xi_re * xi_re + xi_im * xi_im).sqrt(); + let new_s = norm * (1.0 - alpha) + *sp.add(i) * alpha; + *sp.add(i) = new_s; + let sqrt_s = new_s.sqrt(); + *rp.add(i) /= sqrt_s; + *ip.add(i) /= sqrt_s; + } + i += 1; + } +} + +#[cfg(not(target_arch = "wasm32"))] +#[inline] +fn band_unit_norm_t_inner( + xs: &[Complex32], + state: &mut [f32], + alpha: f32, + o_re: &mut [f32], + o_im: &mut [f32], +) { for (x, s, o_re, o_im) in izip!( xs.iter(), state.iter_mut(), @@ -277,23 +690,6 @@ pub fn band_unit_norm_t(xs: &[Complex32], state: &mut [f32], alpha: f32, out: &m } } -pub fn compute_band_corr(out: &mut [f32], x: &[Complex32], p: &[Complex32], erb_fb: &[usize]) { - for y in out.iter_mut() { - *y = 0.0; - } - debug_assert_eq!(erb_fb.len(), out.len()); - - let mut bcsum = 0; - for (&band_size, out_b) in erb_fb.iter().zip(out.iter_mut()) { - let k = 1. / band_size as f32; - for j in 0..band_size { - let idx = bcsum + j; - *out_b += (x[idx].re * p[idx].re + x[idx].im * p[idx].im) * k; - } - bcsum += band_size; - } -} - pub fn band_compr(out: &mut [f32], x: &[f32], erb_fb: &[usize]) { for y in out.iter_mut() { *y = 0.0; @@ -337,12 +733,18 @@ fn interp_band_gain(out: &mut [f32], band_e: &[f32], erb_fb: &[usize]) { } fn apply_band_gain(out: &mut [Complex32], band_e: &[f32], erb_fb: &[usize]) { - let mut bcsum = 0; - for (&band_size, b) in erb_fb.iter().zip(band_e.iter()) { - for j in 0..band_size { - let idx = bcsum + j; - out[idx] *= *b; - } + // Reinterpret &mut [Complex32] as &mut [f32] of length 2*N. Complex32 is + // #[repr(C)] { re: f32, im: f32 }: 8 bytes, alignment 4 — identical layout + // to two contiguous f32. Multiplying each Complex32 by a real f32 scalar `b` + // is equivalent to multiplying every f32 lane by `b`. + let n = out.len(); + let outf: &mut [f32] = + unsafe { core::slice::from_raw_parts_mut(out.as_mut_ptr() as *mut f32, n * 2) }; + let mut bcsum = 0usize; + for (&band_size, &b) in erb_fb.iter().zip(band_e.iter()) { + let f_start = bcsum * 2; + let f_len = band_size * 2; + f32_scale_inplace(&mut outf[f_start..f_start + f_len], b); bcsum += band_size; } } @@ -405,10 +807,12 @@ fn frame_synthesis(input: &mut [Complex32], output: &mut [f32], state: &mut DFSt } apply_window_in_place(&mut x, &state.window); let (x_first, x_second) = x.split_at(state.frame_size); - for ((&xi, &mem), out) in x_first.iter().zip(state.synthesis_mem.iter()).zip(output.iter_mut()) - { - *out = xi + mem; - } + // out[i] = x_first[i] + synthesis_mem[i] (zip-3 stops at shortest; + // x_first.len() == output.len() == frame_size; synthesis_mem may be longer). + let n_out = output.len(); + debug_assert_eq!(x_first.len(), n_out); + debug_assert!(state.synthesis_mem.len() >= n_out); + f32_add_to(x_first, &state.synthesis_mem[..n_out], output); let split = state.synthesis_mem.len() - state.frame_size; if split > 0 { @@ -416,14 +820,12 @@ fn frame_synthesis(input: &mut [Complex32], output: &mut [f32], state: &mut DFSt } let (s_first, s_second) = state.synthesis_mem.split_at_mut(split); let (xs_first, xs_second) = x_second.split_at(split); - for (&xi, mem) in xs_first.iter().zip(s_first.iter_mut()) { - // Overlap add for next frame - *mem += xi; - } - for (&xi, mem) in xs_second.iter().zip(s_second.iter_mut()) { - // Override left shifted buffer - *mem = xi; - } + // Overlap-add for next frame: s_first[i] += xs_first[i]. + let n_first = xs_first.len().min(s_first.len()); + f32_add_inplace(&mut s_first[..n_first], &xs_first[..n_first]); + // Override left-shifted buffer: s_second[i] = xs_second[i] (memcpy-shaped). + let n_second = xs_second.len().min(s_second.len()); + s_second[..n_second].copy_from_slice(&xs_second[..n_second]); } fn apply_window(xs: &[f32], window: &[f32]) -> Vec { @@ -434,13 +836,9 @@ fn apply_window(xs: &[f32], window: &[f32]) -> Vec { out } -fn apply_window_in_place<'a, I>(xs: &mut [f32], window: I) -where - I: IntoIterator, -{ - for (x, &w) in xs.iter_mut().zip(window) { - *x *= w; - } +fn apply_window_in_place(xs: &mut [f32], window: &[f32]) { + debug_assert_eq!(xs.len(), window.len()); + f32_mul_inplace(xs, window); } pub fn post_filter(noisy: &[Complex32], enh: &mut [Complex32], beta: f32) {