diff --git a/Cargo.lock b/Cargo.lock index a1e508d94..cc84bd751 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -104,7 +104,7 @@ dependencies = [ "mt-symetric", "mt-utils", "mt-whir", - "rayon", + "parallel", "tracing", ] @@ -240,31 +240,6 @@ version = "1.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "790eea4361631c5e7d22598ecd5723ff611904e3344ce8720784c93e3d83d40b" -[[package]] -name = "crossbeam-deque" -version = "0.8.6" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9dd111b7b7f7d55b72c0a6ae361660ee5853c9af73f70c3c2ef6858b950e2e51" -dependencies = [ - "crossbeam-epoch", - "crossbeam-utils", -] - -[[package]] -name = "crossbeam-epoch" -version = "0.9.18" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5b82ac4a3c2ca9c3460964f020e1402edd5753411d7737aa39c3714ad1b5420e" -dependencies = [ - "crossbeam-utils", -] - -[[package]] -name = "crossbeam-utils" -version = "0.8.21" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d0a5c400df2834b80a4c3327b3aad3a4c4cd4de0629063962b03235697506a28" - [[package]] name = "crypto-common" version = "0.1.7" @@ -493,6 +468,7 @@ dependencies = [ "backend", "clap", "lean_vm", + "libc", "rand", "rec_aggregation", "serde_json", @@ -620,7 +596,7 @@ dependencies = [ "mt-koala-bear", "mt-symetric", "mt-utils", - "rayon", + "parallel", "serde", "tracing", ] @@ -632,9 +608,9 @@ dependencies = [ "itertools", "mt-utils", "num-bigint", + "parallel", "paste", "rand", - "rayon", "serde", "tracing", ] @@ -649,7 +625,6 @@ dependencies = [ "num-bigint", "paste", "rand", - "rayon", "serde", "tracing", ] @@ -662,8 +637,8 @@ dependencies = [ "mt-field", "mt-koala-bear", "mt-utils", + "parallel", "rand", - "rayon", "serde", "system-info", ] @@ -677,7 +652,7 @@ dependencies = [ "mt-field", "mt-koala-bear", "mt-poly", - "rayon", + "parallel", "tracing", ] @@ -687,7 +662,7 @@ version = "0.1.0" dependencies = [ "mt-field", "mt-koala-bear", - "rayon", + "parallel", ] [[package]] @@ -709,8 +684,8 @@ dependencies = [ "mt-sumcheck", "mt-symetric", "mt-utils", + "parallel", "rand", - "rayon", "system-info", "tracing", "tracing-forest", @@ -791,6 +766,13 @@ version = "1.70.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "384b8ab6d37215f3c5301a95a4accb5d64aa607f1fcb26a11b5303878451b4fe" +[[package]] +name = "parallel" +version = "0.1.0" +dependencies = [ + "system-info", +] + [[package]] name = "paste" version = "1.0.15" @@ -910,26 +892,6 @@ version = "0.10.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "0c8d0fd677905edcbeedbf2edb6494d676f0e98d54d5cf9bda0b061cb8fb8aba" -[[package]] -name = "rayon" -version = "1.11.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "368f01d005bf8fd9b1206fb6fa653e6c4a81ceb1466406b81792d87c5677a58f" -dependencies = [ - "either", - "rayon-core", -] - -[[package]] -name = "rayon-core" -version = "1.13.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "22e18b0f0062d30d4230b2e85ff77fdfe4326feb054b9783a3460d8435c8ab91" -dependencies = [ - "crossbeam-deque", - "crossbeam-utils", -] - [[package]] name = "rec_aggregation" version = "0.1.0" @@ -1118,7 +1080,6 @@ name = "system-info" version = "0.1.0" dependencies = [ "libc", - "rayon", ] [[package]] @@ -1478,7 +1439,6 @@ name = "zk-alloc" version = "0.1.0" dependencies = [ "libc", - "rayon", "system-info", ] diff --git a/Cargo.toml b/Cargo.toml index f8e2ada76..7270ec157 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -19,6 +19,7 @@ members = [ "crates/backend/fiat-shamir", "crates/backend/sumcheck", "crates/backend/system-info", + "crates/backend/parallel", "crates/backend/zk-alloc", ] @@ -61,14 +62,14 @@ lean_compiler = { path = "crates/lean_compiler" } lean_prover = { path = "crates/lean_prover" } rec_aggregation = { path = "crates/rec_aggregation" } backend = { path = "crates/backend" } -zk-alloc = { path = "crates/backend/zk-alloc" } system-info = { path = "crates/backend/system-info" } +parallel = { path = "crates/backend/parallel" } +zk-alloc = { path = "crates/backend/zk-alloc" } # External sha3 = "0.11.0" clap = { version = "4.5.59", features = ["derive"] } rand = "0.10.0" -rayon = "1.11.0" pest = "2.7" pest_derive = "2.7" itertools = "0.14.0" @@ -83,12 +84,14 @@ include_dir = "0.7" [features] prox-gaps-conjecture = ["rec_aggregation/prox-gaps-conjecture"] +# Build with the plain system allocator instead of zk-alloc (for comparison/debugging). standard-alloc = ["rec_aggregation/standard-alloc"] [dependencies] clap.workspace = true rec_aggregation.workspace = true zk-alloc.workspace = true +libc = "0.2" rand.workspace = true sub_protocols.workspace = true utils.workspace = true diff --git a/crates/backend/Cargo.toml b/crates/backend/Cargo.toml index 3f61957af..d56cf56a8 100644 --- a/crates/backend/Cargo.toml +++ b/crates/backend/Cargo.toml @@ -9,7 +9,7 @@ poly = { path = "poly", package = "mt-poly" } sumcheck = { path = "sumcheck", package = "mt-sumcheck" } field = { path = "field", package = "mt-field" } air = { path = "air", package = "mt-air" } -rayon.workspace = true +parallel.workspace = true whir = { path = "../whir", package = "mt-whir" } tracing.workspace = true fiat-shamir = { path = "fiat-shamir", package = "mt-fiat-shamir" } diff --git a/crates/backend/fiat-shamir/Cargo.toml b/crates/backend/fiat-shamir/Cargo.toml index ec8649bc2..57f32d2ba 100644 --- a/crates/backend/fiat-shamir/Cargo.toml +++ b/crates/backend/fiat-shamir/Cargo.toml @@ -10,4 +10,4 @@ symetric = { path = "../symetric", package = "mt-symetric" } utils = { path = "../utils", package = "mt-utils" } tracing.workspace = true serde.workspace = true -rayon.workspace = true +parallel.workspace = true diff --git a/crates/backend/fiat-shamir/src/prover.rs b/crates/backend/fiat-shamir/src/prover.rs index 80bb6d13e..79d3859bb 100644 --- a/crates/backend/fiat-shamir/src/prover.rs +++ b/crates/backend/fiat-shamir/src/prover.rs @@ -9,8 +9,7 @@ use field::PrimeCharacteristicRing; use field::integers::QuotientMap; use field::{ExtensionField, PrimeField64}; use koala_bear::symmetric::Permutation; -use rayon::prelude::*; -use std::sync::atomic::{AtomicU64, Ordering}; +use std::sync::atomic::{AtomicBool, AtomicU64, Ordering}; use std::time::Duration; use std::{fmt::Debug, sync::Mutex, time::Instant}; @@ -132,9 +131,22 @@ where let witness_found = Mutex::>>::new(None); // each batch tests lanes witnesses simultaneously let num_batches = PF::::ORDER_U64.div_ceil(lanes as u64); - (0..num_batches) - .into_par_iter() - .find_any(|&batch| { + + // Parallel short-circuiting search (replaces rayon `find_any`): spawn one + // searcher per worker, each claiming batches from a shared counter and bailing + // as soon as any worker finds a witness. Bounds the work to ~expected + a few + // extra batches instead of enumerating all `num_batches` (which can be ~2^31). + let next_batch = AtomicU64::new(0); + let found = AtomicBool::new(false); + parallel::for_each_index(parallel::num_threads(), |_| { + loop { + if found.load(Ordering::Relaxed) { + return; + } + let batch = next_batch.fetch_add(1, Ordering::Relaxed); + if batch >= num_batches { + return; + } let base = batch * lanes as u64; let packed_witnesses = Packed::::from_fn(|lane| { @@ -159,12 +171,13 @@ where let rand_usize = sample.as_canonical_u64() as usize; if (rand_usize & ((1 << bits) - 1)) == 0 { *witness_found.lock().unwrap() = Some(*witness); - return true; + found.store(true, Ordering::Relaxed); + return; } } - false - }) - .expect("failed to find witness"); + } + }); + assert!(found.load(Ordering::Relaxed), "failed to find witness"); let witness = witness_found.lock().unwrap().unwrap(); diff --git a/crates/backend/field/Cargo.toml b/crates/backend/field/Cargo.toml index 89e87c133..cde41bb61 100644 --- a/crates/backend/field/Cargo.toml +++ b/crates/backend/field/Cargo.toml @@ -9,7 +9,7 @@ utils = { path = "../utils", package = "mt-utils" } itertools.workspace = true num-bigint = "*" paste = "*" +parallel.workspace = true rand.workspace = true -rayon.workspace = true serde.workspace = true tracing.workspace = true diff --git a/crates/backend/field/src/field.rs b/crates/backend/field/src/field.rs index b44ed45ed..836529cf9 100644 --- a/crates/backend/field/src/field.rs +++ b/crates/backend/field/src/field.rs @@ -9,7 +9,6 @@ use core::ops::{Add, AddAssign, Div, DivAssign, Mul, MulAssign, Neg, Sub, SubAss use core::{array, slice}; use num_bigint::BigUint; -use rayon::{current_num_threads, prelude::*}; use serde::Serialize; use serde::de::DeserializeOwned; use utils::{flatten_to_base, iter_array_chunks_padded}; @@ -1020,7 +1019,7 @@ impl BoundedPowers { let mut points_packed = F::Packing::zero_vec(num_packed); // Split computation evenly among threads - let num_threads = current_num_threads().max(1); + let num_threads = parallel::num_threads().max(1); let chunk_size = num_packed.div_ceil(num_threads); // Precompute base for each chunk. @@ -1028,16 +1027,13 @@ impl BoundedPowers { let chunk_base = base.exp_u64((chunk_size * width) as u64); let shift = self.iter.current; - points_packed - .par_chunks_mut(chunk_size) - .enumerate() - .for_each(|(chunk_idx, chunk_slice)| { - // First power in this chunk - let chunk_start = shift * chunk_base.exp_u64(chunk_idx as u64); + parallel::par_chunks_mut(&mut points_packed, chunk_size, |chunk_idx, chunk_slice| { + // First power in this chunk + let chunk_start = shift * chunk_base.exp_u64(chunk_idx as u64); - // Fill the chunk with packed powers. - F::Packing::packed_shifted_powers(base, chunk_start).fill(chunk_slice); - }); + // Fill the chunk with packed powers. + F::Packing::packed_shifted_powers(base, chunk_start).fill(chunk_slice); + }); // return the number of requested points, discarding the unused packed powers // SAFETY: size_of:: always divides size_of::. diff --git a/crates/backend/koala-bear/Cargo.toml b/crates/backend/koala-bear/Cargo.toml index aba2ab231..5ce4ad111 100644 --- a/crates/backend/koala-bear/Cargo.toml +++ b/crates/backend/koala-bear/Cargo.toml @@ -8,7 +8,6 @@ field = { path = "../field", package = "mt-field" } utils = { path = "../utils", package = "mt-utils" } rand.workspace = true -rayon.workspace = true serde.workspace = true itertools.workspace = true tracing.workspace = true diff --git a/crates/backend/parallel/Cargo.toml b/crates/backend/parallel/Cargo.toml new file mode 100644 index 000000000..731b5163d --- /dev/null +++ b/crates/backend/parallel/Cargo.toml @@ -0,0 +1,11 @@ +[package] +name = "parallel" +version.workspace = true +edition.workspace = true +description = "Minimal fixed-size thread pool for static data-parallel kernels" + +[dependencies] +system-info.workspace = true + +[lints] +workspace = true diff --git a/crates/backend/parallel/src/lib.rs b/crates/backend/parallel/src/lib.rs new file mode 100644 index 000000000..d75624914 --- /dev/null +++ b/crates/backend/parallel/src/lib.rs @@ -0,0 +1,434 @@ +//! Minimal fixed-size thread pool for flat data-parallel kernels ("split a range, run a +//! closure on each piece"). No work-stealing, no per-dispatch allocation; owning the runtime +//! lets us pin per-worker scratch and drop rayon. +//! +//! - **Model.** `NUM_THREADS-1` background workers (ids `1..NUM_THREADS`); the dispatcher is +//! worker 0 and runs its share inline. Workers claim task ranges from a shared atomic +//! counter (guided self-scheduling) for dynamic load balance. +//! - **Lock-free dispatch.** Dispatch bumps a `generation` counter idle workers spin on +//! (back-to-back dispatches pay no syscall), parking only after `SPIN_LIMIT` idle spins. +//! Completion is a `working` countdown the dispatcher spins on. The per-worker `parked` flag +//! is SeqCst-ordered against `generation`, so for every dispatch at least one side sees the +//! other — a wakeup can't be lost — and unpark is skipped while a worker spins. +//! - **No nesting.** Dispatching from inside a task would deadlock the dispatch lock; an +//! `IN_TASK` guard panics instead (the outer level already saturates every core). +//! - **Panics.** A task panic is caught on its worker and re-raised on the dispatcher once the +//! dispatch quiesces (rayon's propagate-to-caller behavior); the pool stays usable. +//! - **One dispatcher at a time**, serialized by the `dispatch` mutex. + +use std::any::Any; +use std::cell::{Cell, UnsafeCell}; +use std::panic::{AssertUnwindSafe, catch_unwind, resume_unwind}; +use std::ptr::NonNull; +use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering}; +use std::sync::{Mutex, Once, OnceLock}; +use std::thread::Thread; + +use system_info::NUM_THREADS; + +/// Idle spins before a worker parks: long enough that back-to-back dispatches stay hot, short +/// enough that sequential gaps free the core for the active thread. +const SPIN_LIMIT: u32 = 1 << 12; + +/// Max tasks claimed in one guided-self-scheduling step: bounds load imbalance while keeping +/// million-task kernels to a few thousand claims. +const MAX_CLAIM_BATCH: usize = 1 << 12; + +/// Worker count including the dispatcher (= build-time `NUM_THREADS`). +#[must_use] +pub const fn num_threads() -> usize { + NUM_THREADS +} + +/// Chunk size for a flat fan-out: a few chunks per worker — fine enough for the counter to +/// rebalance heterogeneous cores, coarse enough to amortize dispatch. +#[must_use] +#[inline] +pub fn recommended_chunk_size(n_items: usize) -> usize { + n_items.div_ceil(NUM_THREADS * 4).max(1) +} + +thread_local! { + /// Stable pool id of this thread; `0` on the dispatcher and off-pool threads. + static WORKER_ID: Cell = const { Cell::new(0) }; + /// Set while running a task; a dispatch in this state is forbidden nesting (panics). + static IN_TASK: Cell = const { Cell::new(false) }; +} + +/// Calling worker's id in `0..NUM_THREADS` (`0` off-pool). The hook for per-worker scratch. +#[must_use] +pub fn current_worker_id() -> usize { + WORKER_ID.with(Cell::get) +} + +/// Type-erased work unit. The `&dyn Fn` lifetime is erased to `'static`; it is dereferenced +/// only inside a dispatch window during which the dispatcher blocks, so the borrow outlives +/// every call. Range-based (`f(start, end)`) so a reduction looks up its per-worker +/// accumulator once per claimed batch, not per element. +struct Job { + f: NonNull, + n_tasks: usize, +} + +/// Park/unpark state, indexed by worker id (slot 0, the dispatcher, never parks). +#[derive(Debug)] +struct Worker { + /// "Currently parked", SeqCst-ordered against `Pool::generation`. + parked: AtomicBool, + /// Handle for `unpark`, published once at worker start-up. + handle: OnceLock, +} + +struct Pool { + /// Current job: written by the dispatcher before the `generation` bump, read by workers + /// after observing it (the bump supplies the happens-before). + job: UnsafeCell>, + /// Bumped once per dispatch; idle workers watch it (spin, then park). + generation: AtomicUsize, + /// Next task index to claim; reset to 0 per dispatch. + counter: AtomicUsize, + /// Background workers still draining; the dispatcher spins this to 0. + working: AtomicUsize, + /// Park flag + unpark handle per worker (slot 0 unused). + workers: Vec, + /// Serializes dispatchers: one driver at a time. + dispatch: Mutex<()>, + /// First task-panic payload of the current dispatch, re-raised by the dispatcher. Caught + /// here so it can't unwind across `worker_main` (which would skip the `working` decrement + /// and deadlock the completion spin). + panic: Mutex>>, +} + +// SAFETY: `job` is written only by the sole dispatcher (while workers are parked or before +// they observe the generation bump) and read only after; the generation release/acquire and +// SeqCst park protocol order the phases. The erased `Job` pointer is used only within a +// dispatch window where its borrow is live. +unsafe impl Sync for Pool {} +unsafe impl Send for Pool {} + +/// Idempotent warm-up: spawn workers and run one empty dispatch so the pool and the (macOS) +/// lazily-allocated mutex exist before timed work. Otherwise the pool inits on first use. +/// +/// Also fail-fast if the machine's core count differs from the build-time [`NUM_THREADS`] +/// (which sizes the whole pool): a mismatch silently over/under-subscribes every kernel, so +/// surface it as "rebuild" rather than a quiet perf cliff. +pub fn init() { + static INIT: Once = Once::new(); + INIT.call_once(|| { + let actual = std::thread::available_parallelism().unwrap().get(); + assert_eq!( + actual, NUM_THREADS, + "parallel pool built for {NUM_THREADS} threads but this machine reports {actual} -> please rebuild" + ); + let _ = pool(); + if NUM_THREADS > 1 { + for_each_index(NUM_THREADS, |_| {}); + } + }); +} + +fn pool() -> &'static Pool { + static POOL: OnceLock<&'static Pool> = OnceLock::new(); + POOL.get_or_init(|| { + let n = NUM_THREADS.max(1); + let p: &'static Pool = Box::leak(Box::new(Pool { + job: UnsafeCell::new(None), + generation: AtomicUsize::new(0), + counter: AtomicUsize::new(0), + working: AtomicUsize::new(0), + workers: (0..n) + .map(|_| Worker { + parked: AtomicBool::new(false), + handle: OnceLock::new(), + }) + .collect(), + dispatch: Mutex::new(()), + panic: Mutex::new(None), + })); + for id in 1..n { + std::thread::Builder::new() + .name(format!("parallel-worker-{id}")) + .spawn(move || worker_main(p, id)) + .expect("failed to spawn pool worker"); + } + p + }) +} + +fn worker_main(pool: &'static Pool, id: usize) { + WORKER_ID.with(|c| c.set(id)); + let _ = pool.workers[id].handle.set(std::thread::current()); + // Leaked, lives for the whole process; workers never shut down. One iteration per dispatch. + let mut last_gen = 0usize; + loop { + last_gen = wait_for_dispatch(pool, id, last_gen); + drain(pool); + pool.working.fetch_sub(1, Ordering::Release); + } +} + +/// Block until a new job is published, returning its generation. Spins up to [`SPIN_LIMIT`], +/// then parks. The park is delicate: publish `parked = true`, then re-check `generation`, both +/// SeqCst — the same total order the dispatcher's bump and `parked` load observe, so a wakeup +/// can never be lost. +fn wait_for_dispatch(pool: &Pool, id: usize, last_gen: usize) -> usize { + let mut spins = 0u32; + loop { + let g = pool.generation.load(Ordering::Acquire); + if g != last_gen { + return g; + } + if spins < SPIN_LIMIT { + spins += 1; + std::hint::spin_loop(); + continue; + } + // Announce intent to park, then re-check: park only if nothing changed, else re-loop. + pool.workers[id].parked.store(true, Ordering::SeqCst); + if pool.generation.load(Ordering::SeqCst) == last_gen { + std::thread::park(); + } + pool.workers[id].parked.store(false, Ordering::SeqCst); + spins = 0; + } +} + +/// Claim and run task ranges until the counter is exhausted (guided self-scheduling: each claim +/// takes `remaining / (NUM_THREADS*2)`, clamped to `1..=`[`MAX_CLAIM_BATCH`]). Big early claims +/// cut counter contention; the proportional shrink keeps the tail balanced. +fn drain(pool: &Pool) { + // SAFETY: the dispatcher published `Some(job)` before the bump this worker observed and + // overwrites it only on the next dispatch (gated on `working == 0`); no writer during drain. + let job = unsafe { (*pool.job.get()).as_ref().expect("drain without a published job") }; + // SAFETY: `job.f` borrows a `&dyn Fn` the blocked dispatcher keeps live. + let f = unsafe { job.f.as_ref() }; + let n = job.n_tasks; + let prev = IN_TASK.replace(true); // catch nested dispatch (see `for_each_chunk`) + // Catch a task panic so it can't unwind across `worker_main` (skipping the `working` + // decrement → deadlock) or poison the dispatch lock; `for_each_chunk` re-raises it. + let result = catch_unwind(AssertUnwindSafe(|| { + loop { + // Stale read only affects granularity: `fetch_add` tiles `0..n` into disjoint claims. + let observed = pool.counter.load(Ordering::Relaxed); + if observed >= n { + break; + } + let batch = ((n - observed) / (NUM_THREADS * 2)).clamp(1, MAX_CLAIM_BATCH); + let start = pool.counter.fetch_add(batch, Ordering::Relaxed); + if start >= n { + break; + } + f(start, (start + batch).min(n)); + } + })); + IN_TASK.set(prev); + if let Err(payload) = result { + pool.panic.lock().unwrap().get_or_insert(payload); // keep the first + } +} + +/// Run `f(start, end)` over disjoint ranges tiling `0..n_tasks`, in parallel; a worker may get +/// several (guided self-scheduling, see [`drain`]). Blocks until done, the dispatcher acting as +/// worker 0. The base primitive — range-based so reductions amortize per-worker lookups. +pub fn for_each_chunk(n_tasks: usize, f: F) { + // Nesting would deadlock the dispatch lock — panic so it's caught, not silently serial. + assert!(!IN_TASK.get(), "nested parallel dispatch from within a pool task"); + + // Trivial sizes / single-core builds run inline. + if NUM_THREADS <= 1 || n_tasks <= 1 { + if n_tasks > 0 { + f(0, n_tasks); + } + return; + } + + let pool = pool(); + let _guard = pool.dispatch.lock().unwrap(); + + // SAFETY: erase the borrow to `'static` so it fits the `Job`. The dispatcher blocks on + // `working` before returning, so `f` outlives every deref. `transmute` (not a `*const dyn` + // cast) is required: a bare cast would default the trait object to `'static` and force + // `F: 'static` (E0310); the transmute reinterprets the same fat pointer without that bound. + let f_ref: &(dyn Fn(usize, usize) + Sync) = &f; + let f_erased: NonNull = unsafe { std::mem::transmute(NonNull::from(f_ref)) }; + + // SAFETY: sole writer — prior dispatch fully drained (`working == 0`), next not yet observed. + unsafe { *pool.job.get() = Some(Job { f: f_erased, n_tasks }) }; + pool.counter.store(0, Ordering::Relaxed); + pool.working.store(NUM_THREADS - 1, Ordering::Release); + pool.generation.fetch_add(1, Ordering::SeqCst); // publish; SeqCst guards the park protocol + + // Wake only parked workers; spinning ones see the bump for free. + for worker in &pool.workers[1..] { + if worker.parked.load(Ordering::SeqCst) + && let Some(t) = worker.handle.get() + { + t.unpark(); + } + } + + drain(pool); // dispatcher runs as worker 0 + while pool.working.load(Ordering::Acquire) != 0 { + std::hint::spin_loop(); // lock-free completion wait + } + + // Re-raise the first task panic (if any) after dropping `_guard`, so the lock releases + // cleanly (no poison) and the pool stays usable. + let panicked = pool.panic.lock().unwrap().take(); + drop(_guard); + if let Some(payload) = panicked { + resume_unwind(payload); + } +} + +/// `f(i)` for every `i` in `0..n_tasks`, in parallel. `#[inline]` folds the range→index adapter +/// into the monomorphized [`for_each_chunk`]. +#[inline] +pub fn for_each_index(n_tasks: usize, f: F) { + for_each_chunk(n_tasks, |start, end| { + for i in start..end { + f(i); + } + }); +} + +/// A base `*mut` shareable across workers. Sound only because callers partition the allocation +/// by task index (disjoint regions). Reuse this instead of redefining the pattern per crate. +#[derive(Debug)] +pub struct SendPtr(pub *mut T); +// SAFETY: accesses are partitioned by task index (see callers). +unsafe impl Send for SendPtr {} +unsafe impl Sync for SendPtr {} + +impl SendPtr { + /// Offset the base by `n` elements. + /// # Safety + /// `n` stays in the allocation; any write targets a slot no concurrent task touches. + #[inline] + pub unsafe fn add(&self, n: usize) -> *mut T { + unsafe { self.0.add(n) } + } + + /// Reconstruct the `len`-element slice at element offset `off`. + /// # Safety + /// `off`/`len` in-bounds and disjoint from every other concurrent task's slice. + #[inline] + pub unsafe fn slice<'a>(&self, off: usize, len: usize) -> &'a mut [T] { + unsafe { std::slice::from_raw_parts_mut(self.0.add(off), len) } + } +} + +/// Parallel `data.chunks_mut(chunk).enumerate().for_each(f)`; the final chunk may be shorter. +pub fn par_chunks_mut(data: &mut [T], chunk: usize, f: F) +where + F: Fn(usize, &mut [T]) + Sync, +{ + assert!(chunk > 0, "chunk size must be non-zero"); + let len = data.len(); + let base = SendPtr(data.as_mut_ptr()); + for_each_index(len.div_ceil(chunk), |i| { + let start = i * chunk; + // SAFETY: distinct `i` give disjoint in-bounds ranges; `data` stays borrowed. + let slice = unsafe { base.slice(start, chunk.min(len - start)) }; + f(i, slice); + }); +} + +/// Parallel `data.iter_mut().enumerate().for_each(f)`, chunked by [`recommended_chunk_size`]. +/// Hands the closure each element's **global** index. `#[inline]` to recover hand-written codegen. +#[inline] +pub fn par_for_each_mut(data: &mut [T], f: F) +where + F: Fn(usize, &mut T) + Sync, +{ + let chunk = recommended_chunk_size(data.len()); + par_chunks_mut(data, chunk, |ci, sub| { + for (k, slot) in sub.iter_mut().enumerate() { + f(ci * chunk + k, slot); + } + }); +} + +/// Parallel `(0..n_tasks).map(f).collect::>()`: runs `f(i)` across the pool and gathers +/// the results in index order, writing each straight into the output (no `Option` slots, one +/// allocation). Folds away the common "fill a `Vec>` in parallel, then unwrap" dance. +pub fn par_map_collect T + Sync>(n_tasks: usize, f: F) -> Vec { + let mut out: Vec = Vec::with_capacity(n_tasks); + let base = SendPtr(out.as_mut_ptr()); + for_each_index(n_tasks, |i| { + // SAFETY: distinct `i` write disjoint, in-bounds slots (each exactly once) and the + // dispatch blocks until all writes finish. A panic in `f` leaks the slots written so + // far, which is fine: a pool task panic is fatal (see the module's "Panics" note). + unsafe { base.add(i).write(f(i)) }; + }); + // SAFETY: every slot in `0..n_tasks` was initialized exactly once above. + unsafe { out.set_len(n_tasks) }; + out +} + +/// Give each worker its own persistent `Option` slot while it drains `0..n_tasks`: +/// `run(slot, start, end)` fires once per claimed batch with that worker's slot, so state +/// accumulates across its batches. Returns the slots (rest `None`) for the caller to combine. +/// The sole home of the cross-worker slot `unsafe`. +fn drain_into_slots(n_tasks: usize, run: impl Fn(&mut Option, usize, usize) + Sync) -> Vec> { + let mut slots: Vec> = (0..NUM_THREADS).map(|_| None).collect(); + let ptr = SendPtr(slots.as_mut_ptr()); + for_each_chunk(n_tasks, |start, end| { + // SAFETY: `current_worker_id() < NUM_THREADS` is unique per live worker → disjoint + // slots; `slots` outlives the dispatch. + let slot = unsafe { &mut *ptr.add(current_worker_id()) }; + run(slot, start, end); + }); + slots +} + +/// Parallel map-reduce over `0..n_tasks` = `(0..n).map(map).reduce(identity, reduce)`. Each +/// worker folds its claimed indices into one local partial; the partials combine on the +/// dispatcher. `reduce` must be associative with `identity()` a neutral element (rayon's +/// `reduce` contract). +pub fn map_reduce(n_tasks: usize, identity: ID, map: M, reduce: R) -> T +where + T: Send, + ID: Fn() -> T, + M: Fn(usize) -> T + Sync, + R: Fn(T, T) -> T + Sync, +{ + let slots = drain_into_slots(n_tasks, |slot, start, end| { + // Fold the batch into the worker's partial, seeded by the first `map` so `identity` + // stays off the per-element path; take/replace the shared slot just once. + *slot = (start..end).fold(slot.take(), |acc, i| { + Some(acc.map_or_else(|| map(i), |a| reduce(a, map(i)))) + }); + }); + // `identity()` seeds the combine as a no-op left-identity; the empty and single-thread + // (`for_each_chunk` runs inline) cases then fall out without a special path. + slots.into_iter().flatten().fold(identity(), &reduce) +} + +/// Parallel reduce where each worker keeps reusable scratch beside its accumulator (so the +/// per-task body needn't allocate). `(scratch, acc)` are created once per worker and threaded +/// through its batches; the `acc`s combine on the dispatcher. `combine` must be associative +/// with `init_acc()` a neutral element. +pub fn map_reduce_with_state(n_tasks: usize, init_state: IS, init_acc: IA, fold: F, combine: C) -> A +where + S: Send, + A: Send, + IS: Fn() -> S + Sync, + IA: Fn() -> A + Sync, + F: Fn(&mut S, &mut A, usize) + Sync, + C: Fn(A, A) -> A, +{ + let slots = drain_into_slots(n_tasks, |slot, start, end| { + let (state, acc) = slot.get_or_insert_with(|| (init_state(), init_acc())); + for i in start..end { + fold(state, acc, i); + } + }); + // `init_acc()` seeds the combine as a neutral element; the empty and single-thread cases + // (`for_each_chunk` runs inline) then fall out without a special path. + slots + .into_iter() + .flatten() + .map(|(_, acc)| acc) + .fold(init_acc(), &combine) +} diff --git a/crates/backend/poly/Cargo.toml b/crates/backend/poly/Cargo.toml index dcdf80aed..f198a2d19 100644 --- a/crates/backend/poly/Cargo.toml +++ b/crates/backend/poly/Cargo.toml @@ -7,9 +7,9 @@ edition.workspace = true field = { path = "../field", package = "mt-field" } utils = { path = "../utils", package = "mt-utils" } system-info.workspace = true +parallel.workspace = true itertools.workspace = true -rayon.workspace = true rand.workspace = true serde.workspace = true diff --git a/crates/backend/poly/src/eq_mle.rs b/crates/backend/poly/src/eq_mle.rs index 64d3733f5..ca54a2e4d 100644 --- a/crates/backend/poly/src/eq_mle.rs +++ b/crates/backend/poly/src/eq_mle.rs @@ -2,13 +2,69 @@ use crate::*; use crate::{EFPacking, PF}; use ::utils::{iter_array_chunks_padded, log2_ceil_usize, log2_strict_usize}; use field::*; -use rayon::prelude::*; use system_info::NUM_THREADS; const LOG_NUM_THREADS: usize = log2_ceil_usize(NUM_THREADS); -const NUM_THREADS_PADDED: usize = 1 << LOG_NUM_THREADS; const LOG_BATCHED_TILE_SIZE: usize = 14; +/// log2 oversubscription for the eq_mle fan-out: emit `NUM_THREADS << this` chunks so the +/// pool's task counter rebalances across heterogeneous cores (e.g. P/E). `0` = one chunk +/// per worker; `2` (4x) is a conservative default that balances well without over-fragmenting. +const PARALLEL_LOG_OVERSUB: usize = 2; + +/// `(log2(n_chunks), n_chunks)` for the parallel fan-out. +#[inline] +fn parallel_split() -> (usize, usize) { + let log_chunks = LOG_NUM_THREADS + PARALLEL_LOG_OVERSUB; + (log_chunks, 1 << log_chunks) +} + +/// Parallel equivalent of +/// `out.par_chunks_exact_mut(chunk).zip(buf).enumerate().for_each(|(i, (c, _))| g(i, c, &buf[i]))`, +/// dispatched through the in-house [`parallel`] pool. `chunk` must divide `out.len()` +/// exactly into `buf.len()` chunks (the eq_mle fan-out always does). +#[inline] +fn par_chunks_zip(out: &mut [T], chunk: usize, buf: &[A], g: G) +where + T: Send, + A: Sync, + G: Fn(&mut [T], &A) + Sync, +{ + debug_assert_eq!(out.len(), chunk * buf.len()); + parallel::par_chunks_mut(out, chunk, |i, c| g(c, &buf[i])); +} + +/// Shared parallel tail of the `compute_eval_eq*` family. With `eval` split into +/// `log_chunks` leading variables (handled one-per-chunk), `log_packing_width` trailing +/// variables already folded into `seed = buffer[0]`, and the middle variables left for +/// `kernel`, this builds the per-chunk equality buffer and runs +/// `kernel(middle, out_chunk, buffer_val)` over `out` in parallel. `kernel` fires once +/// per chunk (not per element), so threading it through a closure costs nothing. +#[inline] +fn par_eval_eq( + eval: &[In], + out: &mut [Out], + log_chunks: usize, + n_chunks: usize, + log_packing_width: usize, + seed: Buf, + kernel: impl Fn(&[In], &mut [Out], Buf) + Sync, +) where + In: Field, + Buf: Algebra + Copy + Send + Sync, + Out: Send, +{ + let mut buffer = Buf::zero_vec(n_chunks); + buffer[0] = seed; + fill_buffer(eval[..log_chunks].iter().rev(), &mut buffer); + + let out_chunk_size = out.len() / n_chunks; + let middle = &eval[log_chunks..(eval.len() - log_packing_width)]; + par_chunks_zip(out, out_chunk_size, &buffer, |out_chunk, buffer_val| { + kernel(middle, out_chunk, *buffer_val); + }); +} + /// Given `evals` = (α_1, ..., α_n), returns a multilinear polynomial P in n variables, /// defined on the boolean hypercube by: ∀ (x_1, ..., x_n) ∈ {0, 1}^n, /// P(x_1, ..., x_n) = Π_{i=1}^{n} (x_i.α_i + (1 - x_i).(1 - α_i)) @@ -87,62 +143,33 @@ where F: Field, EF: ExtensionField, { - // It's possible for this to be called with F = EF (Despite F actually being an extension field). - // - // IMPORTANT: We previously checked here that `packing_width > 1`, - // but this check is **not viable** for Goldilocks on Neon or when not using `target-cpu=native`. - // - // Why? Because Neon SIMD vectors are 128 bits and Goldilocks elements are already 64 bits, - // so no packing happens (width stays 1), and there's no performance advantage. - // - // Be careful: this means code relying on packing optimizations should **not assume** - // `packing_width > 1` is always true. + // `packing_width` may be 1 (e.g. Goldilocks on Neon, or without `target-cpu=native`), + // so nothing here may assume it is > 1. let log_packing_width = log2_strict_usize(F::Packing::WIDTH); - - // Ensure that the output buffer size is correct: - // It should be of size `2^n`, where `n` is the number of variables. debug_assert_eq!(out.len(), 1 << eval.len()); - // If the number of variables is small, there is no need to use - // parallelization or packings. - if eval.len() <= log_packing_width + 1 + LOG_NUM_THREADS { - // A basic recursive approach. + let (log_chunks, n_chunks) = parallel_split(); + if eval.len() <= log_packing_width + 1 + log_chunks { + // Too small to be worth packing/parallelizing. eval_eq_basic::<_, _, _, INITIALIZED>(eval, out, scalar); return; } - let eval_len_min_packing = eval.len() - log_packing_width; - - // We split eval into three parts: - // - eval[..LOG_NUM_THREADS] (the first LOG_NUM_THREADS elements) - // - eval[LOG_NUM_THREADS..eval_len_min_packing] (the middle elements) - // - eval[eval_len_min_packing..] (the last log_packing_width elements) - - // The middle elements are the ones which will be computed in parallel. - // The last log_packing_width elements are the ones which will be packed. - - // We make a buffer of elements of size `NUM_THREADS`. - let mut parallel_buffer = EF::ExtensionPacking::zero_vec(NUM_THREADS_PADDED); - let out_chunk_size = out.len() / NUM_THREADS_PADDED; - - // Compute the equality polynomial corresponding to the last log_packing_width elements - // and pack these. - parallel_buffer[0] = packed_eq_poly(&eval[eval_len_min_packing..], scalar); - - // Update the buffer so it contains the evaluations of the equality polynomial - // with respect to parts one and three. - fill_buffer(eval[..LOG_NUM_THREADS].iter().rev(), &mut parallel_buffer); - - // Finally do all computations involving the middle elements in parallel. - out.par_chunks_exact_mut(out_chunk_size) - .zip(parallel_buffer.par_iter()) - .for_each(|(out_chunk, buffer_val)| { - eval_eq_with_packed_scalar::<_, _, INITIALIZED>( - &eval[LOG_NUM_THREADS..(eval.len() - log_packing_width)], - out_chunk, - *buffer_val, - ); - }); + // Split `eval` into [leading `log_chunks` | middle | trailing `log_packing_width`]: the + // trailing vars fold into the per-chunk seed, the leading vars index the chunks, the + // middle runs in parallel. + let seed = packed_eq_poly(&eval[eval.len() - log_packing_width..], scalar); + par_eval_eq( + eval, + out, + log_chunks, + n_chunks, + log_packing_width, + seed, + |middle, out_chunk, buffer_val| { + eval_eq_with_packed_scalar::<_, _, INITIALIZED>(middle, out_chunk, buffer_val); + }, + ); } #[inline] @@ -150,16 +177,8 @@ pub fn compute_eval_eq_packed(eval: &[EF], out: &mu where EF: ExtensionField>, { - // It's possible for this to be called with F = EF (Despite F actually being an extension field). - // - // IMPORTANT: We previously checked here that `packing_width > 1`, - // but this check is **not viable** for Goldilocks on Neon or when not using `target-cpu=native`. - // - // Why? Because Neon SIMD vectors are 128 bits and Goldilocks elements are already 64 bits, - // so no packing happens (width stays 1), and there's no performance advantage. - // - // Be careful: this means code relying on packing optimizations should **not assume** - // `packing_width > 1` is always true. + // `packing_width` may be 1 (e.g. Goldilocks on Neon, or without `target-cpu=native`), + // so nothing here may assume it is > 1. let packing_width = packing_width::(); let log_packing_width = log2_strict_usize(packing_width); @@ -168,12 +187,13 @@ where // If the number of variables is small, there is no need to use // parallelization or packings. - if eval.len() <= log_packing_width + 1 + LOG_NUM_THREADS { - // A basic recursive approach. - let mut output_no_packing = EF::zero_vec(1 << eval.len()); - eval_eq_basic::<_, _, _, false>(eval, &mut output_no_packing, scalar); - out.par_iter_mut() - .zip(output_no_packing.par_chunks_exact(packing_width)) + let (log_chunks, n_chunks) = parallel_split(); + if eval.len() <= log_packing_width + 1 + log_chunks { + // Small case: evaluate unpacked, then pack lanes into `out`. + let mut unpacked = EF::zero_vec(1 << eval.len()); + eval_eq_basic::<_, _, _, false>(eval, &mut unpacked, scalar); + out.iter_mut() + .zip(unpacked.chunks_exact(packing_width)) .for_each(|(out_elem, chunk)| { if INITIALIZED { *out_elem += EF::ExtensionPacking::from_ext_slice(chunk); @@ -181,40 +201,22 @@ where *out_elem = EF::ExtensionPacking::from_ext_slice(chunk); } }); - } else { - let eval_len_min_packing = eval.len() - log_packing_width; - - // We split eval into three parts: - // - eval[..LOG_NUM_THREADS] (the first LOG_NUM_THREADS elements) - // - eval[LOG_NUM_THREADS..eval_len_min_packing] (the middle elements) - // - eval[eval_len_min_packing..] (the last log_packing_width elements) - - // The middle elements are the ones which will be computed in parallel. - // The last log_packing_width elements are the ones which will be packed. - - // We make a buffer of elements of size `NUM_THREADS`. - let mut parallel_buffer = EF::ExtensionPacking::zero_vec(NUM_THREADS_PADDED); - let out_chunk_size = out.len() / NUM_THREADS_PADDED; - - // Compute the equality polynomial corresponding to the last log_packing_width elements - // and pack these. - parallel_buffer[0] = packed_eq_poly(&eval[eval_len_min_packing..], scalar); - - // Update the buffer so it contains the evaluations of the equality polynomial - // with respect to parts one and three. - fill_buffer(eval[..LOG_NUM_THREADS].iter().rev(), &mut parallel_buffer); - - // Finally do all computations involving the middle elements in parallel. - out.par_chunks_exact_mut(out_chunk_size) - .zip(parallel_buffer.par_iter()) - .for_each(|(out_chunk, buffer_val)| { - eval_eq_with_packed_output::<_, _, INITIALIZED>( - &eval[LOG_NUM_THREADS..(eval.len() - log_packing_width)], - out_chunk, - *buffer_val, - ); - }); + return; } + + // See `compute_eval_eq` for the leading/middle/trailing split. + let seed = packed_eq_poly(&eval[eval.len() - log_packing_width..], scalar); + par_eval_eq( + eval, + out, + log_chunks, + n_chunks, + log_packing_width, + seed, + |middle, out_chunk, buffer_val| { + eval_eq_with_packed_output::<_, _, INITIALIZED>(middle, out_chunk, buffer_val); + }, + ); } /// Computes the equality polynomial evaluations efficiently. @@ -240,57 +242,30 @@ where F: Field, EF: ExtensionField, { - // we assume that packing_width is a power of 2. let log_packing_width = log2_strict_usize(F::Packing::WIDTH); - - // Ensure that the output buffer size is correct: - // It should be of size `2^n`, where `n` is the number of variables. debug_assert_eq!(out.len(), 1 << eval.len()); - // If the number of variables is small, there is no need to use - // parallelization or packings. - if eval.len() <= log_packing_width + 1 + LOG_NUM_THREADS { - // A basic recursive approach. + let (log_chunks, n_chunks) = parallel_split(); + if eval.len() <= log_packing_width + 1 + log_chunks { eval_eq_basic::<_, _, _, INITIALIZED>(eval, out, scalar); return; } - let eval_len_min_packing = eval.len() - log_packing_width; - - // We split eval into three parts: - // - eval[..LOG_NUM_THREADS] (the first LOG_NUM_THREADS elements) - // - eval[LOG_NUM_THREADS..eval_len_min_packing] (the middle elements) - // - eval[eval_len_min_packing..] (the last log_packing_width elements) - - // The middle elements are the ones which will be computed in parallel. - // The last log_packing_width elements are the ones which will be packed. - - // We make a buffer of PackedField elements of size `NUM_THREADS`. - // Note that this is a slightly different strategy to `eval_eq` which instead - // uses PackedExtensionField elements. Whilst this involves slightly more mathematical - // operations, it seems to be faster in practice due to less data moving around. - let mut parallel_buffer = F::Packing::zero_vec(NUM_THREADS_PADDED); - let out_chunk_size = out.len() / NUM_THREADS_PADDED; - - // Compute the equality polynomial corresponding to the last log_packing_width elements - // and pack these. - parallel_buffer[0] = packed_eq_poly(&eval[eval_len_min_packing..], F::ONE); - - // Update the buffer so it contains the evaluations of the equality polynomial - // with respect to parts one and three. - fill_buffer(eval[..LOG_NUM_THREADS].iter().rev(), &mut parallel_buffer); - - // Finally do all computations involving the middle elements in parallel. - out.par_chunks_exact_mut(out_chunk_size) - .zip(parallel_buffer.par_iter()) - .for_each(|(out_chunk, buffer_val)| { - base_eval_eq_packed::<_, _, INITIALIZED>( - &eval[LOG_NUM_THREADS..(eval.len() - log_packing_width)], - out_chunk, - *buffer_val, - scalar, - ); - }); + // Base-field input: seed the per-chunk buffer with `F::Packing` (not `EF::ExtensionPacking`) + // and apply `scalar` inside the kernel — slightly more ops but less data movement, which is + // faster here in practice. See `compute_eval_eq` for the leading/middle/trailing split. + let seed = packed_eq_poly(&eval[eval.len() - log_packing_width..], F::ONE); + par_eval_eq( + eval, + out, + log_chunks, + n_chunks, + log_packing_width, + seed, + |middle, out_chunk, buffer_val| { + base_eval_eq_packed::<_, _, INITIALIZED>(middle, out_chunk, buffer_val, scalar); + }, + ); } #[inline] @@ -302,24 +277,18 @@ pub fn compute_eval_eq_base_packed( F: Field, EF: ExtensionField, { - // we assume that packing_width is a power of 2. let packing_width = F::Packing::WIDTH; let log_packing_width = log2_strict_usize(packing_width); assert!(log_packing_width <= eval.len()); assert_eq!(out.len(), 1 << (eval.len() - log_packing_width)); - // Ensure that the output buffer size is correct: - // It should be of size `2^n`, where `n` is the number of variables. - debug_assert_eq!(out.len(), 1 << (eval.len() - log_packing_width)); - - // If the number of variables is small, there is no need to use - // parallelization or packings. - if eval.len() <= log_packing_width + 1 + LOG_NUM_THREADS { - // A basic recursive approach. - let mut output_no_packing = EF::zero_vec(1 << eval.len()); - eval_eq_basic::<_, _, _, false>(eval, &mut output_no_packing, scalar); - out.par_iter_mut() - .zip(output_no_packing.par_chunks_exact(packing_width)) + let (log_chunks, n_chunks) = parallel_split(); + if eval.len() <= log_packing_width + 1 + log_chunks { + // Small case: evaluate unpacked, then pack lanes into `out`. + let mut unpacked = EF::zero_vec(1 << eval.len()); + eval_eq_basic::<_, _, _, false>(eval, &mut unpacked, scalar); + out.iter_mut() + .zip(unpacked.chunks_exact(packing_width)) .for_each(|(out_elem, chunk)| { if INITIALIZED { *out_elem += EF::ExtensionPacking::from_ext_slice(chunk); @@ -327,45 +296,24 @@ pub fn compute_eval_eq_base_packed( *out_elem = EF::ExtensionPacking::from_ext_slice(chunk); } }); - } else { - let eval_len_min_packing = eval.len() - log_packing_width; - - // We split eval into three parts: - // - eval[..LOG_NUM_THREADS] (the first LOG_NUM_THREADS elements) - // - eval[LOG_NUM_THREADS..eval_len_min_packing] (the middle elements) - // - eval[eval_len_min_packing..] (the last log_packing_width elements) - - // The middle elements are the ones which will be computed in parallel. - // The last log_packing_width elements are the ones which will be packed. - - // We make a buffer of PackedField elements of size `NUM_THREADS`. - // Note that this is a slightly different strategy to `eval_eq` which instead - // uses PackedExtensionField elements. Whilst this involves slightly more mathematical - // operations, it seems to be faster in practice due to less data moving around. - let mut parallel_buffer = F::Packing::zero_vec(NUM_THREADS_PADDED); - let out_chunk_size = out.len() / NUM_THREADS_PADDED; - - // Compute the equality polynomial corresponding to the last log_packing_width elements - // and pack these. - parallel_buffer[0] = packed_eq_poly(&eval[eval_len_min_packing..], F::ONE); - - // Update the buffer so it contains the evaluations of the equality polynomial - // with respect to parts one and three. - fill_buffer(eval[..LOG_NUM_THREADS].iter().rev(), &mut parallel_buffer); - - // Finally do all computations involving the middle elements in parallel. - let scalar_packed = EF::ExtensionPacking::from(scalar); - out.par_chunks_exact_mut(out_chunk_size) - .zip(parallel_buffer.par_iter()) - .for_each(|(out_chunk, buffer_val)| { - base_eval_eq_packed_with_packed_output::( - &eval[LOG_NUM_THREADS..(eval.len() - log_packing_width)], - out_chunk, - *buffer_val, - scalar_packed, - ); - }); + return; } + + // Base-field input: seed with `F::Packing` and apply `scalar` in the kernel (less data + // movement — see `compute_eval_eq_base`). See `compute_eval_eq` for the split. + let scalar_packed = EF::ExtensionPacking::from(scalar); + let seed = packed_eq_poly(&eval[eval.len() - log_packing_width..], F::ONE); + par_eval_eq( + eval, + out, + log_chunks, + n_chunks, + log_packing_width, + seed, + |middle, out_chunk, buffer_val| { + base_eval_eq_packed_with_packed_output::(middle, out_chunk, buffer_val, scalar_packed); + }, + ); } #[inline] @@ -412,21 +360,21 @@ pub fn compute_eval_eq_base_packed_batched( }) .collect(); - out.par_chunks_exact_mut(tile_packed_size) - .enumerate() - .for_each(|(tile_idx, out_tile)| { - for (eq_prefix, middle, eq_suffix) in &per_query { - // Here e could precompute the eq poly, trading some memory for less computation - // (2x faster on M4 max, but 2x slower on machines with smaller caches. - // TODO implement both and choose based on cache size?) - base_eval_eq_packed_with_packed_output::( - middle, - out_tile, - *eq_suffix, - EF::ExtensionPacking::from(eq_prefix[tile_idx]), - ); - } - }); + // `out` already splits into `2^n_prefix_levels` tiles — many more than there are + // workers — so the pool's task counter load-balances these directly. + parallel::par_chunks_mut(out, tile_packed_size, |tile_idx, out_tile| { + for (eq_prefix, middle, eq_suffix) in &per_query { + // Here e could precompute the eq poly, trading some memory for less computation + // (2x faster on M4 max, but 2x slower on machines with smaller caches. + // TODO implement both and choose based on cache size?) + base_eval_eq_packed_with_packed_output::( + middle, + out_tile, + *eq_suffix, + EF::ExtensionPacking::from(eq_prefix[tile_idx]), + ); + } + }); } /// Fills the `buffer` with evaluations of the equality polynomial @@ -944,39 +892,40 @@ pub fn compute_eval_eq_packed_dual( assert!(log_packing_width <= eval_a.len()); assert_eq!(out.len(), 1 << (eval_a.len() - log_packing_width)); - if eval_a.len() <= log_packing_width + 1 + LOG_NUM_THREADS { + let (log_chunks, n_chunks) = parallel_split(); + if eval_a.len() <= log_packing_width + 1 + log_chunks { let mut output_no_packing = EF::zero_vec(1 << eval_a.len()); eval_eq_basic::<_, _, _, false>(eval_a, &mut output_no_packing, scalar_a); eval_eq_basic::<_, _, _, true>(eval_b, &mut output_no_packing, scalar_b); - out.par_iter_mut() - .zip(output_no_packing.par_chunks_exact(packing_width)) + out.iter_mut() + .zip(output_no_packing.chunks_exact(packing_width)) .for_each(|(out_elem, chunk)| { *out_elem = EF::ExtensionPacking::from_ext_slice(chunk); }); } else { let eval_len_min_packing = eval_a.len() - log_packing_width; - let mut parallel_buffer_a = EF::ExtensionPacking::zero_vec(NUM_THREADS_PADDED); - let mut parallel_buffer_b = EF::ExtensionPacking::zero_vec(NUM_THREADS_PADDED); - let out_chunk_size = out.len() / NUM_THREADS_PADDED; + let mut parallel_buffer_a = EF::ExtensionPacking::zero_vec(n_chunks); + let mut parallel_buffer_b = EF::ExtensionPacking::zero_vec(n_chunks); + let out_chunk_size = out.len() / n_chunks; parallel_buffer_a[0] = packed_eq_poly(&eval_a[eval_len_min_packing..], scalar_a); - fill_buffer(eval_a[..LOG_NUM_THREADS].iter().rev(), &mut parallel_buffer_a); + fill_buffer(eval_a[..log_chunks].iter().rev(), &mut parallel_buffer_a); parallel_buffer_b[0] = packed_eq_poly(&eval_b[eval_len_min_packing..], scalar_b); - fill_buffer(eval_b[..LOG_NUM_THREADS].iter().rev(), &mut parallel_buffer_b); - - out.par_chunks_exact_mut(out_chunk_size) - .enumerate() - .for_each(|(i, out_chunk)| { - eval_eq_with_packed_output_dual::, EF>( - &eval_a[LOG_NUM_THREADS..eval_len_min_packing], - &eval_b[LOG_NUM_THREADS..eval_len_min_packing], - out_chunk, - parallel_buffer_a[i], - parallel_buffer_b[i], - ); - }); + fill_buffer(eval_b[..log_chunks].iter().rev(), &mut parallel_buffer_b); + + let middle_a = &eval_a[log_chunks..eval_len_min_packing]; + let middle_b = &eval_b[log_chunks..eval_len_min_packing]; + parallel::par_chunks_mut(out, out_chunk_size, |i, out_chunk| { + eval_eq_with_packed_output_dual::, EF>( + middle_a, + middle_b, + out_chunk, + parallel_buffer_a[i], + parallel_buffer_b[i], + ); + }); } } @@ -1312,7 +1261,7 @@ mod tests { let time = Instant::now(); compute_eval_eq::(&eval, &mut out_3, scalar); let out_3_packed = out_3 - .par_chunks_exact(packing_width) + .chunks_exact(packing_width) .map(>::ExtensionPacking::from_ext_slice) .collect::>(); println!("EXTENSION PACKED AFTER: {:?}", time.elapsed()); @@ -1347,7 +1296,7 @@ mod tests { let time = Instant::now(); compute_eval_eq_base::(&eval, &mut out_3, scalar); let out_3_packed = out_3 - .par_chunks_exact(packing_width) + .chunks_exact(packing_width) .map(>::ExtensionPacking::from_ext_slice) .collect::>(); println!("BASE PACKED AFTER: {:?}", time.elapsed()); diff --git a/crates/backend/poly/src/evals.rs b/crates/backend/poly/src/evals.rs index 7e0e07b4f..b11f2ff81 100644 --- a/crates/backend/poly/src/evals.rs +++ b/crates/backend/poly/src/evals.rs @@ -1,8 +1,8 @@ use crate::*; use crate::{EFPacking, PF}; +use ::utils::log2_ceil_usize; use field::{ExtensionField, Field, PrimeCharacteristicRing}; use itertools::Itertools; -use rayon::{join, prelude::*}; use std::borrow::Borrow; pub trait EvaluationsList { @@ -87,7 +87,11 @@ pub fn scale_poly>(poly: &[F], factor: EF) -> Ve if poly.len() < PARALLEL_THRESHOLD { poly.iter().map(|&e| factor * e).collect() } else { - poly.par_iter().map(|&e| factor * e).collect() + let mut out: Vec = unsafe { uninitialized_vec(poly.len()) }; + parallel::par_for_each_mut(&mut out, |i, o| { + *o = factor * poly[i]; + }); + out } } @@ -257,20 +261,23 @@ where // // This chain of operations computes the regrouped sum: // Σ_{v_high} eq(v_high, p_high) * (Σ_{v_low} f(v_high, v_low) * eq(v_low, p_low)) - evals - .par_chunks(left.len()) - .zip_eq(right.par_iter()) - .map(|(part, &c)| { + let left_len = left.len(); + parallel::map_reduce( + right.len(), + || Res::ZERO, + |i| { + let part = &evals[i * left_len..][..left_len]; // This is the inner sum: a dot product between the evaluation chunk and the `left` basis values. mul_res_point( part.iter() .zip_eq(left.iter()) .map(|(&a, &b)| mul_coeffs_point(a, b)) .sum::(), - c, + right[i], ) - }) - .sum() + }, + |a, b| a + b, + ) } else { evals .chunks(left.len()) @@ -290,62 +297,76 @@ where } else { // For moderately sized inputs (5 to 19 variables), use the recursive strategy. // - // Split the evaluations into two halves, corresponding to the first variable being 0 or 1. - let (f0, f1) = evals.split_at(evals.len() / 2); - - // Recursively evaluate on the two smaller hypercubes. - let (f0_eval, f1_eval) = { - // Only spawn parallel tasks if the subproblem is large enough to overcome - // the overhead of threading. - let work_size: usize = (1 << 15) / std::mem::size_of::(); - if evals.len() > work_size && PARALLEL { - join( - || { - eval_multilinear_generic::<_, _, _, _, _, _, PARALLEL>( - f0, - tail, - mul_coeffs_point, - add_res_coeffs, - mul_res_point, - ) - }, - || { - eval_multilinear_generic::<_, _, _, _, _, _, PARALLEL>( - f1, - tail, - mul_coeffs_point, - add_res_coeffs, - mul_res_point, - ) - }, + // Only spawn parallel tasks if the subproblem is large enough to overcome + // the overhead of threading. + let work_size: usize = (1 << 15) / std::mem::size_of::(); + if evals.len() > work_size && PARALLEL { + // Flat fan-out: peel the `n_split` leading variables into `2^n_split` + // independent subproblems, evaluate each over the remaining coordinates + // sequentially across the pool, then interpolate the partial results over + // the leading coordinates. Equivalent to the recursive `join` split, but + // flat so the in-house pool can parallelize it (nested dispatches fall + // back to sequential, so a recursive split would lose all parallelism). + let log_work = log2_ceil_usize(work_size.max(2)); + let n_split = point.len().saturating_sub(log_work).max(1); + let (lead, sub_point) = point.split_at(n_split); + let n_chunks = 1 << n_split; + let chunk = evals.len() >> n_split; + let partials = parallel::par_map_collect(n_chunks, |j| { + eval_multilinear_generic::<_, _, _, _, _, _, false>( + &evals[j * chunk..][..chunk], + sub_point, + mul_coeffs_point, + add_res_coeffs, + mul_res_point, ) - } else { - // For smaller subproblems, execute sequentially. - ( - eval_multilinear_generic::<_, _, _, _, _, _, false>( - f0, - tail, - mul_coeffs_point, - add_res_coeffs, - mul_res_point, - ), - eval_multilinear_generic::<_, _, _, _, _, _, false>( - f1, - tail, - mul_coeffs_point, - add_res_coeffs, - mul_res_point, - ), - ) - } - }; - // Perform the final linear interpolation for the first variable `x`. - f0_eval + mul_res_point(f1_eval - f0_eval, *x) + }); + interpolate_res(&partials, lead, mul_res_point) + } else { + let (f0, f1) = evals.split_at(evals.len() / 2); + let f0_eval = eval_multilinear_generic::<_, _, _, _, _, _, false>( + f0, + tail, + mul_coeffs_point, + add_res_coeffs, + mul_res_point, + ); + let f1_eval = eval_multilinear_generic::<_, _, _, _, _, _, false>( + f1, + tail, + mul_coeffs_point, + add_res_coeffs, + mul_res_point, + ); + // Perform the final linear interpolation for the first variable `x`. + f0_eval + mul_res_point(f1_eval - f0_eval, *x) + } } } } } +/// Multilinear interpolation of `values` (the `2^point.len()` hypercube evaluations of a +/// function, indexed lexicographically) at `point`, using only `Res` arithmetic and the +/// `mul_res_point` scaling. Used to recombine the partial results of the flat parallel +/// fan-out in [`eval_multilinear_generic`]. +fn interpolate_res(values: &[Res], point: &[Point], mul_res_point: &MRP) -> Res +where + Point: Field, + Res: Copy + PrimeCharacteristicRing, + MRP: Fn(Res, Point) -> Res, +{ + match point { + [] => values[0], + [x, tail @ ..] => { + let (low, high) = values.split_at(values.len() / 2); + let p0 = interpolate_res(low, tail, mul_res_point); + let p1 = interpolate_res(high, tail, mul_res_point); + p0 + mul_res_point(p1 - p0, *x) + } + } +} + #[cfg(test)] mod tests { use std::time::Instant; diff --git a/crates/backend/poly/src/mle/mle_single_ref.rs b/crates/backend/poly/src/mle/mle_single_ref.rs index 61d607d76..269884fdf 100644 --- a/crates/backend/poly/src/mle/mle_single_ref.rs +++ b/crates/backend/poly/src/mle/mle_single_ref.rs @@ -119,13 +119,15 @@ impl<'a, EF: ExtensionField>> MleRef<'a, EF> { pub fn fold(&self, alpha: EF) -> MleOwned { match self { - Self::Base(pols) => MleOwned::Extension(fold_multilinear(pols, alpha, &|a, b| b * a)), - Self::Extension(pols) => MleOwned::Extension(fold_multilinear(pols, alpha, &|a, b| b * a)), + Self::Base(pols) => MleOwned::Extension(fold_multilinear(pols, alpha, &|a, b| b * a, false)), + Self::Extension(pols) => MleOwned::Extension(fold_multilinear(pols, alpha, &|a, b| b * a, false)), Self::BasePacked(pols) => { let alpha_packed = EFPacking::::from(alpha); - MleOwned::ExtensionPacked(fold_multilinear(pols, alpha_packed, &|a, b| b * a)) + MleOwned::ExtensionPacked(fold_multilinear(pols, alpha_packed, &|a, b| b * a, false)) + } + Self::ExtensionPacked(pols) => { + MleOwned::ExtensionPacked(fold_multilinear(pols, alpha, &|a, b| a * b, false)) } - Self::ExtensionPacked(pols) => MleOwned::ExtensionPacked(fold_multilinear(pols, alpha, &|a, b| a * b)), } } } diff --git a/crates/backend/poly/src/utils.rs b/crates/backend/poly/src/utils.rs index 5bb5fb1b4..3f531a19c 100644 --- a/crates/backend/poly/src/utils.rs +++ b/crates/backend/poly/src/utils.rs @@ -1,14 +1,9 @@ use std::{ mem::ManuallyDrop, - ops::{Add, Range, Sub}, + ops::{Add, Sub}, }; use field::*; -use rayon::{ - iter::Zip, - prelude::*, - slice::{Iter, IterMut}, -}; use crate::{EFPacking, PF, PFPacking}; @@ -26,9 +21,9 @@ pub fn pack_extension>>(slice: &[EF]) -> Vec>>(vec: &[EFPacking]) -> Ve write(chunk, x); } } else { - out.par_chunks_exact_mut(width) - .zip(vec.par_iter()) - .for_each(|(chunk, x)| write(chunk, x)); + // One pool task per group of `group` packed elements, each writing `group * width` + // contiguous output scalars from a disjoint slice of `vec`. + let group = parallel::recommended_chunk_size(vec.len()); + parallel::par_chunks_mut(&mut out, group * width, |ci, out_chunk| { + for (k, sub) in out_chunk.chunks_exact_mut(width).enumerate() { + write(sub, &vec[ci * group + k]); + } + }); } out } @@ -67,31 +67,24 @@ pub const fn must_unpack_multilinears(n_vars: usize) -> bool { n_vars <= 1 + packing_log_width::() } -pub fn batch_fold_multilinears< - EF: PrimeCharacteristicRing + Copy + Send + Sync, - IF: Copy + Sub + Send + Sync, - OF: Copy + Add + Send + Sync, - F: Fn(IF, EF) -> OF + Sync + Send, ->( - polys: &[&[IF]], - alpha: EF, - mul_if_of: F, -) -> Vec> { - let total_size: usize = polys.iter().map(|p| p.len()).sum(); - if total_size < PARALLEL_THRESHOLD { - polys - .iter() - .map(|poly| fold_multilinear(poly, alpha, &mul_if_of)) - .collect() +/// Fill `len` output slots with `compute(i)`, parallelizing via the pool when the work is +/// large enough. `seq` forces the sequential path: the batched wrappers below dispatch one +/// pool task per poly, so their inner fold must not nest a parallel dispatch (which would +/// panic in [`parallel`]). +#[inline] +fn fold_fill OF + Sync>(len: usize, seq: bool, compute: C) -> Vec { + let mut res = unsafe { uninitialized_vec(len) }; + if seq || len < PARALLEL_THRESHOLD { + for (i, r) in res.iter_mut().enumerate() { + *r = compute(i); + } } else { - polys - .par_iter() - .map(|poly| fold_multilinear(poly, alpha, &mul_if_of)) - .collect() + parallel::par_for_each_mut(&mut res, |i, r| *r = compute(i)); } + res } -pub fn fold_multilinear_lsb< +fn fold_multilinear_lsb< EF: PrimeCharacteristicRing + Copy + Send + Sync, IF: Copy + Sub + Send + Sync, OF: Copy + Add + Send + Sync, @@ -100,20 +93,14 @@ pub fn fold_multilinear_lsb< m: &[IF], alpha: EF, mul_if_of: &Mul, + seq: bool, ) -> Vec { - let new_size = m.len() / 2; - let mut res = unsafe { uninitialized_vec(new_size) }; - let compute = |(c, r_v): (&[IF], &mut OF)| { - *r_v = mul_if_of(c[1] - c[0], alpha) + c[0]; - }; - if new_size < PARALLEL_THRESHOLD { - m.chunks_exact(2).zip(res.iter_mut()).for_each(compute); - } else { - m.par_chunks_exact(2).zip(res.par_iter_mut()).for_each(compute); - } - res + fold_fill(m.len() / 2, seq, |j| { + mul_if_of(m[2 * j + 1] - m[2 * j], alpha) + m[2 * j] + }) } +/// Fold `m` at variable `bit`. `seq` forces sequential execution (see [`fold_fill`]). pub fn fold_multilinear_at_bit< EF: PrimeCharacteristicRing + Copy + Send + Sync, IF: Copy + Sub + Send + Sync, @@ -124,40 +111,24 @@ pub fn fold_multilinear_at_bit< alpha: EF, bit: usize, mul_if_of: &Mul, + seq: bool, ) -> Vec { - let new_size = m.len() / 2; assert!(m.len() >= 2 * (1 << bit), "bit out of range for slice length"); - if bit == 0 { - return fold_multilinear_lsb(m, alpha, mul_if_of); + return fold_multilinear_lsb(m, alpha, mul_if_of, seq); } - let stride = 1usize << bit; let lo_mask = stride - 1; - let mut res = unsafe { uninitialized_vec(new_size) }; - - let compute = |new_j: usize| { + fold_fill(m.len() / 2, seq, |new_j| { let i_hi = new_j >> bit; let i_lo = new_j & lo_mask; let i0 = (i_hi << (bit + 1)) | i_lo; let i1 = i0 | stride; mul_if_of(m[i1] - m[i0], alpha) + m[i0] - }; - - if new_size < PARALLEL_THRESHOLD { - for (new_j, res_v) in res.iter_mut().enumerate() { - *res_v = compute(new_j); - } - } else { - (0..new_size) - .into_par_iter() - .with_min_len(PARALLEL_THRESHOLD) - .map(compute) - .collect_into_vec(&mut res); - } - res + }) } +/// Fold `m` at its top variable. `seq` forces sequential execution (see [`fold_fill`]). pub fn fold_multilinear< EF: PrimeCharacteristicRing + Copy + Send + Sync, IF: Copy + Sub + Send + Sync, @@ -167,22 +138,31 @@ pub fn fold_multilinear< m: &[IF], alpha: EF, mul_if_of: &F, + seq: bool, ) -> Vec { let new_size = m.len() / 2; - let mut res = unsafe { uninitialized_vec(new_size) }; + fold_fill(new_size, seq, |i| mul_if_of(m[i + new_size] - m[i], alpha) + m[i]) +} - if new_size < PARALLEL_THRESHOLD { - for i in 0..new_size { - res[i] = mul_if_of(m[i + new_size] - m[i], alpha) + m[i]; - } +pub fn batch_fold_multilinears< + EF: PrimeCharacteristicRing + Copy + Send + Sync, + IF: Copy + Sub + Send + Sync, + OF: Copy + Add + Send + Sync, + F: Fn(IF, EF) -> OF + Sync + Send, +>( + polys: &[&[IF]], + alpha: EF, + mul_if_of: F, +) -> Vec> { + let total_size: usize = polys.iter().map(|p| p.len()).sum(); + if total_size < PARALLEL_THRESHOLD { + polys + .iter() + .map(|poly| fold_multilinear(poly, alpha, &mul_if_of, true)) + .collect() } else { - (0..new_size) - .into_par_iter() - .with_min_len(PARALLEL_THRESHOLD) - .map(|i| mul_if_of(m[i + new_size] - m[i], alpha) + m[i]) - .collect_into_vec(&mut res); + parallel::par_map_collect(polys.len(), |i| fold_multilinear(polys[i], alpha, &mul_if_of, true)) } - res } pub fn batch_fold_multilinears_at_bit< @@ -196,17 +176,17 @@ pub fn batch_fold_multilinears_at_bit< bit: usize, mul_if_of: F, ) -> Vec> { + // See `batch_fold_multilinears`: one task per poly, inner fold forced sequential. let total_size: usize = polys.iter().map(|p| p.len()).sum(); if total_size < PARALLEL_THRESHOLD { polys .iter() - .map(|poly| fold_multilinear_at_bit(poly, alpha, bit, &mul_if_of)) + .map(|poly| fold_multilinear_at_bit(poly, alpha, bit, &mul_if_of, true)) .collect() } else { - polys - .par_iter() - .map(|poly| fold_multilinear_at_bit(poly, alpha, bit, &mul_if_of)) - .collect() + parallel::par_map_collect(polys.len(), |i| { + fold_multilinear_at_bit(polys[i], alpha, bit, &mul_if_of, true) + }) } } @@ -281,54 +261,6 @@ pub fn split_at_mut_many<'a, A>(slice: &'a mut [A], indices: &[usize]) -> Vec<&' result } -// Parallel - -#[allow(clippy::type_complexity)] -pub fn par_iter_split_4<'a, A: Sync + Send>( - u: &'a [A], -) -> Zip, Iter<'a, A>>, Zip, Iter<'a, A>>> { - let n = u.len(); - assert!(n.is_multiple_of(4)); - let [u_ll, u_lr, u_rl, u_rr] = split_at_many(u, &[n / 4, n / 2, 3 * n / 4]).try_into().ok().unwrap(); - (u_ll.par_iter().zip(u_lr)).zip(u_rl.par_iter().zip(u_rr.par_iter())) -} - -pub fn par_iter_split_2<'a, A: Sync + Send>(u: &'a [A]) -> Zip, Iter<'a, A>> { - par_iter_split_2_capped(u, 0..u.len() / 2) -} - -pub fn par_iter_split_2_capped<'a, A: Sync + Send>(u: &'a [A], range: Range) -> Zip, Iter<'a, A>> { - let n = u.len(); - assert!(n.is_multiple_of(2)); - let (u_left, u_right) = u.split_at(n / 2); - u_left[range.clone()].par_iter().zip(u_right[range.clone()].par_iter()) -} - -pub fn par_iter_mut_split_2<'a, A: Sync + Send>(u: &'a mut [A]) -> Zip, IterMut<'a, A>> { - par_iter_mut_split_2_capped(u, 0..u.len() / 2) -} - -pub fn par_iter_mut_split_2_capped<'a, A: Sync + Send>( - u: &'a mut [A], - range: Range, -) -> Zip, IterMut<'a, A>> { - let n = u.len(); - assert!(n.is_multiple_of(2)); - let (u_left, u_right) = u.split_at_mut(n / 2); - u_left[range.clone()].par_iter_mut().zip(u_right[range].par_iter_mut()) -} - -#[allow(clippy::type_complexity)] -pub fn par_zip_fold_2<'a, 'b, A: Sync + Send, B: Sync + Send>( - u: &'a [A], - folded: &'b mut [B], -) -> Zip, Iter<'a, A>>, Zip, Iter<'a, A>>>, Zip, IterMut<'b, B>>> { - let n = u.len(); - assert!(n.is_multiple_of(4)); - assert_eq!(folded.len(), n / 2); - par_iter_split_4(u).zip(par_iter_mut_split_2(folded)) -} - // Sequential pub fn iter_split_2(u: &[A]) -> impl Iterator { diff --git a/crates/backend/src/lib.rs b/crates/backend/src/lib.rs index cbd44fb2b..f4cc2d18f 100644 --- a/crates/backend/src/lib.rs +++ b/crates/backend/src/lib.rs @@ -2,9 +2,8 @@ pub use air::*; pub use fiat_shamir::*; pub use field::*; pub use koala_bear::*; +pub use parallel; pub use poly::*; -pub use rayon; -pub use rayon::prelude::*; pub use sumcheck::*; pub use symetric::*; pub use utils::*; diff --git a/crates/backend/sumcheck/Cargo.toml b/crates/backend/sumcheck/Cargo.toml index 91085f352..1d5f486ca 100644 --- a/crates/backend/sumcheck/Cargo.toml +++ b/crates/backend/sumcheck/Cargo.toml @@ -8,8 +8,8 @@ field = { path = "../field", package = "mt-field" } air = { path = "../air", package = "mt-air" } poly = { path = "../poly", package = "mt-poly" } fiat-shamir = { path = "../fiat-shamir", package = "mt-fiat-shamir" } +parallel.workspace = true tracing.workspace = true -rayon.workspace = true [dev-dependencies] koala-bear = { path = "../koala-bear", package = "mt-koala-bear" } diff --git a/crates/backend/sumcheck/src/product_computation.rs b/crates/backend/sumcheck/src/product_computation.rs index 12d8882c5..f0e46e3f4 100644 --- a/crates/backend/sumcheck/src/product_computation.rs +++ b/crates/backend/sumcheck/src/product_computation.rs @@ -1,7 +1,6 @@ use fiat_shamir::*; use field::*; use poly::*; -use rayon::prelude::*; use tracing::instrument; use crate::{SumcheckComputation, sumcheck_prove_many_rounds}; @@ -146,15 +145,21 @@ pub fn compute_product_sumcheck_polynomial< (a0 + b0, a2 + b2) }) } else { - pol_0[..n / 2] - .par_iter() - .zip(pol_0[n / 2..].par_iter()) - .zip(pol_1[..n / 2].par_iter().zip(pol_1[n / 2..].par_iter())) - .map(sumcheck_quadratic) - .reduce( - || (EFPacking::ZERO, EFPacking::ZERO), - |(a0, a2), (b0, b2)| (a0 + b0, a2 + b2), - ) + // Per-worker in-place accumulation: each worker folds the contiguous range it + // claims straight into its own `(c0, c2)` accumulator (no per-chunk tuple to build + // and reduce, worker-slot lookup amortized once per batch by `for_each_chunk`). + let half = n / 2; + parallel::map_reduce_with_state( + half, + || (), + || (EFPacking::ZERO, EFPacking::ZERO), + |(), acc, i| { + let (b0, b2) = sumcheck_quadratic(((&pol_0[i], &pol_0[half + i]), (&pol_1[i], &pol_1[half + i]))); + acc.0 += b0; + acc.1 += b2; + }, + |(a0, a2), (b0, b2)| (a0 + b0, a2 + b2), + ) }; let c0 = decompose(c0_packed).into_iter().sum::(); @@ -212,13 +217,41 @@ pub fn fold_and_compute_product_sumcheck_polynomial< (a0 + b0, a2 + b2) }) } else { - par_zip_fold_2(pol_0, &mut pol_0_folded) - .zip(par_zip_fold_2(pol_1, &mut pol_1_folded)) - .map(|(p0, p1)| process_element(p0, p1)) - .reduce( - || (EFPacking::ZERO, EFPacking::ZERO), - |(a0, a2), (b0, b2)| (a0 + b0, a2 + b2), - ) + // Fused single pass with per-worker in-place accumulation: fold both polynomials + // (writing the disjoint `i` / `quarter + i` output slots) and accumulate the + // per-index quadratic straight into the worker's `(c0, c2)` — no per-chunk tuple. + let quarter = n / 4; + let p0f = parallel::SendPtr(pol_0_folded.as_mut_ptr()); + let p1f = parallel::SendPtr(pol_1_folded.as_mut_ptr()); + parallel::map_reduce_with_state( + quarter, + || (), + || (EFPacking::ZERO, EFPacking::ZERO), + |(), acc, i| { + let diff_0 = pol_0[2 * quarter + i] - pol_0[i]; + let diff_1 = pol_0[3 * quarter + i] - pol_0[quarter + i]; + let x_0 = prev_folding_factor_packed * diff_0 + pol_0[i]; + let x_1 = prev_folding_factor_packed * diff_1 + pol_0[quarter + i]; + + let y_0 = prev_folding_factor_packed * (pol_1[2 * quarter + i] - pol_1[i]) + pol_1[i]; + let y_1 = + prev_folding_factor_packed * (pol_1[3 * quarter + i] - pol_1[quarter + i]) + pol_1[quarter + i]; + + // SAFETY: distinct `i` write disjoint slots `i` and `quarter + i` in + // `[0, n/2)`; the dispatcher keeps both buffers borrowed for the call. + unsafe { + *p0f.add(i) = x_0; + *p0f.add(quarter + i) = x_1; + *p1f.add(i) = y_0; + *p1f.add(quarter + i) = y_1; + } + + let (b0, b2) = sumcheck_quadratic(((&x_0, &x_1), (&y_0, &y_1))); + acc.0 += b0; + acc.1 += b2; + }, + |(a0, a2), (b0, b2)| (a0 + b0, a2 + b2), + ) }; let c0 = decompose(c0_packed).into_iter().sum::(); diff --git a/crates/backend/sumcheck/src/sc_computation.rs b/crates/backend/sumcheck/src/sc_computation.rs index 6f589bb7b..56d63d0b3 100644 --- a/crates/backend/sumcheck/src/sc_computation.rs +++ b/crates/backend/sumcheck/src/sc_computation.rs @@ -2,10 +2,16 @@ use crate::*; use air::*; use field::*; use poly::*; -use rayon::prelude::*; use std::any::TypeId; use std::ops::{Add, AddAssign, Mul, MulAssign, Sub}; +fn add_assign_vec(mut a: Vec, b: Vec) -> Vec { + for (x, y) in a.iter_mut().zip(b) { + *x += y; + } + a +} + pub trait SumcheckComputation>>: Sync { type ExtraData: Send + Sync + 'static; @@ -58,44 +64,12 @@ where } } -fn parallel_sum(size: usize, n: usize, init_state: IS, compute_iteration: F) -> Vec -where - T: PrimeCharacteristicRing + Send + Sync, - S: Send, - IS: Fn() -> S + Sync + Send, - F: Fn(&mut S, usize) -> Vec + Sync + Send, -{ - let accumulate = |mut acc: Vec, sums: Vec| { - for (j, sum) in sums.into_iter().enumerate() { - acc[j] += sum; - } - acc - }; - - if size < PARALLEL_THRESHOLD { - let mut state = init_state(); - (0..size).fold(T::zero_vec(n), |acc, i| { - accumulate(acc, compute_iteration(&mut state, i)) - }) - } else { - (0..size) - .into_par_iter() - .map_init(&init_state, |state, i| compute_iteration(state, i)) - .reduce(|| T::zero_vec(n), accumulate) - } -} - fn build_evals>>( sums: impl IntoIterator, missing_mul_factor: Option, ) -> Vec { sums.into_iter() - .map(|mut sum| { - if let Some(factor) = missing_mul_factor { - sum *= factor; - } - sum - }) + .map(|sum| missing_mul_factor.map_or(sum, |f| sum * f)) .collect() } @@ -425,49 +399,49 @@ where + MulAssign, SC: SumcheckComputation, { + // Per-worker scratch: `rows` (the [lo, diff, hi] triples) and `point` (the + // evaluation point handed to `eval_fn`) are reused across every task a worker + // owns, so the hot loop allocates nothing. `acc` (length `degree`) is the + // per-worker partial sum. let n_mult = multilinears.len(); - let compute_at = |(rows, point): &mut (Vec<[IF; 3]>, Vec), i: usize| -> Vec { - let eq_val = eq_at(i); - - rows.clear(); - rows.extend(multilinears.iter().map(|m| { - let lo = m[i]; - let hi = m[i + fold_size]; - [lo, hi - lo, hi] - })); - - // z = 0 - point.clear(); - point.extend(rows.iter().map(|row| row[0])); - let mut eval_0 = eval_fn(computation, point, extra_data); - if let Some(eq) = eq_val { - eval_0 *= eq; - } - - let mut evals = Vec::with_capacity(degree); - evals.push(eval_0); - - // z = 2, 3, ... - for _ in 1..degree { - for [_, diff_hi_lo, running] in rows.iter_mut() { - *running += *diff_hi_lo; - } + let sums = parallel::map_reduce_with_state( + fold_size, + || (Vec::<[IF; 3]>::with_capacity(n_mult), Vec::::with_capacity(n_mult)), + || EFT::zero_vec(degree), + |(rows, point), acc, i| { + let eq_val = eq_at(i); + + rows.clear(); + rows.extend(multilinears.iter().map(|m| { + let lo = m[i]; + let hi = m[i + fold_size]; + [lo, hi - lo, hi] + })); + + // z = 0 point.clear(); - point.extend(rows.iter().map(|row| row[2])); - let mut eval = eval_fn(computation, point, extra_data); + point.extend(rows.iter().map(|row| row[0])); + let mut eval_0 = eval_fn(computation, point, extra_data); if let Some(eq) = eq_val { - eval *= eq; + eval_0 *= eq; } - evals.push(eval); - } - evals - }; + acc[0] += eval_0; - let sums = parallel_sum( - fold_size, - degree, - || (Vec::<[IF; 3]>::with_capacity(n_mult), Vec::::with_capacity(n_mult)), - compute_at, + // z = 2, 3, ... + for acc_d in acc.iter_mut().skip(1) { + for [_, diff_hi_lo, running] in rows.iter_mut() { + *running += *diff_hi_lo; + } + point.clear(); + point.extend(rows.iter().map(|row| row[2])); + let mut eval = eval_fn(computation, point, extra_data); + if let Some(eq) = eq_val { + eval *= eq; + } + *acc_d += eval; + } + }, + add_assign_vec, ); let unpacked_sums = sums.into_iter().map(&unpack_sum); build_evals(unpacked_sums, missing_mul_factor) @@ -500,54 +474,54 @@ where .map(|_| FT::zero_vec(prev_folded_size)) .collect(); + // Per-worker scratch: `rows_f` (the [lo, diff, hi] triples) and `point` (the + // evaluation point handed to `eval_fn`) are reused across every task a worker + // owns, so the hot loop allocates nothing. `acc` (length `degree`) is the + // per-worker partial sum. let n_mult = multilinears.len(); - let compute_iteration = |(rows_f, point): &mut (Vec<[FT; 3]>, Vec), i: usize| -> Vec { - let eq_mle_eval = eq_at(i); - - rows_f.clear(); - rows_f.extend(multilinears.iter().enumerate().map(|(j, m)| { - let lo = fold_f(m, i); - let hi = fold_f(m, i + compute_fold_size); - unsafe { - let ptr = folded_f[j].as_ptr() as *mut FT; - *ptr.add(i) = lo; - *ptr.add(i + compute_fold_size) = hi; - } - [lo, hi - lo, hi] - })); - - // z = 0 - point.clear(); - point.extend(rows_f.iter().map(|row| row[0])); - let mut eval_0 = eval_fn(computation, point, extra_data); - if let Some(eq) = eq_mle_eval { - eval_0 *= eq; - } - - let mut evals = Vec::with_capacity(degree); - evals.push(eval_0); + let sums = parallel::map_reduce_with_state( + compute_fold_size, + || (Vec::<[FT; 3]>::with_capacity(n_mult), Vec::::with_capacity(n_mult)), + || FT::zero_vec(degree), + |(rows_f, point), acc, i| { + let eq_mle_eval = eq_at(i); + + rows_f.clear(); + rows_f.extend(multilinears.iter().enumerate().map(|(j, m)| { + let lo = fold_f(m, i); + let hi = fold_f(m, i + compute_fold_size); + unsafe { + let ptr = folded_f[j].as_ptr() as *mut FT; + *ptr.add(i) = lo; + *ptr.add(i + compute_fold_size) = hi; + } + [lo, hi - lo, hi] + })); - // z = 2, 3, ... - for _ in 1..degree { - for [_, diff_hi_lo, running] in rows_f.iter_mut() { - *running += *diff_hi_lo; - } + // z = 0 point.clear(); - point.extend(rows_f.iter().map(|row| row[2])); - let mut eval = eval_fn(computation, point, extra_data); + point.extend(rows_f.iter().map(|row| row[0])); + let mut eval_0 = eval_fn(computation, point, extra_data); if let Some(eq) = eq_mle_eval { - eval *= eq; + eval_0 *= eq; } - evals.push(eval); - } - evals - }; + acc[0] += eval_0; - let sums = parallel_sum( - compute_fold_size, - degree, - || (Vec::<[FT; 3]>::with_capacity(n_mult), Vec::::with_capacity(n_mult)), - compute_iteration, + // z = 2, 3, ... + for acc_d in acc.iter_mut().skip(1) { + for [_, diff_hi_lo, running] in rows_f.iter_mut() { + *running += *diff_hi_lo; + } + point.clear(); + point.extend(rows_f.iter().map(|row| row[2])); + let mut eval = eval_fn(computation, point, extra_data); + if let Some(eq) = eq_mle_eval { + eval *= eq; + } + *acc_d += eval; + } + }, + add_assign_vec, ); let unpacked_sums = sums.into_iter().map(&unpack_sum); (build_evals(unpacked_sums, missing_mul_factor), wrap_f(folded_f)) @@ -575,65 +549,60 @@ where let eq_lo = &split_eq.eq_lo; let eq_hi = &split_eq.eq_hi_packed; - let zero = || EFPacking::::zero_vec(degree); - let accumulate = |mut acc: Vec>, vals: Vec>| -> Vec> { - for (a, v) in acc.iter_mut().zip(vals.iter()) { - *a += *v; - } - acc - }; - + // Per-worker scratch reused across every `b_lo` task: `rows` ([lo, diff, hi] + // triples), `point` (handed to `eval_fn`), and `block_acc` (per-`b_lo` partial + // sum, scaled by `eq_lo` before folding into the worker accumulator `acc`). let n_mult = multilinears.len(); - let sums: Vec> = (0..n_lo) - .into_par_iter() - .map_init( - || { - ( - Vec::<[EFPacking; 3]>::with_capacity(n_mult), - Vec::>::with_capacity(n_mult), - ) - }, - |(rows, point), b_lo| { - let eq_lo_bc = EFPacking::::from(eq_lo[b_lo]); - let base = b_lo << log_packed_hi; - let mut block_acc = zero(); - for k in 0..packed_hi { - let i = base + k; - let eq_val = eq_hi[k]; - - rows.clear(); - rows.extend(multilinears.iter().map(|m| { - let lo = m[i]; - let hi = m[i + fold_size]; - [lo, hi - lo, hi] - })); - - // z = 0 - point.clear(); - point.extend(rows.iter().map(|r| r[0])); - let mut e0 = eval_fn(computation, point, extra_data); - e0 *= eq_val; - block_acc[0] += e0; - - // z = 2, 3, ... - for d in 1..degree { - for [_, diff, running] in rows.iter_mut() { - *running += *diff; - } - point.clear(); - point.extend(rows.iter().map(|r| r[2])); - let mut ev = eval_fn(computation, point, extra_data); - ev *= eq_val; - block_acc[d] += ev; + let sums: Vec> = parallel::map_reduce_with_state( + n_lo, + || { + ( + Vec::<[EFPacking; 3]>::with_capacity(n_mult), + Vec::>::with_capacity(n_mult), + EFPacking::::zero_vec(degree), + ) + }, + || EFPacking::::zero_vec(degree), + |(rows, point, block_acc), acc, b_lo| { + let eq_lo_bc = EFPacking::::from(eq_lo[b_lo]); + let base = b_lo << log_packed_hi; + block_acc.iter_mut().for_each(|x| *x = EFPacking::::ZERO); + for k in 0..packed_hi { + let i = base + k; + let eq_val = eq_hi[k]; + + rows.clear(); + rows.extend(multilinears.iter().map(|m| { + let lo = m[i]; + let hi = m[i + fold_size]; + [lo, hi - lo, hi] + })); + + // z = 0 + point.clear(); + point.extend(rows.iter().map(|r| r[0])); + let mut e0 = eval_fn(computation, point, extra_data); + e0 *= eq_val; + block_acc[0] += e0; + + // z = 2, 3, ... + for d in 1..degree { + for [_, diff, running] in rows.iter_mut() { + *running += *diff; } + point.clear(); + point.extend(rows.iter().map(|r| r[2])); + let mut ev = eval_fn(computation, point, extra_data); + ev *= eq_val; + block_acc[d] += ev; } - for a in &mut block_acc { - *a *= eq_lo_bc; - } - block_acc - }, - ) - .reduce(zero, accumulate); + } + for (a, b) in acc.iter_mut().zip(block_acc.iter()) { + *a += *b * eq_lo_bc; + } + }, + add_assign_vec, + ); let unpacked = sums.into_iter().map(&unpack_sum); build_evals(unpacked, missing_mul_factor) @@ -670,69 +639,64 @@ where let eq_lo = &split_eq.eq_lo; let eq_hi = &split_eq.eq_hi_packed; - let zero = || EFPacking::::zero_vec(degree); - let accumulate = |mut acc: Vec>, vals: Vec>| -> Vec> { - for (a, v) in acc.iter_mut().zip(vals.iter()) { - *a += *v; - } - acc - }; - + // Per-worker scratch reused across every `b_lo` task (see `sumcheck_compute_with_split_eq`): + // `rows_f` triples, `point` for `eval_fn`, and the per-`b_lo` `block_acc`. let n_mult = multilinears.len(); - let sums: Vec> = (0..n_lo) - .into_par_iter() - .map_init( - || { - ( - Vec::<[EFPacking; 3]>::with_capacity(n_mult), - Vec::>::with_capacity(n_mult), - ) - }, - |(rows_f, point), b_lo| { - let eq_lo_bc = EFPacking::::from(eq_lo[b_lo]); - let base = b_lo << log_packed_hi; - let mut block_acc = zero(); - for k in 0..packed_hi { - let i = base + k; - let eq_val = eq_hi[k]; - - rows_f.clear(); - rows_f.extend(multilinears.iter().enumerate().map(|(j, m)| { - let lo = fold_f(m, i); - let hi = fold_f(m, i + compute_fold_size); - unsafe { - let ptr = folded_f[j].as_ptr() as *mut EFPacking; - *ptr.add(i) = lo; - *ptr.add(i + compute_fold_size) = hi; - } - [lo, hi - lo, hi] - })); - - // z = 0 - point.clear(); - point.extend(rows_f.iter().map(|r| r[0])); - let mut e0 = eval_fn(computation, point, extra_data); - e0 *= eq_val; - block_acc[0] += e0; - - for d in 1..degree { - for [_, diff, running] in rows_f.iter_mut() { - *running += *diff; - } - point.clear(); - point.extend(rows_f.iter().map(|r| r[2])); - let mut ev = eval_fn(computation, point, extra_data); - ev *= eq_val; - block_acc[d] += ev; + let sums: Vec> = parallel::map_reduce_with_state( + n_lo, + || { + ( + Vec::<[EFPacking; 3]>::with_capacity(n_mult), + Vec::>::with_capacity(n_mult), + EFPacking::::zero_vec(degree), + ) + }, + || EFPacking::::zero_vec(degree), + |(rows_f, point, block_acc), acc, b_lo| { + let eq_lo_bc = EFPacking::::from(eq_lo[b_lo]); + let base = b_lo << log_packed_hi; + block_acc.iter_mut().for_each(|x| *x = EFPacking::::ZERO); + for k in 0..packed_hi { + let i = base + k; + let eq_val = eq_hi[k]; + + rows_f.clear(); + rows_f.extend(multilinears.iter().enumerate().map(|(j, m)| { + let lo = fold_f(m, i); + let hi = fold_f(m, i + compute_fold_size); + unsafe { + let ptr = folded_f[j].as_ptr() as *mut EFPacking; + *ptr.add(i) = lo; + *ptr.add(i + compute_fold_size) = hi; } + [lo, hi - lo, hi] + })); + + // z = 0 + point.clear(); + point.extend(rows_f.iter().map(|r| r[0])); + let mut e0 = eval_fn(computation, point, extra_data); + e0 *= eq_val; + block_acc[0] += e0; + + // z = 2, 3, ... + for d in 1..degree { + for [_, diff, running] in rows_f.iter_mut() { + *running += *diff; + } + point.clear(); + point.extend(rows_f.iter().map(|r| r[2])); + let mut ev = eval_fn(computation, point, extra_data); + ev *= eq_val; + block_acc[d] += ev; } - for a in &mut block_acc { - *a *= eq_lo_bc; - } - block_acc - }, - ) - .reduce(zero, accumulate); + } + for (a, b) in acc.iter_mut().zip(block_acc.iter()) { + *a += *b * eq_lo_bc; + } + }, + add_assign_vec, + ); let unpacked = sums.into_iter().map(&unpack_sum); (build_evals(unpacked, missing_mul_factor), wrap_f(folded_f)) diff --git a/crates/backend/symetric/Cargo.toml b/crates/backend/symetric/Cargo.toml index 125fb5535..86b7c3cd3 100644 --- a/crates/backend/symetric/Cargo.toml +++ b/crates/backend/symetric/Cargo.toml @@ -6,4 +6,4 @@ edition.workspace = true [dependencies] koala-bear = { path = "../koala-bear", package = "mt-koala-bear" } field = { path = "../field", package = "mt-field" } -rayon.workspace = true +parallel.workspace = true diff --git a/crates/backend/symetric/src/merkle.rs b/crates/backend/symetric/src/merkle.rs index 2fe194855..4b609a09d 100644 --- a/crates/backend/symetric/src/merkle.rs +++ b/crates/backend/symetric/src/merkle.rs @@ -4,7 +4,6 @@ use std::array; use field::PackedValue; -use rayon::prelude::*; use crate::Compression; @@ -67,18 +66,18 @@ where let default_digest = [P::Value::default(); DIGEST_ELEMS]; let mut next_digests = vec![default_digest; next_len_padded]; - next_digests[0..next_len] - .par_chunks_exact_mut(width) - .enumerate() - .for_each(|(i, digests_chunk)| { - let first_row = i * width; - let left = array::from_fn(|j| P::from_fn(|k| prev_layer[2 * (first_row + k)][j])); - let right = array::from_fn(|j| P::from_fn(|k| prev_layer[2 * (first_row + k) + 1][j])); - let packed_digest = crate::compress(comp, [left, right]); - for (dst, src) in digests_chunk.iter_mut().zip(unpack_array(packed_digest)) { - *dst = src; - } - }); + // Process only the full packed chunks in parallel (matches `par_chunks_exact_mut`); + // the `< width` remainder is handled by the sequential tail loop below. + let n_full = next_len / width * width; + parallel::par_chunks_mut(&mut next_digests[0..n_full], width, |i, digests_chunk| { + let first_row = i * width; + let left = array::from_fn(|j| P::from_fn(|k| prev_layer[2 * (first_row + k)][j])); + let right = array::from_fn(|j| P::from_fn(|k| prev_layer[2 * (first_row + k) + 1][j])); + let packed_digest = crate::compress(comp, [left, right]); + for (dst, src) in digests_chunk.iter_mut().zip(unpack_array(packed_digest)) { + *dst = src; + } + }); for i in (next_len / width * width)..next_len { let left = prev_layer[2 * i]; diff --git a/crates/backend/system-info/Cargo.toml b/crates/backend/system-info/Cargo.toml index c63ee1297..862e36e89 100644 --- a/crates/backend/system-info/Cargo.toml +++ b/crates/backend/system-info/Cargo.toml @@ -5,7 +5,6 @@ edition.workspace = true [dependencies] libc = "0.2" -rayon.workspace = true [lints] workspace = true diff --git a/crates/backend/system-info/src/lib.rs b/crates/backend/system-info/src/lib.rs index 07180559b..5323c1ce4 100644 --- a/crates/backend/system-info/src/lib.rs +++ b/crates/backend/system-info/src/lib.rs @@ -9,36 +9,3 @@ pub fn peak_rss_bytes() -> u64 { // ru_maxrss unit: bytes on macOS, KiB on Linux. if cfg!(target_os = "macos") { max } else { max * 1024 } } - -/// Number of jobs [`flush_rayon`] pushes. Must exceed -/// `crossbeam_deque::deque::BLOCK_CAP` (currently 63 — -/// `crossbeam-deque-0.8.6/src/deque.rs:1191`). -const RAYON_FLUSH_JOBS: usize = 256; - -/// Drain rayon's internal queues so they release any storage allocated during the -/// previous phase. -/// -/// Rayon's global pool owns a `crossbeam_deque::Injector`, internally a linked list -/// of fixed-size blocks (`Block` and `Injector::push` — -/// `crossbeam-deque-0.8.6/src/deque.rs:1219` and `:1371`). A block is freed only -/// once its last slot has been consumed. -/// -/// `rayon::join` from a non-worker thread reaches that injector via -/// `join` (`rayon-core-1.13.0/src/join/mod.rs:132`) -> -/// `registry::in_worker` (`registry.rs:946`) -> -/// `Registry::in_worker_cold` (`:517`) -> -/// `Registry::inject` (`:428`) -> `Injector::push`. -/// -/// Under an arena allocator that recycles memory between phases (e.g. `zk-alloc`), -/// a block allocated *during* a phase points into a slab the next `begin_phase()` -/// will reuse. The next push then writes a `JobRef` straight through whatever the -/// application has placed on top, silently corrupting it. -/// -/// Pushing more than `BLOCK_CAP` jobs while the arena is off forces the Injector -/// to allocate a fresh tail block (which lands in System), and forces workers to -/// steal the last slot of every preceding block (which destroys them). -pub fn flush_rayon() { - for _ in 0..RAYON_FLUSH_JOBS { - rayon::join(|| {}, || {}); - } -} diff --git a/crates/backend/zk-alloc/Cargo.toml b/crates/backend/zk-alloc/Cargo.toml index fe4c12233..0c4ab6a5f 100644 --- a/crates/backend/zk-alloc/Cargo.toml +++ b/crates/backend/zk-alloc/Cargo.toml @@ -7,9 +7,6 @@ description = "Bump+reset arena allocator for ZK proving workloads" [dependencies] system-info.workspace = true -[dev-dependencies] -rayon.workspace = true - [target.'cfg(not(all(target_os = "linux", target_arch = "x86_64")))'.dependencies] libc = "0.2" diff --git a/crates/backend/zk-alloc/src/lib.rs b/crates/backend/zk-alloc/src/lib.rs index 1b43143d6..f0433d3a4 100644 --- a/crates/backend/zk-alloc/src/lib.rs +++ b/crates/backend/zk-alloc/src/lib.rs @@ -7,7 +7,6 @@ //! back to the system allocator. //! //! ```ignore -//! init(); // once, at process start //! loop { //! begin_phase(); // arena ON; slabs reset lazily //! let res = heavy_work(); // fast increments @@ -88,15 +87,6 @@ fn ensure_region() -> usize { REGION_BASE.load(Ordering::Acquire) } -/// Call once at process start, before any `begin_phase()`. -pub fn init() { - let actual_num_threads = std::thread::available_parallelism().unwrap().get(); - assert_eq!( - actual_num_threads, NUM_THREADS, - "built for {NUM_THREADS} threads but this machine reports {actual_num_threads} -> please rebuild`" - ); -} - /// Activates the arena and resets every thread's slab. All allocations until the next /// `end_phase()` go to the arena; the previous phase's data is overwritten in place. pub fn begin_phase() { @@ -111,11 +101,11 @@ pub fn begin_phase() { /// Deactivates the arena. New allocations go to the system allocator; existing arena /// pointers stay valid until the next `begin_phase()` resets the slabs. /// -/// Also calls [`system_info::flush_rayon`] to release any rayon/crossbeam storage -/// still referencing this phase's arena memory. +/// Unlike the rayon-based build (which needed `flush_rayon` to drain crossbeam's +/// arena-allocated injector blocks), the in-house `parallel` pool allocates its state +/// once at startup and nothing per-dispatch, so no flush is required here. pub fn end_phase() { ARENA_ACTIVE.store(false, Ordering::Release); - system_info::flush_rayon(); } #[cold] diff --git a/crates/backend/zk-alloc/tests/test_rayon.rs b/crates/backend/zk-alloc/tests/test_rayon.rs deleted file mode 100644 index ae084af21..000000000 --- a/crates/backend/zk-alloc/tests/test_rayon.rs +++ /dev/null @@ -1,26 +0,0 @@ -//! Regression test for the bug prevented by `system_info::flush_rayon`. - -use rayon::prelude::*; - -#[global_allocator] -static A: zk_alloc::ZkAllocator = zk_alloc::ZkAllocator; - -#[test] -fn rayon_does_not_corrupt_zkalloc() { - zk_alloc::init(); - let _: u64 = (0..1_000_000_u64).into_par_iter().sum(); - - zk_alloc::begin_phase(); - for _ in 0..200 { - rayon::join(|| {}, || {}); - } - zk_alloc::end_phase(); - - zk_alloc::begin_phase(); - let canary = vec![0xAB_u8; 8192]; - rayon::join(|| {}, || {}); - zk_alloc::end_phase(); - - let pos = canary.iter().position(|&b| b != 0xAB); - assert!(pos.is_none(), "canary corrupted at offset {}", pos.unwrap()); -} diff --git a/crates/lean_compiler/src/c_compile_final.rs b/crates/lean_compiler/src/c_compile_final.rs index 3585abd84..003607d70 100644 --- a/crates/lean_compiler/src/c_compile_final.rs +++ b/crates/lean_compiler/src/c_compile_final.rs @@ -132,7 +132,10 @@ pub fn compile_to_low_level_bytecode( validate_instruction(instruction)?; } - let instructions_encoded = instructions.par_iter().map(field_representation).collect::>(); + let mut instructions_encoded: Vec<[F; N_INSTRUCTION_COLUMNS]> = unsafe { uninitialized_vec(instructions.len()) }; + parallel::par_for_each_mut(&mut instructions_encoded, |i, out| { + *out = field_representation(&instructions[i]); + }); let mut instructions_multilinear = vec![]; for instr in &instructions_encoded { diff --git a/crates/lean_prover/src/prove_execution.rs b/crates/lean_prover/src/prove_execution.rs index aaf50be3b..a952b373b 100644 --- a/crates/lean_prover/src/prove_execution.rs +++ b/crates/lean_prover/src/prove_execution.rs @@ -165,8 +165,9 @@ pub fn prove_execution( }) .collect(); let _span = info_span!("Computing shifted columns for AIR sumcheck").entered(); + // Only a few tables; run them serially and let `compute_shifted_columns` use the full pool. let shifted_rows: Vec>> = ALL_TABLES - .par_iter() + .iter() .zip(&column_refs) .map(|(table, cols)| compute_shifted_columns(table.n_shift_columns(), cols)) .collect(); diff --git a/crates/lean_prover/src/trace_gen.rs b/crates/lean_prover/src/trace_gen.rs index cd0e401be..86a736e25 100644 --- a/crates/lean_prover/src/trace_gen.rs +++ b/crates/lean_prover/src/trace_gen.rs @@ -1,7 +1,7 @@ use backend::*; use lean_vm::*; use std::{array, collections::BTreeMap}; -use utils::{ToUsize, get_poseidon_16_of_zero, transposed_par_iter_mut}; +use utils::{ToUsize, get_poseidon_16_of_zero, transposed_par_for_each_mut}; #[derive(Debug)] pub struct ExecutionTrace { @@ -27,74 +27,76 @@ pub fn get_execution_trace( } } - transposed_par_iter_mut(&mut main_trace) - .zip(execution_result.pcs.par_iter()) - .zip(execution_result.fps.par_iter()) - .for_each(|((trace_row, &pc), &fp)| { - let instruction = &bytecode.code[pc].instruction; - let field_repr = &bytecode.instructions_multilinear[pc * N_INSTRUCTION_COLUMNS.next_power_of_two()..] - [..N_INSTRUCTION_COLUMNS]; - - let flag_a = field_repr[instr_idx(EXEC_COL_FLAG_A)]; - let flag_b = field_repr[instr_idx(EXEC_COL_FLAG_B)]; - let flag_c = field_repr[instr_idx(EXEC_COL_FLAG_C)]; - let flag_c_fp = field_repr[instr_idx(EXEC_COL_FLAG_C_FP)]; - let flag_ab_fp = field_repr[instr_idx(EXEC_COL_FLAG_AB_FP)]; - let aux_1 = field_repr[instr_idx(EXEC_COL_AUX_1)]; - let is_deref = aux_1 == F::TWO; - - let mut addr_a = F::ZERO; - if flag_a.is_zero() && flag_ab_fp.is_zero() { - addr_a = F::from_usize(fp) + field_repr[instr_idx(EXEC_COL_OPERAND_A)]; - } - let value_a = memory.0.get(addr_a.to_usize()).copied().flatten().unwrap_or_default(); - - let mut addr_b = F::ZERO; - if flag_b.is_zero() && flag_ab_fp.is_zero() { - addr_b = F::from_usize(fp) + field_repr[instr_idx(EXEC_COL_OPERAND_B)]; - } else if is_deref { - // DEREF: addr_B = value_A + operand_B - addr_b = value_a + field_repr[instr_idx(EXEC_COL_OPERAND_B)]; - } - let value_b = memory.0.get(addr_b.to_usize()).copied().flatten().unwrap_or_default(); + transposed_par_for_each_mut(&mut main_trace, |i, trace_row| { + let pc = execution_result.pcs[i]; + let fp = execution_result.fps[i]; + let instruction = &bytecode.code[pc].instruction; + let field_repr = &bytecode.instructions_multilinear[pc * N_INSTRUCTION_COLUMNS.next_power_of_two()..] + [..N_INSTRUCTION_COLUMNS]; + + let flag_a = field_repr[instr_idx(EXEC_COL_FLAG_A)]; + let flag_b = field_repr[instr_idx(EXEC_COL_FLAG_B)]; + let flag_c = field_repr[instr_idx(EXEC_COL_FLAG_C)]; + let flag_c_fp = field_repr[instr_idx(EXEC_COL_FLAG_C_FP)]; + let flag_ab_fp = field_repr[instr_idx(EXEC_COL_FLAG_AB_FP)]; + let aux_1 = field_repr[instr_idx(EXEC_COL_AUX_1)]; + let is_deref = aux_1 == F::TWO; + + let mut addr_a = F::ZERO; + if flag_a.is_zero() && flag_ab_fp.is_zero() { + addr_a = F::from_usize(fp) + field_repr[instr_idx(EXEC_COL_OPERAND_A)]; + } + let value_a = memory.0.get(addr_a.to_usize()).copied().flatten().unwrap_or_default(); + + let mut addr_b = F::ZERO; + if flag_b.is_zero() && flag_ab_fp.is_zero() { + addr_b = F::from_usize(fp) + field_repr[instr_idx(EXEC_COL_OPERAND_B)]; + } else if is_deref { + // DEREF: addr_B = value_A + operand_B + addr_b = value_a + field_repr[instr_idx(EXEC_COL_OPERAND_B)]; + } + let value_b = memory.0.get(addr_b.to_usize()).copied().flatten().unwrap_or_default(); - let mut addr_c = F::ZERO; - if flag_c.is_zero() && flag_c_fp.is_zero() { - addr_c = F::from_usize(fp) + field_repr[instr_idx(EXEC_COL_OPERAND_C)]; - } - let value_c = memory.0.get(addr_c.to_usize()).copied().flatten().unwrap_or_default(); + let mut addr_c = F::ZERO; + if flag_c.is_zero() && flag_c_fp.is_zero() { + addr_c = F::from_usize(fp) + field_repr[instr_idx(EXEC_COL_OPERAND_C)]; + } + let value_c = memory.0.get(addr_c.to_usize()).copied().flatten().unwrap_or_default(); - for (j, field) in field_repr.iter().enumerate() { - *trace_row[j + N_RUNTIME_COLUMNS] = *field; - } + for (j, field) in field_repr.iter().enumerate() { + *trace_row[j + N_RUNTIME_COLUMNS] = *field; + } - let nu_a = flag_a * field_repr[instr_idx(EXEC_COL_OPERAND_A)] - + (F::ONE - flag_a - flag_ab_fp) * value_a - + flag_ab_fp * (F::from_usize(fp) + field_repr[instr_idx(EXEC_COL_OPERAND_A)]); - let nu_b = flag_b * field_repr[instr_idx(EXEC_COL_OPERAND_B)] - + (F::ONE - flag_b - flag_ab_fp) * value_b - + flag_ab_fp * (F::from_usize(fp) + field_repr[instr_idx(EXEC_COL_OPERAND_B)]); - let nu_c = flag_c * field_repr[instr_idx(EXEC_COL_OPERAND_C)] - + (F::ONE - flag_c - flag_c_fp) * value_c - + flag_c_fp * (F::from_usize(fp) + field_repr[instr_idx(EXEC_COL_OPERAND_C)]); - if let Instruction::Precompile(..) = instruction { - *trace_row[EXEC_COL_FLAG_PRECOMPILE] = F::ONE; - } - *trace_row[EXEC_COL_NU_A] = nu_a; - *trace_row[EXEC_COL_NU_B] = nu_b; - *trace_row[EXEC_COL_NU_C] = nu_c; - - *trace_row[EXEC_COL_VALUE_A] = value_a; - *trace_row[EXEC_COL_VALUE_B] = value_b; - *trace_row[EXEC_COL_VALUE_C] = value_c; - *trace_row[EXEC_COL_PC] = F::from_usize(pc); - *trace_row[EXEC_COL_FP] = F::from_usize(fp); - *trace_row[EXEC_COL_ADDR_A] = addr_a; - *trace_row[EXEC_COL_ADDR_B] = addr_b; - *trace_row[EXEC_COL_ADDR_C] = addr_c; - }); + let nu_a = flag_a * field_repr[instr_idx(EXEC_COL_OPERAND_A)] + + (F::ONE - flag_a - flag_ab_fp) * value_a + + flag_ab_fp * (F::from_usize(fp) + field_repr[instr_idx(EXEC_COL_OPERAND_A)]); + let nu_b = flag_b * field_repr[instr_idx(EXEC_COL_OPERAND_B)] + + (F::ONE - flag_b - flag_ab_fp) * value_b + + flag_ab_fp * (F::from_usize(fp) + field_repr[instr_idx(EXEC_COL_OPERAND_B)]); + let nu_c = flag_c * field_repr[instr_idx(EXEC_COL_OPERAND_C)] + + (F::ONE - flag_c - flag_c_fp) * value_c + + flag_c_fp * (F::from_usize(fp) + field_repr[instr_idx(EXEC_COL_OPERAND_C)]); + if let Instruction::Precompile(..) = instruction { + *trace_row[EXEC_COL_FLAG_PRECOMPILE] = F::ONE; + } + *trace_row[EXEC_COL_NU_A] = nu_a; + *trace_row[EXEC_COL_NU_B] = nu_b; + *trace_row[EXEC_COL_NU_C] = nu_c; + + *trace_row[EXEC_COL_VALUE_A] = value_a; + *trace_row[EXEC_COL_VALUE_B] = value_b; + *trace_row[EXEC_COL_VALUE_C] = value_c; + *trace_row[EXEC_COL_PC] = F::from_usize(pc); + *trace_row[EXEC_COL_FP] = F::from_usize(fp); + *trace_row[EXEC_COL_ADDR_A] = addr_a; + *trace_row[EXEC_COL_ADDR_B] = addr_b; + *trace_row[EXEC_COL_ADDR_C] = addr_c; + }); - let mut memory_padded = memory.0.par_iter().map(|&v| v.unwrap_or(F::ZERO)).collect::>(); + let mut memory_padded: Vec = unsafe { uninitialized_vec(memory.0.len()) }; + parallel::par_for_each_mut(&mut memory_padded, |i, slot| { + *slot = memory.0[i].unwrap_or(F::ZERO); + }); // Write [0000000000000000 | poseidon_compress(0000000000000000)] (to make lookups work on padding-rows). let padding_zero_vec_ptr = memory_padded.len(); @@ -124,23 +126,22 @@ pub fn get_execution_trace( const N: usize = HALF_DIGEST_LEN + DIGEST_LEN; let cols: &mut [Vec; N] = (&mut right[..N]).try_into().unwrap(); - transposed_par_iter_mut(cols) - .zip(flag_out4_col) - .zip(flag_out8_col) - .zip(nu_c_col) - .for_each(|(((row, &flag_out4), &flag_out8), &nu_c)| { - let base = nu_c.to_usize(); - if flag_out4 == F::ONE { - for j in 0..HALF_DIGEST_LEN { - *row[j] = memory_padded[base + HALF_DIGEST_LEN + j]; - } + transposed_par_for_each_mut(cols, |i, row| { + let flag_out4 = flag_out4_col[i]; + let flag_out8 = flag_out8_col[i]; + let nu_c = nu_c_col[i]; + let base = nu_c.to_usize(); + if flag_out4 == F::ONE { + for j in 0..HALF_DIGEST_LEN { + *row[j] = memory_padded[base + HALF_DIGEST_LEN + j]; } - if flag_out8 == F::ONE || flag_out4 == F::ONE { - for j in 0..DIGEST_LEN { - *row[HALF_DIGEST_LEN + j] = memory_padded[base + DIGEST_LEN + j]; - } + } + if flag_out8 == F::ONE || flag_out4 == F::ONE { + for j in 0..DIGEST_LEN { + *row[HALF_DIGEST_LEN + j] = memory_padded[base + DIGEST_LEN + j]; } - }); + } + }); } let extension_op_trace = traces.get_mut(&Table::extension_op()).unwrap(); @@ -197,7 +198,8 @@ fn pad_table( trace.log_n_rows = log2_ceil_usize(h + 1).max(min_log_n_rows); let n_rows = 1 << trace.log_n_rows; let padding_row = table.padding_row(zero_vec_ptr, null_poseidon_16_hash_ptr, ending_pc); - trace.columns.par_iter_mut().enumerate().for_each(|(i, col)| { + parallel::par_chunks_mut(&mut trace.columns, 1, |i, slot| { + let col = &mut slot[0]; assert!(col.len() <= h); // potentially some columns have not been filled (in Poseidon -> we fill it later with SIMD + parallelism), but the first one should always be representative col.resize(n_rows, padding_row[i]); }); diff --git a/crates/lean_vm/src/execution/memory.rs b/crates/lean_vm/src/execution/memory.rs index 364eb7471..38b5a41e2 100644 --- a/crates/lean_vm/src/execution/memory.rs +++ b/crates/lean_vm/src/execution/memory.rs @@ -11,6 +11,16 @@ pub trait MemoryAccess { (0..len).map(|i| self.get(start + i)).collect() } + /// In-place version of [`get_slice`] that writes into a caller-provided buffer, + /// avoiding a per-call heap allocation on the hot interpreter path (Poseidon / + /// extension-op slice reads run hundreds of thousands of times per proof). + fn get_slice_into(&self, start: usize, dest: &mut [F]) -> Result<(), RunnerError> { + for (i, d) in dest.iter_mut().enumerate() { + *d = self.get(start + i)?; + } + Ok(()) + } + fn set_slice(&mut self, start: usize, values: &[F]) -> Result<(), RunnerError> { for (i, v) in values.iter().enumerate() { self.set(start + i, *v)?; @@ -77,7 +87,7 @@ impl MemoryAccess for Memory { impl Memory { pub fn new(public_memory: Vec) -> Self { - Self(public_memory.into_par_iter().map(Some).collect()) + Self(public_memory.into_iter().map(Some).collect()) } pub fn get(&self, index: usize) -> Result { diff --git a/crates/lean_vm/src/execution/runner.rs b/crates/lean_vm/src/execution/runner.rs index f00e04880..8a47e716d 100644 --- a/crates/lean_vm/src/execution/runner.rs +++ b/crates/lean_vm/src/execution/runner.rs @@ -333,7 +333,12 @@ fn execute_bytecode_helper( None }; let runtime_memory_size = memory.0.len() - PUBLIC_INPUT_LEN - witness.preamble_memory_len; - let used_memory_cells = memory.0.par_iter().filter(|&&x| x.is_some()).count(); + let used_memory_cells = parallel::map_reduce( + memory.0.len(), + || 0usize, + |i| usize::from(memory.0[i].is_some()), + |a, b| a + b, + ); let metadata = ExecutionMetadata { cycles: trace.pcs.len(), memory: memory.0.len(), @@ -432,13 +437,26 @@ fn handle_parallel_batch( let split_at = batch.batch_fp + stride; // end of iteration 0's frame let (left, right) = memory.0.split_at_mut(split_at); let shared: &[Option] = &*left; - let segment_slices: Vec<&mut [Option]> = right.chunks_mut(stride).take(n_par).collect(); + let mut segment_slices: Vec<&mut [Option]> = right.chunks_mut(stride).take(n_par).collect(); type SegResult = Result<(Trace, Vec<(usize, F)>), RunnerError>; - let results: Vec = segment_slices - .into_par_iter() - .enumerate() - .map(|(i, seg_slice)| { + + // Raw base pointer + length per disjoint segment, so the pool can run each segment + // on its own slice without moving `&mut` references through the `Fn` task closure. + // SAFETY: segments are non-overlapping `chunks_mut` of `right`; task `i` touches only `i`. + let seg_info: Vec<(parallel::SendPtr>, usize)> = segment_slices + .iter_mut() + .map(|s| (parallel::SendPtr(s.as_mut_ptr()), s.len())) + .collect(); + // Release the `&mut` borrows so only the raw pointers alias the segments. + drop(segment_slices); + + let mut results: Vec> = (0..n_par).map(|_| None).collect(); + parallel::par_chunks_mut(&mut results, 1, |i, out| { + let (seg_ptr, seg_len) = &seg_info[i]; + // SAFETY: distinct `i` reconstruct disjoint segments of `right`, valid for the dispatch. + let seg_slice: &mut [Option] = unsafe { std::slice::from_raw_parts_mut(seg_ptr.0, *seg_len) }; + out[0] = Some((|| -> SegResult { let seg_start = split_at + i * stride; let mut seg_mem = SegmentMemory::new(shared, seg_slice, seg_start); let fp_i = batch.batch_fp + (i + 1) * stride; @@ -452,8 +470,10 @@ fn handle_parallel_batch( cursor.index += i * delta; } } - let seg_start_indices: HashMap<_, _> = - seg_named_hints.iter().map(|(name, c)| (name.clone(), c.index)).collect(); + let seg_start_indices: HashMap<_, _> = seg_named_hints + .iter() + .map(|(name, c)| (name.clone(), c.index)) + .collect(); let mut hints = HintState { diagnostics: None, named_hints: &mut seg_named_hints, @@ -478,8 +498,9 @@ fn handle_parallel_batch( } let deferred = seg_mem.into_deferred_writes(); Ok((seg_trace, deferred)) - }) - .collect(); + })()); + }); + let results: Vec = results.into_iter().map(Option::unwrap).collect(); for (idx, result) in results.into_iter().enumerate() { let (seg_trace, deferred) = result.map_err(|e| RunnerError::ParallelSegmentFailed(idx + 1, Box::new(e)))?; diff --git a/crates/lean_vm/src/tables/poseidon/mod.rs b/crates/lean_vm/src/tables/poseidon/mod.rs index fb1efcb8d..f0fff6dd0 100644 --- a/crates/lean_vm/src/tables/poseidon/mod.rs +++ b/crates/lean_vm/src/tables/poseidon/mod.rs @@ -241,14 +241,14 @@ impl TableT for Poseidon16Precompile { } else { arg_a_usize + HALF_DIGEST_LEN }; - let arg0_first = ctx.memory.get_slice(left_first_addr, HALF_DIGEST_LEN)?; - let arg0_second = ctx.memory.get_slice(left_second_addr, HALF_DIGEST_LEN)?; - let arg1 = ctx.memory.get_slice(arg_b.to_usize(), DIGEST_LEN)?; - + // Fill the Poseidon input array directly from memory — no per-call Vec allocation + // (this runs once per Poseidon instruction, the dominant small-alloc source). let mut input = [F::ZERO; DIGEST_LEN * 2]; - input[..HALF_DIGEST_LEN].copy_from_slice(&arg0_first); - input[HALF_DIGEST_LEN..DIGEST_LEN].copy_from_slice(&arg0_second); - input[DIGEST_LEN..].copy_from_slice(&arg1); + ctx.memory + .get_slice_into(left_first_addr, &mut input[..HALF_DIGEST_LEN])?; + ctx.memory + .get_slice_into(left_second_addr, &mut input[HALF_DIGEST_LEN..DIGEST_LEN])?; + ctx.memory.get_slice_into(arg_b.to_usize(), &mut input[DIGEST_LEN..])?; let res_addr = index_res_a.to_usize(); if permute { diff --git a/crates/lean_vm/src/tables/poseidon/trace_gen.rs b/crates/lean_vm/src/tables/poseidon/trace_gen.rs index 9022f6c33..dc3963b75 100644 --- a/crates/lean_vm/src/tables/poseidon/trace_gen.rs +++ b/crates/lean_vm/src/tables/poseidon/trace_gen.rs @@ -20,9 +20,10 @@ pub fn fill_trace_poseidon_16(trace: &mut [Vec]) { const N_COLS: usize = super::num_cols_poseidon_16(); - // fill the packed rows + // fill the packed rows. Bind a fixed-size array ref so the per-row `array::from_fn` + // indexing elides bounds checks (one length check here, none in the hot loop). let cols: &[&[FPacking]; N_COLS] = (&trace_packed[..N_COLS]).try_into().unwrap(); - (0..m / packing_width::()).into_par_iter().for_each(|i| { + parallel::for_each_index(m / packing_width::(), |i| { let ptrs: [*mut FPacking; N_COLS] = std::array::from_fn(|c| unsafe { (cols[c].as_ptr() as *mut FPacking).add(i) }); let perm: &mut Poseidon1Cols16<&mut FPacking> = diff --git a/crates/rec_aggregation/Cargo.toml b/crates/rec_aggregation/Cargo.toml index ac111ba3f..7cadde763 100644 --- a/crates/rec_aggregation/Cargo.toml +++ b/crates/rec_aggregation/Cargo.toml @@ -8,12 +8,14 @@ workspace = true [features] prox-gaps-conjecture = ["lean_prover/prox-gaps-conjecture"] +# Skip the zk-alloc bump-arena phase ceremony (use the plain system allocator). standard-alloc = [] [dependencies] utils.workspace = true xmss.workspace = true rand.workspace = true +zk-alloc.workspace = true tracing.workspace = true include_dir.workspace = true @@ -25,7 +27,6 @@ backend.workspace = true postcard.workspace = true lz4_flex.workspace = true serde.workspace = true -zk-alloc.workspace = true [target.'cfg(target_os = "macos")'.dependencies] objc2 = { version = "0.6.4", default-features = false, features = ["std"] } diff --git a/crates/rec_aggregation/src/bytecode_claims.rs b/crates/rec_aggregation/src/bytecode_claims.rs index 91c44b369..081347b8a 100644 --- a/crates/rec_aggregation/src/bytecode_claims.rs +++ b/crates/rec_aggregation/src/bytecode_claims.rs @@ -64,15 +64,11 @@ pub(crate) fn reduce_bytecode_claims(verified: &[InnerVerified]) -> ReducedBytec let alpha: EF = reduction_prover.sample(); let alpha_powers: Vec = alpha.powers().take(n_claims).collect(); - let weights_packed = claims - .par_iter() - .zip(&alpha_powers) - .map(|(eval, &alpha_i)| eval_eq_packed_scaled(&eval.point.0, alpha_i)) - .reduce_with(|mut acc, eq_i| { - acc.par_iter_mut().zip(&eq_i).for_each(|(w, e)| *w += *e); - acc - }) - .unwrap(); + let n_vars = claims[0].point.0.len(); + let mut weights_packed = EFPacking::::zero_vec(1 << (n_vars - packing_log_width::())); + for (claim, &alpha_pow) in claims.iter().zip(&alpha_powers) { + compute_eval_eq_packed::(&claim.point.0, &mut weights_packed, alpha_pow); + } let claimed_sum: EF = dot_product(claims.iter().map(|c| c.value), alpha_powers.iter().copied()); diff --git a/crates/sub_protocols/src/air_sumcheck.rs b/crates/sub_protocols/src/air_sumcheck.rs index 0f536d7fa..8600de649 100644 --- a/crates/sub_protocols/src/air_sumcheck.rs +++ b/crates/sub_protocols/src/air_sumcheck.rs @@ -89,22 +89,20 @@ where let _span = info_span!("chunk-bit-reversing columns").entered(); let chunk_size = 1usize << pivot; let shift = usize::BITS as usize - pivot; - let bit_reversed = cols - .par_iter() - .map(|&src| { - let mut dst: Vec> = unsafe { uninitialized_vec(src.len()) }; - let src_u = PFPacking::::unpack_slice(src); - let dst_u = PFPacking::::unpack_slice_mut(&mut dst); - for (src_chunk, dst_chunk) in - src_u.chunks_exact(chunk_size).zip(dst_u.chunks_exact_mut(chunk_size)) - { - for (p, slot) in dst_chunk.iter_mut().enumerate() { - *slot = src_chunk[p.reverse_bits() >> shift]; - } + let mut bit_reversed: Vec>> = (0..cols.len()).map(|_| Vec::new()).collect(); + parallel::par_chunks_mut(&mut bit_reversed, 1, |i, out_slot| { + let src = cols[i]; + let mut dst: Vec> = unsafe { uninitialized_vec(src.len()) }; + let src_u = PFPacking::::unpack_slice(src); + let dst_u = PFPacking::::unpack_slice_mut(&mut dst); + for (src_chunk, dst_chunk) in src_u.chunks_exact(chunk_size).zip(dst_u.chunks_exact_mut(chunk_size)) + { + for (p, slot) in dst_chunk.iter_mut().enumerate() { + *slot = src_chunk[p.reverse_bits() >> shift]; } - dst - }) - .collect(); + } + out_slot[0] = dst; + }); MleGroup::Owned(MleGroupOwned::BasePacked(bit_reversed)) } _ => unreachable!(), @@ -438,120 +436,112 @@ where let hi_zs_halved: Vec<_> = hi_zs.iter().map(|&tz| tz.halve()).collect(); let lagrange_coeffs = lagrange_basis_evals(&low_zs, &hi_zs); - let acc = (0..active_count_pairs) - .into_par_iter() - .fold( - || { - ( - vec![EFPacking::::ZERO; degree], - Vec::::with_capacity(n_cols), - Vec::::with_capacity(n_cols), - vec![EFPacking::::ZERO; n_full], - Vec::::new(), - Vec::::new(), - Vec::::new(), - ) - }, - |(mut acc, mut point, mut diff, mut low_evals, mut state_0, mut state_2, mut cached_buf), new_j| { - let i_hi = new_j >> fold_bit; - let i_lo = new_j & lo_mask; - let i0 = (i_hi << (fold_bit + 1)) | i_lo; - let i1 = i0 | stride; - let partial_eq = get_split_eq(new_j); - - // `point` holds column values at z=0; `diff[k] = col_k[i1] - col_k[i0]`. - // Invariant for the rest of this closure: `col_k(z) = point[k] + z · diff[k]`, - // so advancing z by 1 means `point[k] += diff[k]` for all k. - point.clear(); - diff.clear(); - for c in cols { - let lo = c[i0]; - let hi = c[i1]; - point.push(lo); - diff.push(hi - lo); - } + let acc = parallel::map_reduce_with_state( + active_count_pairs, + || { + ( + Vec::::with_capacity(n_cols), + Vec::::with_capacity(n_cols), + vec![EFPacking::::ZERO; n_full], + Vec::::new(), + Vec::::new(), + Vec::::new(), + ) + }, + || vec![EFPacking::::ZERO; degree], + |(point, diff, low_evals, state_0, state_2, cached_buf), acc, new_j| { + let i_hi = new_j >> fold_bit; + let i_lo = new_j & lo_mask; + let i0 = (i_hi << (fold_bit + 1)) | i_lo; + let i1 = i0 | stride; + let partial_eq = get_split_eq(new_j); + + // `point` holds column values at z=0; `diff[k] = col_k[i1] - col_k[i0]`. + // Invariant for the rest of this closure: `col_k(z) = point[k] + z · diff[k]`, + // so advancing z by 1 means `point[k] += diff[k]` for all k. + point.clear(); + diff.clear(); + for c in cols { + let lo = c[i0]; + let hi = c[i1]; + point.push(lo); + diff.push(hi - lo); + } - // Phase 1: full AIR constraints + // Phase 1: full AIR constraints - // z = 0: full eval, capture post-block state. - { - let mut folder = ConstraintFolderPacked::new(&point[..n_flat], &point[n_flat..], extra_data); - folder.cached_state = Some(state_0); - Air::eval(computation, &mut folder, extra_data); - acc[0] += folder.accumulator * partial_eq; - low_evals[0] = folder.accumulator_low; - state_0 = folder.cached_state.unwrap(); - } + // z = 0: full eval, capture post-block state. + { + let mut folder = ConstraintFolderPacked::new(&point[..n_flat], &point[n_flat..], extra_data); + folder.cached_state = Some(std::mem::take(state_0)); + Air::eval(computation, &mut folder, extra_data); + acc[0] += folder.accumulator * partial_eq; + low_evals[0] = folder.accumulator_low; + *state_0 = folder.cached_state.unwrap(); + } - // z = 2: advance `point` by 2·diff, full eval, capture post-block state. - // Together with `state_0` this pins down the linear `state(z)` (linear when we "omit" the low degree constraints of the block) + // z = 2: advance `point` by 2·diff, full eval, capture post-block state. + // Together with `state_0` this pins down the linear `state(z)` (linear when we "omit" the low degree constraints of the block) + for k in 0..n_cols { + point[k] += diff[k].double(); + } + { + let mut folder = ConstraintFolderPacked::new(&point[..n_flat], &point[n_flat..], extra_data); + folder.cached_state = Some(std::mem::take(state_2)); + Air::eval(computation, &mut folder, extra_data); + acc[1] += folder.accumulator * partial_eq; + low_evals[1] = folder.accumulator_low; + *state_2 = folder.cached_state.unwrap(); + } + + // z = 3, …, d_low+1: still doing full eval + for z_idx in 2..n_full { for k in 0..n_cols { - point[k] += diff[k].double(); - } - { - let mut folder = ConstraintFolderPacked::new(&point[..n_flat], &point[n_flat..], extra_data); - folder.cached_state = Some(state_2); - Air::eval(computation, &mut folder, extra_data); - acc[1] += folder.accumulator * partial_eq; - low_evals[1] = folder.accumulator_low; - state_2 = folder.cached_state.unwrap(); + point[k] += diff[k]; } + let mut folder = ConstraintFolderPacked::new(&point[..n_flat], &point[n_flat..], extra_data); + Air::eval(computation, &mut folder, extra_data); + acc[z_idx] += folder.accumulator * partial_eq; + low_evals[z_idx] = folder.accumulator_low; + } - // z = 3, …, d_low+1: still doing full eval - for z_idx in 2..n_full { - for k in 0..n_cols { - point[k] += diff[k]; - } - let mut folder = ConstraintFolderPacked::new(&point[..n_flat], &point[n_flat..], extra_data); - Air::eval(computation, &mut folder, extra_data); - acc[z_idx] += folder.accumulator * partial_eq; - low_evals[z_idx] = folder.accumulator_low; + // Phase 2: skip the low degree constraints of the block + // For each skipped point, assemble Constraints(z) = high(z) + low(z): + // -high(z): run folder with `skip_low = true` + // -low(z): deduce it via Lagrange-interpolation from previous computations + for t in 0..n_skip { + for k in 0..n_cols { + point[k] += diff[k]; } - // Phase 2: skip the low degree constraints of the block - // For each skipped point, assemble Constraints(z) = high(z) + low(z): - // -high(z): run folder with `skip_low = true` - // -low(z): deduce it via Lagrange-interpolation from previous computations - for t in 0..n_skip { - for k in 0..n_cols { - point[k] += diff[k]; - } - - cached_buf.clear(); - for i in 0..state_0.len() { - cached_buf - .push(state_0[i] + (state_2[i] - state_0[i]) * PFPacking::::from(hi_zs_halved[t])); - } - - let mut folder = ConstraintFolderPacked::new(&point[..n_flat], &point[n_flat..], extra_data); - folder.skip_low = true; - folder.cached_state = Some(cached_buf); - folder.low_ci_count = low_n_constraints; - Air::eval(computation, &mut folder, extra_data); - cached_buf = folder.cached_state.unwrap(); - - // low(hi_zs[t]) = Σ_i L_i(hi_zs[t]) · low(low_zs[i]) - let mut low_interpolated = EFPacking::::ZERO; - for (i, lc) in lagrange_coeffs[t].iter().enumerate() { - low_interpolated += low_evals[i] * PFPacking::::from(*lc); - } - - acc[n_full + t] += (folder.accumulator + low_interpolated) * partial_eq; + cached_buf.clear(); + for i in 0..state_0.len() { + cached_buf.push(state_0[i] + (state_2[i] - state_0[i]) * PFPacking::::from(hi_zs_halved[t])); } - (acc, point, diff, low_evals, state_0, state_2, cached_buf) - }, - ) - .map(|(acc, ..)| acc) - .reduce( - || vec![EFPacking::::ZERO; degree], - |mut a, b| { - for i in 0..degree { - a[i] += b[i]; + let mut folder = ConstraintFolderPacked::new(&point[..n_flat], &point[n_flat..], extra_data); + folder.skip_low = true; + folder.cached_state = Some(std::mem::take(cached_buf)); + folder.low_ci_count = low_n_constraints; + Air::eval(computation, &mut folder, extra_data); + *cached_buf = folder.cached_state.unwrap(); + + // low(hi_zs[t]) = Σ_i L_i(hi_zs[t]) · low(low_zs[i]) + let mut low_interpolated = EFPacking::::ZERO; + for (i, lc) in lagrange_coeffs[t].iter().enumerate() { + low_interpolated += low_evals[i] * PFPacking::::from(*lc); } - a - }, - ); + + acc[n_full + t] += (folder.accumulator + low_interpolated) * partial_eq; + } + }, + |mut a, b| { + for i in 0..degree { + a[i] += b[i]; + } + a + }, + ); acc.into_iter().map(&unpack_sum).collect() } @@ -581,54 +571,43 @@ where let stride = 1usize << fold_bit; let lo_mask = stride - 1; - let acc = (0..active_count_pairs) - .into_par_iter() - .fold( - || { - ( - vec![EFT::ZERO; degree], - Vec::::with_capacity(n_cols), - Vec::::with_capacity(n_cols), - ) - }, - |(mut acc, mut point, mut diff), new_j| { - let i_hi = new_j >> fold_bit; - let i_lo = new_j & lo_mask; - let i0 = (i_hi << (fold_bit + 1)) | i_lo; - let i1 = i0 | stride; - let partial_eq = get_split_eq(new_j); - point.clear(); - diff.clear(); - for c in cols { - let lo = c[i0]; - let hi = c[i1]; - point.push(lo); - diff.push(hi - lo); - } - // z = 0 then (skip z = 1) z = 2, 3, …, degree. - acc[0] += eval_fn(computation, &point, extra_data) * partial_eq; + let acc = parallel::map_reduce_with_state( + active_count_pairs, + || (Vec::::with_capacity(n_cols), Vec::::with_capacity(n_cols)), + || vec![EFT::ZERO; degree], + |(point, diff), acc, new_j| { + let i_hi = new_j >> fold_bit; + let i_lo = new_j & lo_mask; + let i0 = (i_hi << (fold_bit + 1)) | i_lo; + let i1 = i0 | stride; + let partial_eq = get_split_eq(new_j); + point.clear(); + diff.clear(); + for c in cols { + let lo = c[i0]; + let hi = c[i1]; + point.push(lo); + diff.push(hi - lo); + } + // z = 0 then (skip z = 1) z = 2, 3, …, degree. + acc[0] += eval_fn(computation, point, extra_data) * partial_eq; + for k in 0..n_cols { + point[k] += diff[k]; + } + for acc_z in &mut acc[1..] { for k in 0..n_cols { point[k] += diff[k]; } - for acc_z in &mut acc[1..] { - for k in 0..n_cols { - point[k] += diff[k]; - } - *acc_z += eval_fn(computation, &point, extra_data) * partial_eq; - } - (acc, point, diff) - }, - ) - .map(|(acc, _, _)| acc) - .reduce( - || vec![EFT::ZERO; degree], - |mut a, b| { - for i in 0..degree { - a[i] += b[i]; - } - a - }, - ); + *acc_z += eval_fn(computation, point, extra_data) * partial_eq; + } + }, + |mut a, b| { + for i in 0..degree { + a[i] += b[i]; + } + a + }, + ); acc.into_iter().map(unpack_sum).collect() } @@ -680,15 +659,15 @@ pub fn prove_batched_air_sumcheck<'a, EF: ExtensionField>>( pub fn compute_shifted_columns(n_shift_columns: usize, columns: &[&[F]]) -> Vec> { // Convention: the first `n_shift_columns` columns are the ones that get shifted. - columns[..n_shift_columns] - .par_iter() - .map(|column| { - let mut shifted = unsafe { uninitialized_vec(column.len()) }; - shifted[..column.len() - 1].copy_from_slice(&column[1..]); - shifted[column.len() - 1] = column[column.len() - 1]; - shifted - }) - .collect() + let mut out: Vec> = (0..n_shift_columns).map(|_| Vec::new()).collect(); + parallel::par_chunks_mut(&mut out, 1, |i, slot| { + let column = columns[i]; + let mut shifted = unsafe { uninitialized_vec(column.len()) }; + shifted[..column.len() - 1].copy_from_slice(&column[1..]); + shifted[column.len() - 1] = column[column.len() - 1]; + slot[0] = shifted; + }); + out } pub fn natural_ordering_point_for_session(sumcheck_air_point: &[EF], log_n_rows: usize) -> Vec { diff --git a/crates/sub_protocols/src/logup.rs b/crates/sub_protocols/src/logup.rs index 55af0a320..85578a1d7 100644 --- a/crates/sub_protocols/src/logup.rs +++ b/crates/sub_protocols/src/logup.rs @@ -72,15 +72,13 @@ pub fn prove_generic_logup( }; let fill_num_from = |dst: &mut [F], src: &[F], neg: bool| { - dst.par_chunks_exact_mut(chunk_size) - .enumerate() - .for_each(|(c, dst_chunk)| { - let src_chunk = &src[c * chunk_size..][..chunk_size]; - for (i, slot) in dst_chunk.iter_mut().enumerate() { - let v = src_chunk[i.reverse_bits() >> chunk_shift]; - *slot = if neg { -v } else { v }; - } - }); + parallel::par_chunks_mut(dst, chunk_size, |c, dst_chunk| { + let src_chunk = &src[c * chunk_size..][..chunk_size]; + for (i, slot) in dst_chunk.iter_mut().enumerate() { + let v = src_chunk[i.reverse_bits() >> chunk_shift]; + *slot = if neg { -v } else { v }; + } + }); }; let mut offset = 0; @@ -118,12 +116,14 @@ pub fn prove_generic_logup( ); if 1 << log_bytecode < max_table_height { // padding - numerators[offset + (1 << log_bytecode)..offset + max_table_height] - .par_iter_mut() - .for_each(|n| *n = F::ZERO); - denominators[(offset + (1 << log_bytecode)) / width..(offset + max_table_height) / width] - .par_iter_mut() - .for_each(|d| *d = EFPacking::::ONE); + par_fill( + &mut numerators[offset + (1 << log_bytecode)..offset + max_table_height], + |_| F::ZERO, + ); + par_fill( + &mut denominators[(offset + (1 << log_bytecode)) / width..(offset + max_table_height) / width], + |_| EFPacking::::ONE, + ); } offset += max_table_height.max(1 << log_bytecode); @@ -142,17 +142,15 @@ pub fn prove_generic_logup( let col_index = &trace.columns[group.idx_col]; let packed_chunk_size = (1 << log_n_rows) / width; - numerators[offset..][..group_len << log_n_rows] - .par_iter_mut() - .for_each(|n| *n = F::ONE); + par_fill(&mut numerators[offset..][..group_len << log_n_rows], |_| F::ONE); - denominators[offset / width..][..group_len * packed_chunk_size] - .par_chunks_exact_mut(packed_chunk_size) - .enumerate() - .for_each(|(i, denom_chunk)| { + parallel::par_chunks_mut( + &mut denominators[offset / width..][..group_len * packed_chunk_size], + packed_chunk_size, + |i, denom_chunk| { let i_field = F::from_usize(i); let col_value = &trace.columns[group.value_cols[i]]; - denom_chunk.par_iter_mut().enumerate().for_each(|(p, slot)| { + for (p, slot) in denom_chunk.iter_mut().enumerate() { *slot = c_packed - finger_print_packed::( memory_domainsep_packed, @@ -162,8 +160,9 @@ pub fn prove_generic_logup( ], &alphas_packed, ); - }); - }); + } + }, + ); offset += group_len << log_n_rows; bus_idx += group_len; next_group += 1; @@ -175,7 +174,7 @@ pub fn prove_generic_logup( match bus.multiplicity { BusMultiplicity::One => { let val = bus.direction.to_field_flag(); - slice.par_iter_mut().for_each(|n| *n = val); + par_fill(slice, |_| val); } BusMultiplicity::Column(col) => { fill_num_from(slice, &trace.columns[col], matches!(bus.direction, BusDirection::Pull)); @@ -532,5 +531,12 @@ fn fill_denoms(dst: &mut [EFPacking], build: Build) where Build: Fn(usize) -> EFPacking + Sync, { - dst.par_iter_mut().enumerate().for_each(|(p, slot)| *slot = build(p)); + par_fill(dst, build); +} + +/// Fill `dst` in parallel through the in-house pool, computing each slot from its +/// global index. Replaces the rayon `par_iter_mut().enumerate()` constant/index fills. +#[inline] +fn par_fill T + Sync>(dst: &mut [T], build: Build) { + parallel::par_for_each_mut(dst, |i, slot| *slot = build(i)); } diff --git a/crates/sub_protocols/src/quotient_gkr/layers.rs b/crates/sub_protocols/src/quotient_gkr/layers.rs index 0ff9e1663..b6c4502e1 100644 --- a/crates/sub_protocols/src/quotient_gkr/layers.rs +++ b/crates/sub_protocols/src/quotient_gkr/layers.rs @@ -111,13 +111,12 @@ pub(super) fn bit_reverse_chunks(v: &[T], chunk_log: usiz return out; } let shift = usize::BITS as usize - chunk_log; - out.par_chunks_exact_mut(chunk_size) - .zip(v.par_chunks_exact(chunk_size)) - .for_each(|(dst, src)| { - for (p, slot) in dst.iter_mut().enumerate() { - *slot = src[p.reverse_bits() >> shift]; - } - }); + parallel::par_chunks_mut(&mut out, chunk_size, |c, dst| { + let src = &v[c * chunk_size..][..chunk_size]; + for (p, slot) in dst.iter_mut().enumerate() { + *slot = src[p.reverse_bits() >> shift]; + } + }); out } @@ -130,18 +129,18 @@ fn sum_quotients_2_by_2>>(nums: &[EF], dens: &[EF]) -> let mut new_nums: Vec = unsafe { uninitialized_vec(new_active) }; let mut new_dens: Vec = unsafe { uninitialized_vec(new_active) }; - new_nums[..full_pairs] - .par_iter_mut() - .zip(new_dens[..full_pairs].par_iter_mut()) - .enumerate() - .for_each(|(i, (num, den))| { + { + let dp = parallel::SendPtr(new_dens.as_mut_ptr()); + parallel::par_for_each_mut(&mut new_nums[..full_pairs], |i, num| { let n0 = nums[2 * i]; let n1 = nums[2 * i + 1]; let d0 = dens[2 * i]; let d1 = dens[2 * i + 1]; *num = d1 * n0 + d0 * n1; - *den = d0 * d1; + // SAFETY: each `i` writes a distinct slot in `new_dens`, a separate buffer. + unsafe { *dp.add(i) = d0 * d1 }; }); + } // Boundary (at most one pair: a/b + 0/1 = a/b). if full_pairs < new_active { @@ -172,18 +171,18 @@ where let mut new_nums: Vec> = unsafe { uninitialized_vec(nums.len() >> 1) }; let mut new_dens: Vec> = unsafe { uninitialized_vec(nums.len() >> 1) }; - new_nums - .par_iter_mut() - .zip(new_dens.par_iter_mut()) - .enumerate() - .for_each(|(new_j, (num_out, den_out))| { + { + let dp = parallel::SendPtr(new_dens.as_mut_ptr()); + parallel::par_for_each_mut(&mut new_nums, |new_j, num_out| { let i_hi = new_j >> bit; let i_lo = new_j & lo_mask; let i0 = (i_hi << (bit + 1)) | i_lo; let i1 = i0 | stride; *num_out = dens[i1] * nums[i0] + dens[i0] * nums[i1]; - *den_out = dens[i0] * dens[i1]; + // SAFETY: each `new_j` writes a distinct slot in `new_dens`, a separate buffer. + unsafe { *dp.add(new_j) = dens[i0] * dens[i1] }; }); + } (new_nums, new_dens) } diff --git a/crates/sub_protocols/src/quotient_gkr/mod.rs b/crates/sub_protocols/src/quotient_gkr/mod.rs index 26fa25a65..c9e9554fc 100644 --- a/crates/sub_protocols/src/quotient_gkr/mod.rs +++ b/crates/sub_protocols/src/quotient_gkr/mod.rs @@ -207,7 +207,7 @@ mod tests { type EF = QuinticExtensionFieldKB; fn sum_all_quotients(nums: &[F], den: &[EF]) -> EF { - nums.par_iter().zip(den).map(|(&n, &d)| EF::from(n) / d).sum() + nums.iter().zip(den).map(|(&n, &d)| EF::from(n) / d).sum() } fn bit_reverse_chunks_and_pack_ext>>(v: &[EF], chunk_log: usize) -> Vec> { diff --git a/crates/sub_protocols/src/quotient_gkr/sumcheck_utils.rs b/crates/sub_protocols/src/quotient_gkr/sumcheck_utils.rs index 42f2c6a42..0f17c4035 100644 --- a/crates/sub_protocols/src/quotient_gkr/sumcheck_utils.rs +++ b/crates/sub_protocols/src/quotient_gkr/sumcheck_utils.rs @@ -256,8 +256,8 @@ pub(super) fn run_phase1_sumcheck<'a, EF: ExtensionField>>( if let Some(prev_r) = pending_r { let prev_bit = layer_chunk_log - 1 - w; let mul = |x: EFPacking, a: EF| x * a; - nums = Cow::Owned(fold_multilinear_at_bit(nums.as_ref(), prev_r, prev_bit, &mul)); - dens = Cow::Owned(fold_multilinear_at_bit(dens.as_ref(), prev_r, prev_bit, &mul)); + nums = Cow::Owned(fold_multilinear_at_bit(nums.as_ref(), prev_r, prev_bit, &mul, false)); + dens = Cow::Owned(fold_multilinear_at_bit(dens.as_ref(), prev_r, prev_bit, &mul, false)); } let nums_nat = unpack_and_unreverse_active::(nums.as_ref(), layer_chunk_log); @@ -329,10 +329,7 @@ pub(super) fn run_phase2_sumcheck>>( }; let acc: RoundCoeffs = if active_pairs > PARALLEL_THRESHOLD { - (0..active_pairs) - .into_par_iter() - .map(term) - .reduce(RoundCoeffs::zero, Add::add) + parallel::map_reduce(active_pairs, RoundCoeffs::zero, term, Add::add) } else { (0..active_pairs).map(term).fold(RoundCoeffs::::zero(), Add::add) }; @@ -363,7 +360,9 @@ pub(super) fn run_phase2_sumcheck>>( if new_eq_len > 0 { let fold_eq = |i: usize| eq_table[2 * i] + eq_table[2 * i + 1]; eq_table = if new_eq_len >= PARALLEL_THRESHOLD { - (0..new_eq_len).into_par_iter().map(fold_eq).collect() + let mut out: Vec = unsafe { uninitialized_vec(new_eq_len) }; + parallel::par_for_each_mut(&mut out, |i, slot| *slot = fold_eq(i)); + out } else { (0..new_eq_len).map(fold_eq).collect() }; @@ -393,10 +392,7 @@ fn fold_normal_with_padding>>(m: &[EF], r: EF, pad_val if new_active < PARALLEL_THRESHOLD { out.iter_mut().enumerate().for_each(compute); } else { - out.par_iter_mut() - .with_min_len(PARALLEL_THRESHOLD) - .enumerate() - .for_each(compute); + parallel::par_for_each_mut(&mut out, |i, slot| compute((i, slot))); } out } @@ -421,10 +417,13 @@ where debug_assert_eq!(dens.len(), nums.len()); debug_assert_eq!(eq_within.len(), quarter); - nums.par_chunks_exact(layer_packed) - .zip(dens.par_chunks_exact(layer_packed)) - .enumerate() - .fold(RoundCoeffs::zero, |mut acc, (c, (n_c, d_c))| { + let n_chunks = nums.len() / layer_packed; + parallel::map_reduce( + n_chunks, + RoundCoeffs::zero, + |c| { + let n_c = &nums[c * layer_packed..][..layer_packed]; + let d_c = &dens[c * layer_packed..][..layer_packed]; let eq_o: EF = eq_outer.get(c).copied().unwrap_or(EF::ONE); let mut local = RoundCoeffs::>::zero(); for inner in 0..quarter { @@ -436,10 +435,10 @@ where ); local += coeffs * eq_within[inner]; } - acc += local * eq_o; - acc - }) - .reduce(RoundCoeffs::zero, Add::add) + local * eq_o + }, + Add::add, + ) } #[allow(clippy::type_complexity)] @@ -473,13 +472,19 @@ where let mut new_dens: Vec> = unsafe { uninitialized_vec(active_out_packed) }; let prev_r_packed: EFPacking = as From>::from(prev_r); - let coeffs = nums - .par_chunks_exact(in_packed) - .zip(dens.par_chunks_exact(in_packed)) - .zip(new_nums.par_chunks_exact_mut(out_packed)) - .zip(new_dens.par_chunks_exact_mut(out_packed)) - .enumerate() - .fold(RoundCoeffs::zero, |mut acc, (c, (((n_c, d_c), nn_c), nd_c))| { + let n_chunks = nums.len() / in_packed; + let nn = parallel::SendPtr(new_nums.as_mut_ptr()); + let nd = parallel::SendPtr(new_dens.as_mut_ptr()); + let coeffs = parallel::map_reduce( + n_chunks, + RoundCoeffs::zero, + |c| { + let n_c = &nums[c * in_packed..][..in_packed]; + let d_c = &dens[c * in_packed..][..in_packed]; + // SAFETY: chunk `c` owns the disjoint `out_packed`-sized regions of the two + // output buffers at `c * out_packed`; no other task touches them. + let nn_c = unsafe { std::slice::from_raw_parts_mut(nn.add(c * out_packed), out_packed) }; + let nd_c = unsafe { std::slice::from_raw_parts_mut(nd.add(c * out_packed), out_packed) }; let eq_o: EF = eq_outer.get(c).copied().unwrap_or(EF::ONE); let mut local = RoundCoeffs::>::zero(); for i in 0..in_eighth { @@ -500,10 +505,10 @@ where ); local += round * eq_within[i]; } - acc += local * eq_o; - acc - }) - .reduce(RoundCoeffs::zero, Add::add); + local * eq_o + }, + Add::add, + ); (new_nums, new_dens, coeffs) } diff --git a/crates/utils/src/misc.rs b/crates/utils/src/misc.rs index ff317ae4c..09bad5162 100644 --- a/crates/utils/src/misc.rs +++ b/crates/utils/src/misc.rs @@ -7,15 +7,23 @@ pub fn from_end(slice: &[A], n: usize) -> &[A] { &slice[slice.len() - n..] } -pub fn transposed_par_iter_mut( - array: &mut [Vec; N], // all vectors must have the same length -) -> impl IndexedParallelIterator + '_ { +/// Run `g(i, row)` in parallel over `i in 0..len`, where `row` is `[&mut A; N]` holding +/// the `i`-th element of each of the `N` equal-length vectors (a transposed row). +/// Dispatched through the in-house [`parallel`] pool. +pub fn transposed_par_for_each_mut(array: &mut [Vec; N], g: G) +where + G: Fn(usize, [&mut A; N]) + Sync, +{ + // all vectors must have the same length let len = array[0].len(); let data_ptrs: [AtomicPtr; N] = array.each_mut().map(|v| AtomicPtr::new(v.as_mut_ptr())); - (0..len) - .into_par_iter() - .map(move |i| unsafe { std::array::from_fn(|j| &mut *data_ptrs[j].load(Ordering::Relaxed).add(i)) }) + parallel::for_each_index(len, |i| { + // SAFETY: distinct `i` access disjoint row `i` of each of the `N` vectors, and the + // arrays outlive the dispatch (the dispatcher blocks until all tasks complete). + let row: [&mut A; N] = unsafe { std::array::from_fn(|j| &mut *data_ptrs[j].load(Ordering::Relaxed).add(i)) }; + g(i, row); + }); } pub fn collect_refs(vecs: &[Vec]) -> Vec<&[T]> { diff --git a/crates/utils/src/multilinear.rs b/crates/utils/src/multilinear.rs index 2030ca5f4..0147d28a0 100644 --- a/crates/utils/src/multilinear.rs +++ b/crates/utils/src/multilinear.rs @@ -13,10 +13,12 @@ pub fn multilinears_linear_combination, P: Borro assert_eq!(pols.len(), scalars.len()); let n_vars = log2_strict_usize(pols[0].borrow().len()); assert!(pols.iter().all(|p| log2_strict_usize(p.borrow().len()) == n_vars)); - (0..1 << n_vars) - .into_par_iter() - .map(|i| dot_product(scalars.iter().copied(), pols.iter().map(|p| p.borrow()[i]))) - .collect::>() + let n = 1usize << n_vars; + let mut out: Vec = unsafe { uninitialized_vec(n) }; + parallel::par_for_each_mut(&mut out, |i, slot| { + *slot = dot_product(scalars.iter().copied(), pols.iter().map(|p| p.borrow()[i])); + }); + out } pub fn multilinear_eval_constants_at_right(limit: usize, point: &[F]) -> F { diff --git a/crates/whir/Cargo.toml b/crates/whir/Cargo.toml index 1c2a2b0a7..6845c34b4 100644 --- a/crates/whir/Cargo.toml +++ b/crates/whir/Cargo.toml @@ -14,7 +14,7 @@ symetric = { path = "../backend/symetric", package = "mt-symetric" } system-info.workspace = true itertools.workspace = true -rayon.workspace = true +parallel.workspace = true rand.workspace = true tracing.workspace = true diff --git a/crates/whir/src/dft.rs b/crates/whir/src/dft.rs index abe1d5a9a..91296c312 100644 --- a/crates/whir/src/dft.rs +++ b/crates/whir/src/dft.rs @@ -29,7 +29,6 @@ use field::PackedValue; use field::{BasedVectorSpace, Field, PackedField, TwoAdicField}; use itertools::Itertools; -use rayon::prelude::*; use tracing::instrument; use utils::{as_base_slice, log2_strict_usize}; @@ -164,7 +163,8 @@ where /// also divide by the height. #[inline] fn par_initial_layers(mat: &mut [F], chunk_size: usize, root_table: &[Vec], width: usize) { - mat.par_chunks_exact_mut(chunk_size).for_each(|chunk| { + let n_full = mat.len() / chunk_size * chunk_size; + parallel::par_chunks_mut(&mut mat[..n_full], chunk_size, |_, chunk| { initial_layers(chunk, root_table, width); }); } @@ -197,14 +197,21 @@ fn dft_layer>(vec: &mut [F], twiddles: &[B], width: us #[inline] fn dft_layer_par>(vec: &mut [F], twiddles: &[B], width: usize) { - vec.par_chunks_exact_mut(twiddles.len() * 2 * width).for_each(|block| { - let (left, right) = block.split_at_mut(twiddles.len() * width); - left.par_chunks_exact_mut(width) - .zip(right.par_chunks_exact_mut(width)) - .zip(twiddles.par_iter()) - .for_each(|((hi_chunk, lo_chunk), twiddle)| { - twiddle.apply_to_rows(hi_chunk, lo_chunk); - }); + let ts = twiddles.len(); + let block_size = 2 * ts * width; + debug_assert!(vec.len().is_multiple_of(block_size),); + let n_blocks = vec.len() / block_size; + // Flatten (block, group) into one parallel loop over `n_blocks * ts` groups so coarse + // layers (few blocks) still parallelize; guided scheduling keeps a worker's batch of + // consecutive groups within the same block, preserving the per-block cache locality. + let base = parallel::SendPtr(vec.as_mut_ptr()); + parallel::for_each_index(n_blocks * ts, |g| { + let block_base = (g / ts) * block_size; + let ind = g % ts; + // SAFETY: distinct `g` map to disjoint (hi, lo) `width`-rows. + let hi = unsafe { base.slice(block_base + ind * width, width) }; + let lo = unsafe { base.slice(block_base + (ts + ind) * width, width) }; + twiddles[ind].apply_to_rows(hi, lo); }); } @@ -234,40 +241,25 @@ fn dft_layer_par_double, M: MultiLayerButterfly> assert_eq!(twiddles_large.len(), twiddles_small.len() * 2); - // TODO optimal workload size with L1 cache - mat.values - .par_chunks_exact_mut(twiddles_large.len() * 2 * width) - .for_each(|block| { - // (0..twiddles_small.len()).into_par_iter().for_each(|ind| { - // let hi_hi = slice_ref_mut(block, ind * width, width); - // let hi_lo = slice_ref_mut(block, (ind + twiddles_small.len()) * width, width); - // let lo_hi = slice_ref_mut(block, (ind + 2 * twiddles_small.len()) * width, width); - // let lo_lo = slice_ref_mut(block, (ind + 3 * twiddles_small.len()) * width, width); - // multi_butterfly.apply_2_layers( - // ((hi_hi, hi_lo), (lo_hi, lo_lo)), - // ind, - // twiddles_small, - // twiddles_large, - // ); - // }); - let (hi_blocks, lo_blocks) = block.split_at_mut(twiddles_small.len() * width * 2); - let (hi_hi_blocks, hi_lo_blocks) = hi_blocks.split_at_mut(twiddles_small.len() * width); - let (lo_hi_blocks, lo_lo_blocks) = lo_blocks.split_at_mut(twiddles_small.len() * width); - hi_hi_blocks - .par_chunks_exact_mut(width) - .zip(hi_lo_blocks.par_chunks_exact_mut(width)) - .zip(lo_hi_blocks.par_chunks_exact_mut(width)) - .zip(lo_lo_blocks.par_chunks_exact_mut(width)) - .enumerate() - .for_each(|(ind, (((hi_hi, hi_lo), lo_hi), lo_lo))| { - multi_butterfly.apply_2_layers( - ((hi_hi, hi_lo), (lo_hi, lo_lo)), - ind, - twiddles_small, - twiddles_large, - ); - }); - }); + // Flatten (block, inner-group) into one parallel loop. A block is `4·ts` rows of + // `width`; group `ind` touches the 4 rows at sub-block offsets `k·ts + ind` (k=0..3). + // Coarse layers (few blocks) thus still parallelize over their `ts` inner groups, and + // guided scheduling keeps a worker's consecutive groups within one block (cache-local). + let ts = twiddles_small.len(); + let block_size = 4 * ts * width; // == twiddles_large.len() * 2 * width + let n_blocks = mat.values.len() / block_size; + let base = parallel::SendPtr(mat.values.as_mut_ptr()); + parallel::for_each_index(n_blocks * ts, |g| { + let block_base = (g / ts) * block_size; + let ind = g % ts; + let row = |k: usize| block_base + (k * ts + ind) * width; + // SAFETY: distinct `g` map to disjoint sets of 4 `width`-rows. + let hi_hi = unsafe { base.slice(row(0), width) }; + let hi_lo = unsafe { base.slice(row(1), width) }; + let lo_hi = unsafe { base.slice(row(2), width) }; + let lo_lo = unsafe { base.slice(row(3), width) }; + multi_butterfly.apply_2_layers(((hi_hi, hi_lo), (lo_hi, lo_lo)), ind, twiddles_small, twiddles_large); + }); } /// Applies three layers of a Radix-2 FFT butterfly network making use of parallelization. @@ -303,44 +295,38 @@ fn dft_layer_par_triple, M: MultiLayerButterfly> // let inner_chunk_size = // (workload_size::().next_power_of_two() / 8).min(eighth_outer_block_size); - mat.values - .par_chunks_exact_mut(twiddles_large.len() * 2 * width) - .for_each(|block| { - let (hi_blocks, lo_blocks) = block.split_at_mut(twiddles_small.len() * width * 4); - let (hi_hi_blocks, hi_lo_blocks) = hi_blocks.split_at_mut(twiddles_small.len() * width * 2); - let (lo_hi_blocks, lo_lo_blocks) = lo_blocks.split_at_mut(twiddles_small.len() * width * 2); - let (hi_hi_hi_blocks, hi_hi_lo_blocks) = hi_hi_blocks.split_at_mut(twiddles_small.len() * width); - let (hi_lo_hi_blocks, hi_lo_lo_blocks) = hi_lo_blocks.split_at_mut(twiddles_small.len() * width); - let (lo_hi_hi_blocks, lo_hi_lo_blocks) = lo_hi_blocks.split_at_mut(twiddles_small.len() * width); - let (lo_lo_hi_blocks, lo_lo_lo_blocks) = lo_lo_blocks.split_at_mut(twiddles_small.len() * width); - hi_hi_hi_blocks - .par_chunks_exact_mut(width) - .zip(hi_hi_lo_blocks.par_chunks_exact_mut(width)) - .zip(hi_lo_hi_blocks.par_chunks_exact_mut(width)) - .zip(hi_lo_lo_blocks.par_chunks_exact_mut(width)) - .zip(lo_hi_hi_blocks.par_chunks_exact_mut(width)) - .zip(lo_hi_lo_blocks.par_chunks_exact_mut(width)) - .zip(lo_lo_hi_blocks.par_chunks_exact_mut(width)) - .zip(lo_lo_lo_blocks.par_chunks_exact_mut(width)) - .enumerate() - .for_each( - |( - ind, - (((((((hi_hi_hi, hi_hi_lo), hi_lo_hi), hi_lo_lo), lo_hi_hi), lo_hi_lo), lo_lo_hi), lo_lo_lo), - )| { - multi_butterfly.apply_3_layers( - ( - ((hi_hi_hi, hi_hi_lo), (hi_lo_hi, hi_lo_lo)), - ((lo_hi_hi, lo_hi_lo), (lo_lo_hi, lo_lo_lo)), - ), - ind, - twiddles_small, - twiddles_med, - twiddles_large, - ); - }, - ); - }); + // Flatten (block, inner-group) into one parallel loop. A block is `8·ts` rows of + // `width`; group `ind` touches the 8 rows at sub-block offsets `k·ts + ind` (k=0..7). + // Coarse layers still parallelize over their `ts` inner groups; guided scheduling keeps + // a worker's consecutive groups within one block (cache-local). + let ts = twiddles_small.len(); + let block_size = 8 * ts * width; // == twiddles_large.len() * 2 * width + let n_blocks = mat.values.len() / block_size; + let base = parallel::SendPtr(mat.values.as_mut_ptr()); + parallel::for_each_index(n_blocks * ts, |g| { + let block_base = (g / ts) * block_size; + let ind = g % ts; + let row = |k: usize| block_base + (k * ts + ind) * width; + // SAFETY: distinct `g` map to disjoint sets of 8 `width`-rows. + let hi_hi_hi = unsafe { base.slice(row(0), width) }; + let hi_hi_lo = unsafe { base.slice(row(1), width) }; + let hi_lo_hi = unsafe { base.slice(row(2), width) }; + let hi_lo_lo = unsafe { base.slice(row(3), width) }; + let lo_hi_hi = unsafe { base.slice(row(4), width) }; + let lo_hi_lo = unsafe { base.slice(row(5), width) }; + let lo_lo_hi = unsafe { base.slice(row(6), width) }; + let lo_lo_lo = unsafe { base.slice(row(7), width) }; + multi_butterfly.apply_3_layers( + ( + ((hi_hi_hi, hi_hi_lo), (hi_lo_hi, hi_lo_lo)), + ((lo_hi_hi, lo_hi_lo), (lo_lo_hi, lo_lo_lo)), + ), + ind, + twiddles_small, + twiddles_med, + twiddles_large, + ); + }); } /// Applies the remaining layers of the Radix-2 FFT butterfly network in parallel. diff --git a/crates/whir/src/merkle.rs b/crates/whir/src/merkle.rs index 4c61782eb..e6c6f5a79 100644 --- a/crates/whir/src/merkle.rs +++ b/crates/whir/src/merkle.rs @@ -12,7 +12,6 @@ use field::PrimeCharacteristicRing; use koala_bear::{KoalaBear, QuinticExtensionFieldKB, default_koalabear_poseidon1_16}; use poly::*; -use rayon::prelude::*; use symetric::merkle::unpack_array; use tracing::instrument; use utils::log2_ceil_usize; @@ -194,22 +193,20 @@ where let mut digests = unsafe { uninitialized_vec(height) }; - digests - .par_chunks_exact_mut(width) - .enumerate() - .for_each(|(i, digests_chunk)| { - let first_row = i * width; - let rtl_iter = matrix.vertically_packed_row_rtl::

(first_row, effective_base_width, n_pad); - let packed_digest: [P; DIGEST_ELEMS] = - symetric::hash_rtl_iter_with_initial_state::<_, _, _, WIDTH, RATE, DIGEST_ELEMS>( - perm, - rtl_iter, - packed_initial_state, - ); - for (dst, src) in digests_chunk.iter_mut().zip(unpack_array(packed_digest)) { - *dst = src; - } - }); + // `height` is a multiple of `width`, so every chunk is exactly `width` long. + parallel::par_chunks_mut(&mut digests, width, |i, digests_chunk| { + let first_row = i * width; + let rtl_iter = matrix.vertically_packed_row_rtl::

