Skip to content

Commit e5cd331

Browse files
authored
Padding free AIR-sumcheck (#196)
* padding free sumcheck wip * w * w * clippy --------- Co-authored-by: Tom Wambsgans <TomWambsgans@users.noreply.github.com>
1 parent c3a5c2b commit e5cd331

11 files changed

Lines changed: 481 additions & 79 deletions

File tree

crates/backend/poly/src/evals.rs

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -346,6 +346,25 @@ where
346346
}
347347
}
348348

349+
/// evaluate the MLE of [0, 0, ..., 0, 1, 1, ..., 1] at `alphas` where there are t zeros
350+
pub fn evaluate_mle_of_zero_then_ones<EF: ExtensionField<PF<EF>>>(t: usize, alphas: &[EF]) -> EF {
351+
let n = alphas.len();
352+
if t == 0 {
353+
return EF::ONE;
354+
}
355+
if t >= 1usize << n {
356+
return EF::ZERO;
357+
}
358+
let half = 1usize << (n - 1);
359+
let alpha = alphas[0];
360+
let sub = &alphas[1..];
361+
if t < half {
362+
(EF::ONE - alpha) * evaluate_mle_of_zero_then_ones(t, sub) + alpha
363+
} else {
364+
alpha * evaluate_mle_of_zero_then_ones(t - half, sub)
365+
}
366+
}
367+
349368
#[cfg(test)]
350369
mod tests {
351370
use std::time::Instant;

crates/backend/poly/src/mle/mle_group.rs

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -75,4 +75,10 @@ impl<'a, EF: ExtensionField<PF<EF>>> MleGroup<'a, EF> {
7575
Self::Ref(r) => r.clone_to_owned(),
7676
}
7777
}
78+
79+
pub fn unpack_if_needed(&mut self) {
80+
if self.is_packed() && must_unpack_multilinears::<EF>(self.n_vars()) {
81+
*self = self.by_ref().unpack().as_owned_or_clone().into();
82+
}
83+
}
7884
}

crates/backend/poly/src/mle/mle_group_ref.rs

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -118,6 +118,24 @@ impl<'a, EF: ExtensionField<PF<EF>>> MleGroupRef<'a, EF> {
118118
}
119119
}
120120

121+
pub fn fold_at_bit(&self, alpha: EF, bit: usize) -> MleGroupOwned<EF> {
122+
match self {
123+
Self::Base(pols) => {
124+
MleGroupOwned::Extension(batch_fold_multilinears_at_bit(pols, alpha, bit, |a, b| b * a))
125+
}
126+
Self::Extension(pols) => {
127+
MleGroupOwned::Extension(batch_fold_multilinears_at_bit(pols, alpha, bit, |a, b| b * a))
128+
}
129+
Self::BasePacked(pols) => {
130+
let alpha_packed = EFPacking::<EF>::from(alpha);
131+
MleGroupOwned::ExtensionPacked(batch_fold_multilinears_at_bit(pols, alpha_packed, bit, |a, b| b * a))
132+
}
133+
Self::ExtensionPacked(pols) => {
134+
MleGroupOwned::ExtensionPacked(batch_fold_multilinears_at_bit(pols, alpha, bit, |a, b| a * b))
135+
}
136+
}
137+
}
138+
121139
pub fn merge(mles: &'a [&'a MleRef<'a, EF>]) -> Self {
122140
match &mles[0] {
123141
MleRef::Base(_) => Self::Base(mles.iter().map(|m| m.as_base().unwrap()).collect()),

crates/backend/poly/src/utils.rs

Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -89,6 +89,45 @@ pub fn batch_fold_multilinears<
8989
}
9090
}
9191

