Skip to content

Commit baf649c

Browse files
committed
Merge branch 'main' into devnet4
2 parents 5bbaae5 + 9fbf13f commit baf649c

9 files changed

Lines changed: 916 additions & 250 deletions

File tree

crates/air/src/prove.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,7 @@ where
5454
columns_up_down_group_packed,
5555
air,
5656
&extra_data,
57-
Some((zerocheck_challenges, None)),
57+
Some(zerocheck_challenges),
5858
prover_state,
5959
virtual_column_statement
6060
.as_ref()

crates/backend/sumcheck/src/lib.rs

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,8 @@
11
#![cfg_attr(not(test), warn(unused_crate_dependencies))]
22

3+
mod split_eq;
4+
pub use split_eq::*;
5+
36
mod prove;
47
pub use prove::*;
58

crates/backend/sumcheck/src/prove.rs

Lines changed: 21 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ pub fn sumcheck_prove<'a, EF, SC, M: Into<MleGroup<'a, EF>>>(
1111
multilinears_f: M,
1212
computation: &SC,
1313
extra_data: &SC::ExtraData,
14-
eq_factor: Option<(Vec<EF>, Option<MleOwned<EF>>)>, // (a, b, c ...), eq_poly(b, c, ...)
14+
eq_factor: Option<Vec<EF>>,
1515
prover_state: &mut impl FSProver<EF>,
1616
sum: EF,
1717
store_intermediate_foldings: bool,
@@ -39,7 +39,7 @@ pub fn sumcheck_fold_and_prove<'a, EF, SC, M: Into<MleGroup<'a, EF>>>(
3939
prev_folding_factor: Option<EF>,
4040
computation: &SC,
4141
extra_data: &SC::ExtraData,
42-
eq_factor: Option<(Vec<EF>, Option<MleOwned<EF>>)>, // (a, b, c ...), eq_poly(b, c, ...)
42+
eq_factor: Option<Vec<EF>>,
4343
prover_state: &mut impl FSProver<EF>,
4444
sum: EF,
4545
store_intermediate_foldings: bool,
@@ -88,7 +88,7 @@ pub fn sumcheck_prove_many_rounds<'a, EF, SC, M: Into<MleGroup<'a, EF>>>(
8888
mut prev_folding_factor: Option<EF>,
8989
computation: &SC,
9090
extra_data: &SC::ExtraData,
91-
mut eq_factor: Option<(Vec<EF>, Option<MleOwned<EF>>)>, // (a, b, c ...), eq_poly(b, c, ...)
91+
mut eq_factor: Option<Vec<EF>>,
9292
prover_state: &mut impl FSProver<EF>,
9393
mut sum: EF,
9494
mut missing_mul_factors: Option<EF>,
@@ -102,49 +102,31 @@ where
102102
SC::ExtraData: AlphaPowers<EF>,
103103
{
104104
let mut multilinears: MleGroup<'a, EF> = multilinears_f.into();
105-
106-
let mut eq_factor: Option<(Vec<EF>, MleOwned<EF>)> = eq_factor.take().map(|(eq_point, eq_mle)| {
107-
let eq_mle = eq_mle.unwrap_or_else(|| {
108-
let eval_eq_ext = eval_eq(&eq_point[1..]);
109-
if multilinears.by_ref().is_packed() {
110-
MleOwned::ExtensionPacked(pack_extension(&eval_eq_ext))
111-
} else {
112-
MleOwned::Extension(eval_eq_ext)
113-
}
114-
});
115-
(eq_point, eq_mle)
116-
});
117-
118105
let mut n_vars = multilinears.by_ref().n_vars();
119106
if prev_folding_factor.is_some() {
120107
n_vars -= 1;
121108
}
122-
if let Some((eq_point, eq_mle)) = &eq_factor {
109+
110+
let mut eq_factor_and_split: Option<(Vec<EF>, SplitEq<EF>)> = eq_factor.take().map(|eq_point| {
123111
assert_eq!(eq_point.len(), n_vars);
124-
assert_eq!(eq_mle.by_ref().n_vars(), eq_point.len() - 1);
125-
if eq_mle.by_ref().is_packed() && !multilinears.is_packed() {
126-
assert!(eq_point.len() < packing_log_width::<EF>());
127-
multilinears = multilinears.by_ref().unpack().as_owned_or_clone().into();
128-
}
129-
}
112+
let split_eq = SplitEq::new(&eq_point[1..]);
113+
(eq_point, split_eq)
114+
});
130115

131116
let mut challenges = Vec::new();
132117
for _ in 0..n_rounds {
133118
// If Packing is enabled, and there are too little variables, we unpack everything:
134119
if multilinears.by_ref().is_packed() && n_vars <= 1 + packing_log_width::<EF>() {
135120
// unpack
136121
multilinears = multilinears.by_ref().unpack().as_owned_or_clone().into();
137-
138-
if let Some((_, eq_mle)) = &mut eq_factor {
139-
*eq_mle = eq_mle.by_ref().unpack().as_owned_or_clone();
140-
}
122+
// SplitEq handles unpacking transparently via get_unpacked
141123
}
142124

143125
let ps = compute_and_send_polynomial(
144126
&mut multilinears,
145127
prev_folding_factor,
146128
computation,
147-
&eq_factor,
129+
&eq_factor_and_split,
148130
extra_data,
149131
prover_state,
150132
sum,
@@ -157,7 +139,7 @@ where
157139
prev_folding_factor = on_challenge_received(
158140
&mut multilinears,
159141
&mut n_vars,
160-
&mut eq_factor,
142+
&mut eq_factor_and_split,
161143
&mut sum,
162144
&mut missing_mul_factors,
163145
challenge,
@@ -178,7 +160,7 @@ fn compute_and_send_polynomial<'a, EF, SC>(
178160
multilinears: &mut MleGroup<'a, EF>,
179161
prev_folding_factor: Option<EF>,
180162
computation: &SC,
181-
eq_factor: &Option<(Vec<EF>, MleOwned<EF>)>, // (a, b, c ...), eq_poly(b, c, ...)
163+
eq_factor_and_split: &Option<(Vec<EF>, SplitEq<EF>)>,
182164
extra_data: &SC::ExtraData,
183165
prover_state: &mut impl FSProver<EF>,
184166
sum: EF,
@@ -196,8 +178,10 @@ where
196178
let computation_degree = computation.degree();
197179

198180
let sc_params = SumcheckComputeParams {
199-
eq_mle: eq_factor.as_ref().map(|(_, eq_mle)| eq_mle),
200-
first_eq_factor: eq_factor.as_ref().map(|(first_eq_factor, _)| first_eq_factor[0]),
181+
split_eq: eq_factor_and_split.as_ref().map(|(_, split_eq)| split_eq),
182+
first_eq_factor: eq_factor_and_split
183+
.as_ref()
184+
.map(|(first_eq_factor, _)| first_eq_factor[0]),
201185
computation,
202186
extra_data,
203187
missing_mul_factor,
@@ -217,7 +201,7 @@ where
217201
None => sumcheck_compute(&multilinears.by_ref(), sc_params, computation_degree),
218202
});
219203

220-
let p_at_1 = if let Some((eq_factor, _)) = eq_factor {
204+
let p_at_1 = if let Some((eq_factor, _)) = eq_factor_and_split {
221205
(sum - (EF::ONE - eq_factor[0]) * p_evals[0]) / eq_factor[0]
222206
} else {
223207
sum - p_evals[0]
@@ -232,7 +216,7 @@ where
232216
.collect::<Vec<_>>(),
233217
)
234218
.unwrap();
235-
let eq_alpha = eq_factor.as_ref().map(|(p, _)| p[0]);
219+
let eq_alpha = eq_factor_and_split.as_ref().map(|(p, _)| p[0]);
236220
prover_state.add_sumcheck_polynomial(&poly.coeffs, eq_alpha);
237221
poly
238222
}
@@ -241,7 +225,7 @@ where
241225
fn on_challenge_received<'a, EF: ExtensionField<PF<EF>>>(
242226
multilinears: &mut MleGroup<'a, EF>,
243227
n_vars: &mut usize,
244-
eq_factor: &mut Option<(Vec<EF>, MleOwned<EF>)>, // (a, b, c ...), eq_poly(b, c, ...)
228+
eq_factor: &mut Option<(Vec<EF>, SplitEq<EF>)>,
245229
sum: &mut EF,
246230
missing_mul_factor: &mut Option<EF>,
247231
challenge: EF,
@@ -253,7 +237,7 @@ fn on_challenge_received<'a, EF: ExtensionField<PF<EF>>>(
253237
*sum = p.evaluate(challenge);
254238
*n_vars -= 1;
255239

256-
if let Some((eq_factor, eq_mle)) = eq_factor {
240+
if let Some((eq_factor, split_eq)) = eq_factor {
257241
// Multiply sum by eq(α_i, r_i) since the polynomial doesn't include the eq linear factor
258242
let eq_eval = (EF::ONE - eq_factor[0]) * (EF::ONE - challenge) + eq_factor[0] * challenge;
259243
*sum *= eq_eval;
@@ -262,7 +246,7 @@ fn on_challenge_received<'a, EF: ExtensionField<PF<EF>>>(
262246
eq_eval * missing_mul_factor.unwrap_or(EF::ONE) / (EF::ONE - eq_factor.get(1).copied().unwrap_or_default()),
263247
);
264248
eq_factor.remove(0);
265-
eq_mle.truncate(eq_mle.by_ref().packed_len() / 2);
249+
split_eq.truncate_half();
266250
}
267251

268252
if store_intermediate_foldings {

0 commit comments

Comments
 (0)