(first_row, effective_base_width, n_pad); + let packed_digest: [P; DIGEST_ELEMS] = + symetric::hash_rtl_iter_with_initial_state::<_, _, _, WIDTH, RATE, DIGEST_ELEMS>( + perm, + rtl_iter, + packed_initial_state, + ); + for (dst, src) in digests_chunk.iter_mut().zip(unpack_array(packed_digest)) { + *dst = src; + } + }); digests } diff --git a/crates/whir/src/open.rs b/crates/whir/src/open.rs index 6636b77c7..6b737b9c4 100644 --- a/crates/whir/src/open.rs +++ b/crates/whir/src/open.rs @@ -5,7 +5,6 @@ use fiat_shamir::{FSProver, MerklePath, ProofResult}; use field::PrimeCharacteristicRing; use field::{ExtensionField, Field, TwoAdicField}; use poly::*; -use rayon::prelude::*; use sumcheck::{ProductComputation, run_product_sumcheck, sumcheck_prove_many_rounds}; use tracing::{info_span, instrument}; @@ -594,17 +593,15 @@ where for (e, &scalar) in smt.values.iter().zip(&next_gamma_powers) { combined_sum += e.value * scalar; } - chunks_mut - .into_par_iter() - .zip(&indexed_smt_values) - .for_each(|(out_buff, &(origin_index, _))| { - out_buff[..1 << shift] - .par_iter_mut() - .zip(&inner_poly) - .for_each(|(out_elem, &poly_elem)| { - *out_elem += poly_elem * next_gamma_powers[origin_index]; - }); + // Few sparse statements (the outer chunks) but each inner accumulation can be + // large, so parallelize the inner loop per statement (the outer runs serial). + for (out_buff, &(origin_index, _)) in chunks_mut.iter_mut().zip(&indexed_smt_values) { + let out = &mut out_buff[..1 << shift]; + let scalar = next_gamma_powers[origin_index]; + parallel::par_for_each_mut(out, |i, out_elem| { + *out_elem += inner_poly[i] * scalar; }); + } gamma_pow = *next_gamma_powers.last().unwrap() * gamma; } } diff --git a/crates/whir/src/utils.rs b/crates/whir/src/utils.rs index c59bed968..9a8ec359e 100644 --- a/crates/whir/src/utils.rs +++ b/crates/whir/src/utils.rs @@ -6,7 +6,6 @@ use field::Field; use field::PackedValue; use field::{ExtensionField, TwoAdicField}; use poly::*; -use rayon::prelude::*; use std::any::{Any, TypeId}; use std::collections::HashMap; use std::sync::{Arc, Mutex, OnceLock}; @@ -136,15 +135,28 @@ fn prepare_evals_for_fft_unpacked( let log_block_size = log2_strict_usize(block_size); let out_len = block_size * dft_n_cols; - (0..out_len) - .into_par_iter() - .map(|i| { - let block_index = i % dft_n_cols; - let offset_in_block = i / dft_n_cols; - let src_index = ((block_index << log_block_size) + offset_in_block) >> log_inv_rate; - unsafe { *evals.get_unchecked(src_index) } - }) - .collect() + let mut out: Vec = unsafe { uninitialized_vec(out_len) }; + if block_size == 0 || dft_n_cols == 0 { + return out; + } + + let rows_per_band = ((system_info::L1_CACHE_SIZE / 2) / (dft_n_cols * size_of::())).clamp(1, block_size); + let band_len = rows_per_band * dft_n_cols; + + parallel::par_chunks_mut(&mut out, band_len, |band_idx, band| { + let row0 = band_idx * rows_per_band; + let n_rows = band.len() / dft_n_cols; + for col in 0..dft_n_cols { + let col_base = col << log_block_size; + for r in 0..n_rows { + let src = (col_base + row0 + r) >> log_inv_rate; + unsafe { + *band.get_unchecked_mut(r * dft_n_cols + col) = *evals.get_unchecked(src); + } + } + } + }); + out } fn prepare_evals_for_fft_packed_extension>>( @@ -158,25 +170,38 @@ fn prepare_evals_for_fft_packed_extension>>( let full_len = evals.len() << (log_inv_rate + log_packing); let block_size = full_len / n_blocks; let log_block_size = log2_strict_usize(block_size); - let n_blocks_mask = n_blocks - 1; let packing_mask = (1 << log_packing) - 1; - (0..full_len) - .into_par_iter() - .map(|i| { - let block_index = i & n_blocks_mask; - let offset_in_block = i >> folding_factor; - let src_index = ((block_index << log_block_size) + offset_in_block) >> log_inv_rate; - let packed_src_index = src_index >> log_packing; - let offset_in_packing = src_index & packing_mask; - let packed = unsafe { evals.get_unchecked(packed_src_index) }; - let unpacked: &[PFPacking] = packed.as_basis_coefficients_slice(); - EF::from_basis_coefficients_fn(|i| unsafe { - let u: &PFPacking = unpacked.get_unchecked(i); - *u.as_slice().get_unchecked(offset_in_packing) - }) - }) - .collect() + let mut out: Vec = unsafe { uninitialized_vec(full_len) }; + if block_size == 0 || n_blocks == 0 { + return out; + } + + let rows_per_band = ((system_info::L1_CACHE_SIZE / 2) / (n_blocks * size_of::())).clamp(1, block_size); + let band_len = rows_per_band * n_blocks; + + parallel::par_chunks_mut(&mut out, band_len, |band_idx, band| { + let row0 = band_idx * rows_per_band; + let n_rows = band.len() / n_blocks; + for col in 0..n_blocks { + let col_base = col << log_block_size; + for r in 0..n_rows { + let src_index = (col_base + row0 + r) >> log_inv_rate; + let packed_src_index = src_index >> log_packing; + let offset_in_packing = src_index & packing_mask; + let packed = unsafe { evals.get_unchecked(packed_src_index) }; + let unpacked: &[PFPacking] = packed.as_basis_coefficients_slice(); + let val = EF::from_basis_coefficients_fn(|j| unsafe { + let u: &PFPacking = unpacked.get_unchecked(j); + *u.as_slice().get_unchecked(offset_in_packing) + }); + unsafe { + *band.get_unchecked_mut(r * n_blocks + col) = val; + } + } + } + }); + out } type CacheKey = TypeId; diff --git a/crates/xmss/src/signers_cache.rs b/crates/xmss/src/signers_cache.rs index 6e7a9956e..b843b9c28 100644 --- a/crates/xmss/src/signers_cache.rs +++ b/crates/xmss/src/signers_cache.rs @@ -65,7 +65,7 @@ fn compute_signer(index: usize) -> (XmssPublicKey, XmssSignature) { let mut rng = StdRng::seed_from_u64(index as u64); let key_start = BENCHMARK_SLOT; let key_end = BENCHMARK_SLOT + 1; - let (sk, pk) = xmss_key_gen(rng.random(), key_start, key_end).unwrap(); + let (sk, pk) = xmss_key_gen(rng.random(), key_start, key_end, true).unwrap(); let sig = xmss_sign(&mut rng, &sk, &message_for_benchmark(), BENCHMARK_SLOT).unwrap(); (pk, sig) } @@ -89,18 +89,16 @@ fn gen_benchmark_signers_cache() -> Vec<(XmssPublicKey, XmssSignature)> { let completed = AtomicUsize::new(1); let time = Instant::now(); - let rest: Vec<_> = (1..NUM_BENCHMARK_SIGNERS) - .into_par_iter() - .map(|index| { - let signer = compute_signer(index); - let done = completed.fetch_add(1, Ordering::Relaxed) + 1; - print!( - "\rPrecomputing benchmark signatures (cached after first run): {:.0}%", - 100.0 * done as f64 / NUM_BENCHMARK_SIGNERS as f64 - ); - signer - }) - .collect(); + let n_rest = NUM_BENCHMARK_SIGNERS - 1; + let rest = parallel::par_map_collect(n_rest, |i| { + let signer = compute_signer(1 + i); + let done = completed.fetch_add(1, Ordering::Relaxed) + 1; + print!( + "\rPrecomputing benchmark signatures (cached after first run): {:.0}%", + 100.0 * done as f64 / NUM_BENCHMARK_SIGNERS as f64 + ); + signer + }); println!( "\rGenerating signatures for benchmark (one-time operation): 100% - done ({:.2}s)", @@ -128,7 +126,8 @@ fn gen_benchmark_signers_cache() -> Vec<(XmssPublicKey, XmssSignature)> { #[test] fn test_signature_cache() { let signatures = get_benchmark_signatures(); - signatures.par_iter().enumerate().for_each(|(i, (pk, sig))| { + parallel::for_each_index(signatures.len(), |i| { + let (pk, sig) = &signatures[i]; xmss_verify(pk, &message_for_benchmark(), sig, BENCHMARK_SLOT) .unwrap_or_else(|_| panic!("Signature {} failed to verify", i)); }); diff --git a/crates/xmss/src/xmss.rs b/crates/xmss/src/xmss.rs index d5f69f445..9f77837af 100644 --- a/crates/xmss/src/xmss.rs +++ b/crates/xmss/src/xmss.rs @@ -74,23 +74,32 @@ pub enum XmssKeyGenError { InvalidRange, } +fn fill(sequential: bool, data: &mut [T], f: impl Fn(usize, &mut T) + Sync) { + if sequential { + data.iter_mut().enumerate().for_each(|(i, out)| f(i, out)); + } else { + parallel::par_for_each_mut(data, f); + } +} + pub fn xmss_key_gen( seed: [u8; 32], slot_start: u32, slot_end: u32, + sequential: bool, ) -> Result<(XmssSecretKey, XmssPublicKey), XmssKeyGenError> { if slot_start > slot_end || slot_end as u64 >= (1 << LOG_LIFETIME) { return Err(XmssKeyGenError::InvalidRange); } let public_param: PublicParam = gen_public_param(&seed); // Level 0: WOTS leaf hashes for slots in [slot_start, slot_end] - let leaves: Vec = (slot_start..=slot_end) - .into_par_iter() - .map(|slot| { - let wots = gen_wots_secret_key(&seed, slot, public_param); - wots.public_key().hash(public_param, slot) - }) - .collect(); + let n_leaves = (slot_end - slot_start + 1) as usize; + let mut leaves: Vec = unsafe { uninitialized_vec(n_leaves) }; + fill(sequential, &mut leaves, |i, out| { + let slot = slot_start + i as u32; + let wots = gen_wots_secret_key(&seed, slot, public_param); + *out = wots.public_key().hash(public_param, slot); + }); let mut merkle_tree = vec![leaves]; // Build levels 1..=LOG_LIFETIME. // At level l, we store nodes with index in [(slot_start >> l), (slot_end >> l)]. @@ -102,30 +111,31 @@ pub fn xmss_key_gen( let prev_top: u64 = (slot_end as u64) >> (level - 1); let nodes: Vec = { let prev = &merkle_tree[level - 1]; - (base..=top) - .into_par_iter() - .map(|i| { - let left_idx = 2 * i; - let right_idx = 2 * i + 1; - let left = if left_idx >= prev_base && left_idx <= prev_top { - prev[(left_idx - prev_base) as usize] - } else { - gen_random_node(&seed, level - 1, left_idx) - }; - let right = if right_idx >= prev_base && right_idx <= prev_top { - prev[(right_idx - prev_base) as usize] - } else { - gen_random_node(&seed, level - 1, right_idx) - }; - let merkle_data = build_merkle_data( - make_tweak(TWEAK_TYPE_MERKLE, level, i as u32), - &public_param, - &left, - &right, - ); - poseidon16_compress(merkle_data)[..XMSS_DIGEST_LEN].try_into().unwrap() - }) - .collect() + let n_nodes = (top - base + 1) as usize; + let mut nodes: Vec = unsafe { uninitialized_vec(n_nodes) }; + fill(sequential, &mut nodes, |k, out| { + let i = base + k as u64; + let left_idx = 2 * i; + let right_idx = 2 * i + 1; + let left = if left_idx >= prev_base && left_idx <= prev_top { + prev[(left_idx - prev_base) as usize] + } else { + gen_random_node(&seed, level - 1, left_idx) + }; + let right = if right_idx >= prev_base && right_idx <= prev_top { + prev[(right_idx - prev_base) as usize] + } else { + gen_random_node(&seed, level - 1, right_idx) + }; + let merkle_data = build_merkle_data( + make_tweak(TWEAK_TYPE_MERKLE, level, i as u32), + &public_param, + &left, + &right, + ); + *out = poseidon16_compress(merkle_data)[..XMSS_DIGEST_LEN].try_into().unwrap(); + }); + nodes }; merkle_tree.push(nodes); } diff --git a/crates/xmss/tests/xmss_tests.rs b/crates/xmss/tests/xmss_tests.rs index 0fb08e01d..6d2721df0 100644 --- a/crates/xmss/tests/xmss_tests.rs +++ b/crates/xmss/tests/xmss_tests.rs @@ -12,7 +12,7 @@ fn test_xmss_serialize_deserialize() { let slot_end = 115; let slot = 110; - let (sk, pk) = xmss_key_gen(keygen_seed, slot_start, slot_end).unwrap(); + let (sk, pk) = xmss_key_gen(keygen_seed, slot_start, slot_end, false).unwrap(); let sig = xmss_sign(&mut StdRng::seed_from_u64(slot as u64), &sk, &message, slot).unwrap(); let pk_bytes = postcard::to_allocvec(&pk).unwrap(); @@ -32,7 +32,7 @@ fn keygen_sign_verify() { let message: [F; MESSAGE_LEN_FE] = std::array::from_fn(|i| F::from_usize(i * 3 + 7)); for slot in [0, 1234, u32::MAX] { - let (sk, pk) = xmss_key_gen(keygen_seed, slot.saturating_sub(1), slot.saturating_add(2)).unwrap(); + let (sk, pk) = xmss_key_gen(keygen_seed, slot.saturating_sub(1), slot.saturating_add(2), false).unwrap(); let sig = xmss_sign(&mut StdRng::seed_from_u64(slot as u64), &sk, &message, slot).unwrap(); xmss_verify(&pk, &message, &sig, slot).unwrap(); } @@ -46,17 +46,19 @@ fn encoding_grinding_bits() { merkle_root: Default::default(), public_param: Default::default(), }; - let total_iters = (0..n) - .into_par_iter() - .map(|i| { + let total_iters = parallel::map_reduce( + n, + || 0usize, + |i| { let message: [F; MESSAGE_LEN_FE] = Default::default(); let slot = i as u32; let mut rng = StdRng::seed_from_u64(i as u64); let (_randomness, _encoding, num_iters) = find_randomness_for_wots_encoding(&message, slot, &xmss_pub_key, &mut rng); num_iters - }) - .sum::(); + }, + |a, b| a + b, + ); let grinding = ((total_iters as f64) / (n as f64)).log2(); println!("Average grinding bits: {:.1}", grinding); } diff --git a/src/lib.rs b/src/lib.rs index 577853996..eea73127b 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -11,8 +11,10 @@ pub use xmss::{MESSAGE_LEN_FE, XmssPublicKey, XmssSecretKey, XmssSignature, xmss pub type F = KoalaBear; -/// Call once before proving. Compiles the aggregation program and precomputes DFT twiddles. +/// Call once before proving. Tunes the process memory policy (see [`tune_allocator`]), +/// compiles the aggregation program, and precomputes DFT twiddles. pub fn setup_prover() { + parallel::init(); rec_aggregation::init_aggregation_bytecode(); precompute_dft_twiddles::(1 << 24); } @@ -24,13 +26,10 @@ pub fn setup_verifier() { /// Bump-arena allocator. /// -/// **Optional.** -/// -/// To enable, set it as the `#[global_allocator]` in your binary and call -/// [`init_allocator`] once at startup. Then bracket each proving call with -/// [`begin_phase`] / [`end_phase`] and **clone the outputs after -/// [`end_phase`]** so the cloned copy lands in the system allocator before the -/// next [`begin_phase`] resets the arena slabs. +/// To enable, set it as the `#[global_allocator]` in your binary. Then bracket each proving +/// call with [`begin_phase`] / [`end_phase`] and **clone the outputs after [`end_phase`]** so +/// the cloned copy lands in the system allocator before the next [`begin_phase`] resets the +/// arena slabs. /// /// See `tests/test_zk_alloc.rs` for a runnable end-to-end example. -pub use zk_alloc::{ZkAllocator, begin_phase, end_phase, init as init_allocator}; +pub use zk_alloc::{ZkAllocator, begin_phase, end_phase}; diff --git a/src/main.rs b/src/main.rs index 646fc6f64..c77e12ea5 100644 --- a/src/main.rs +++ b/src/main.rs @@ -52,6 +52,7 @@ enum Cli { } fn run_with_warmup(topology: &AggregationTopology, tracing: bool, json: bool, repeat: usize) { + lean_multisig::setup_prover(); let warmup = biggest_leaf(topology).unwrap(); eprint!("warming up... "); let _ = run_aggregation_benchmark(&warmup, false, true, 1); @@ -67,9 +68,6 @@ fn run_with_warmup(topology: &AggregationTopology, tracing: bool, json: bool, re #[allow(clippy::too_many_lines)] fn main() { - #[cfg(not(feature = "standard-alloc"))] - zk_alloc::init(); - let cli = Cli::parse(); match cli { diff --git a/tests/test_multisignatures.rs b/tests/test_multisignatures.rs index c5ba89d4c..a7cd12ef9 100644 --- a/tests/test_multisignatures.rs +++ b/tests/test_multisignatures.rs @@ -23,7 +23,7 @@ fn test_xmss_signature() { let mut rng: StdRng = StdRng::seed_from_u64(0); let msg = rng.random(); - let (secret_key, pub_key) = xmss_key_gen(rng.random(), start_slot, end_slot).unwrap(); + let (secret_key, pub_key) = xmss_key_gen(rng.random(), start_slot, end_slot, false).unwrap(); let signature = xmss_sign(&mut rng, &secret_key, &msg, slot).unwrap(); xmss_verify(&pub_key, &msg, &signature, slot).unwrap(); } @@ -91,7 +91,7 @@ fn test_multi_message_aggregation() { let raws_b: Vec<_> = (0..2) .map(|_| { - let (sk, pk) = xmss_key_gen(rng_b.random(), slot_b, slot_b).unwrap(); + let (sk, pk) = xmss_key_gen(rng_b.random(), slot_b, slot_b, false).unwrap(); let sig = xmss_sign(&mut rng_b, &sk, &message_b, slot_b).unwrap(); (pk, sig) })