92+
pub fn fold_multilinear_at_bit<
93+
EF: PrimeCharacteristicRing + Copy + Send + Sync,
94+
IF: Copy + Sub<Output = IF> + Send + Sync,
95+
OF: Copy + Add<IF, Output = OF> + Send + Sync,
96+
Mul: Fn(IF, EF) -> OF + Sync + Send,
97+
>(
98+
m: &[IF],
99+
alpha: EF,
100+
bit: usize,
101+
mul_if_of: &Mul,
102+
) -> Vec<OF> {
103+
let new_size = m.len() / 2;
104+
assert!(m.len() >= 2 * (1 << bit), "bit out of range for slice length");
105+
let stride = 1usize << bit;
106+
let lo_mask = stride - 1;
107+
let mut res = unsafe { uninitialized_vec(new_size) };
108+
109+
let compute = |new_j: usize| {
110+
let i_hi = new_j >> bit;
111+
let i_lo = new_j & lo_mask;
112+
let i0 = (i_hi << (bit + 1)) | i_lo;
113+
let i1 = i0 | stride;
114+
mul_if_of(m[i1] - m[i0], alpha) + m[i0]
115+
};
116+
117+
if new_size < PARALLEL_THRESHOLD {
118+
for (new_j, res_v) in res.iter_mut().enumerate() {
119+
*res_v = compute(new_j);
120+
}
121+
} else {
122+
(0..new_size)
123+
.into_par_iter()
124+
.with_min_len(PARALLEL_THRESHOLD)
125+
.map(compute)
126+
.collect_into_vec(&mut res);
127+
}
128+
res
129+
}
130+
92131
pub fn fold_multilinear<
93132
EF: PrimeCharacteristicRing + Copy + Send + Sync,
94133
IF: Copy + Sub<Output = IF> + Send + Sync,
@@ -116,6 +155,31 @@ pub fn fold_multilinear<
116155
res
117156
}
118157

158+
pub fn batch_fold_multilinears_at_bit<
159+
EF: PrimeCharacteristicRing + Copy + Send + Sync,
160+
IF: Copy + Sub<Output = IF> + Send + Sync,
161+
OF: Copy + Add<IF, Output = OF> + Send + Sync,
162+
F: Fn(IF, EF) -> OF + Sync + Send,
163+
>(
164+
polys: &[&[IF]],
165+
alpha: EF,
166+
bit: usize,
167+
mul_if_of: F,
168+
) -> Vec<Vec<OF>> {
169+
let total_size: usize = polys.iter().map(|p| p.len()).sum();
170+
if total_size < PARALLEL_THRESHOLD {
171+
polys
172+
.iter()
173+
.map(|poly| fold_multilinear_at_bit(poly, alpha, bit, &mul_if_of))
174+
.collect()
175+
} else {
176+
polys
177+
.par_iter()
178+
.map(|poly| fold_multilinear_at_bit(poly, alpha, bit, &mul_if_of))
179+
.collect()
180+
}
181+
}
182+
119183
/// Returns a vector of uninitialized elements of type `A` with the specified length.
120184
/// # Safety
121185
/// Entries should be overwritten before use.

crates/backend/sumcheck/src/prove.rs

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -115,9 +115,7 @@ where
115115

116116
let mut challenges = Vec::new();
117117
for _ in 0..n_rounds {
118-
if multilinears.by_ref().is_packed() && must_unpack_multilinears::<EF>(n_vars) {
119-
multilinears = multilinears.by_ref().unpack().as_owned_or_clone().into();
120-
}
118+
multilinears.unpack_if_needed();
121119

122120
let ps = compute_and_send_polynomial(
123121
&mut multilinears,

crates/lean_prover/src/prove_execution.rs

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -175,9 +175,11 @@ pub fn prove_execution(
175175
up_down.extend(shifted_rows[idx].iter().map(Vec::as_slice));
176176
let packed = MleGroupRef::<EF>::Base(up_down).pack();
177177

178+
let non_padded = traces[table].non_padded_n_rows;
179+
178180
macro_rules! make_session {
179181
($t:expr) => {{
180-
let session = AirSumcheckSession::new(packed, eq_suffix, bus_final_value, *$t, extra_data);
182+
let session = AirSumcheckSession::new(packed, eq_suffix, bus_final_value, *$t, extra_data, non_padded);
181183
Box::new(session) as Box<dyn OuterSumcheckSession<EF> + '_>
182184
}};
183185
}
@@ -187,12 +189,14 @@ pub fn prove_execution(
187189
let sumcheck_air_point = info_span!("batched AIR sumcheck")
188190
.in_scope(|| prove_batched_air_sumcheck(&mut prover_state, &mut sessions, air_eta));
189191

190-
for (idx, (table, log_n_rows)) in tables_sorted.iter().enumerate() {
192+
for (idx, (table, _)) in tables_sorted.iter().enumerate() {
191193
let col_evals = sessions[idx].final_column_evals();
192194
prover_state.add_extension_scalars(&col_evals);
193195

196+
let natural_ordering_point =
197+
natural_ordering_point_for_session(&sumcheck_air_point.0, traces[table].log_n_rows);
194198
macro_rules! split {
195-
($t:expr) => {{ columns_evals_up_and_down($t, &col_evals, &sumcheck_air_point.0, *log_n_rows) }};
199+
($t:expr) => {{ columns_evals_up_and_down($t, &col_evals, &natural_ordering_point) }};
196200
}
197201
let claim = delegate_to_inner!(table => split);
198202
committed_statements.get_mut(table).unwrap().push(claim);

crates/lean_prover/src/verify_execution.rs

Lines changed: 26 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -149,11 +149,17 @@ pub fn verify_execution(
149149
let constraint_eval = delegate_to_inner!(&vd.table => eval_constraint);
150150

151151
let bus_point = from_end(gkr_point, table_n_vars[&vd.table]);
152-
my_air_final_value +=
153-
back_loaded_table_contribution(bus_point, &sumcheck_air_point.0, constraint_eval, vd.eta_power);
152+
let natural_ordering_point = natural_ordering_point_for_session(&sumcheck_air_point.0, table_n_vars[&vd.table]);
153+
my_air_final_value += back_loaded_table_contribution(
154+
bus_point,
155+
&sumcheck_air_point.0,
156+
&natural_ordering_point,
157+
constraint_eval,
158+
vd.eta_power,
159+
);
154160

155161
macro_rules! split {
156-
($t:expr) => {{ columns_evals_up_and_down($t, &col_evals, &sumcheck_air_point.0, table_n_vars[&vd.table]) }};
162+
($t:expr) => {{ columns_evals_up_and_down($t, &col_evals, &natural_ordering_point) }};
157163
}
158164
let claim = delegate_to_inner!(&vd.table => split);
159165

@@ -218,3 +224,20 @@ pub fn verify_execution(
218224
verifier_state.into_raw_proof(),
219225
))
220226
}
227+
228+
fn back_loaded_table_contribution<EF: ExtensionField<PF<EF>>>(
229+
bus_point: &[EF],
230+
sumcheck_air_point: &[EF],
231+
natural_ordering_point: &[EF],
232+
constraint_eval: EF,
233+
eta_power: EF,
234+
) -> EF {
235+
let n_t = bus_point.len();
236+
let n_max = sumcheck_air_point.len();
237+
let suffix_start = n_max - n_t;
238+
assert_eq!(natural_ordering_point.len(), n_t);
239+
let eq_val =
240+
MultilinearPoint(bus_point.to_vec()).eq_poly_outside(&MultilinearPoint(natural_ordering_point.to_vec()));
241+
let k_t: EF = sumcheck_air_point[..suffix_start].iter().copied().product();
242+
eta_power * k_t * eq_val * constraint_eval
243+
}

crates/rec_aggregation/recursion.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -432,8 +432,8 @@ def continue_recursion_ordered(
432432

433433
bus_point = pcs_points[table_index][0]
434434
suffix_start = n_max - log_n_rows
435-
challenge_suffix = all_challenges + suffix_start * DIM
436-
eq_val = eq_mle_extension(bus_point, challenge_suffix, log_n_rows)
435+
natural_ordering_point = natural_ordering_point_for_session(all_challenges, suffix_start, log_n_rows)
436+
eq_val = eq_mle_extension(bus_point, natural_ordering_point, log_n_rows)
437437

438438
k_t = product_first_n(all_challenges, suffix_start)
439439

@@ -443,7 +443,7 @@ def continue_recursion_ordered(
443443
)
444444
check_sum = add_extension_ret(check_sum, contribution)
445445

446-
pcs_points[table_index].push(challenge_suffix)
446+
pcs_points[table_index].push(natural_ordering_point)
447447
pcs_values[table_index].push(DynArray([]))
448448
pcs_values_down[table_index].push(DynArray([]))
449449
last_index = len(pcs_values[table_index]) - 1

crates/rec_aggregation/utils.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,21 @@ def product_first_n(values, n):
7171
return res
7272

7373

74+
def natural_ordering_point_for_session(all_challenges, suffix_start, log_n_rows):
75+
res = Array(log_n_rows * DIM)
76+
match_range(log_n_rows, range(1, 33), lambda m: _natural_ordering_point_const(all_challenges, suffix_start, res, m))
77+
return res
78+
79+
80+
def _natural_ordering_point_const(all_challenges, suffix_start, dst, m: Const):
81+
half = div_ceil(m, 2)
82+
for t in unroll(0, half):
83+
copy_5(all_challenges + (suffix_start + t) * DIM, dst + (half - 1 - t) * DIM)
84+
for t in unroll(half, m):
85+
copy_5(all_challenges + (suffix_start + t) * DIM, dst + t * DIM)
86+
return
87+
88+
7489
@inline
7590
def product_first_n_const(values, n):
7691
debug_assert(n != 0)

0 commit comments

Comments
 (0)