Skip to content

Commit c3a5c2b

Browse files
committed
skip the last 5 sumchecks of logup-GKR (send data in clear instead)
1 parent 6834816 commit c3a5c2b

5 files changed

Lines changed: 56 additions & 29 deletions

File tree

crates/rec_aggregation/recursion.py

Lines changed: 22 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,9 @@
44

55
N_TABLES = N_TABLES_PLACEHOLDER
66

7+
LOGUP_GKR_N_VARS_TO_SEND_COEFFS = LOGUP_GKR_N_VARS_TO_SEND_COEFFS_PLACEHOLDER
8+
LOGUP_GKR_N_COEFFS_SENT = 2**LOGUP_GKR_N_VARS_TO_SEND_COEFFS
9+
710
MIN_LOG_N_ROWS_PER_TABLE = MIN_LOG_N_ROWS_PER_TABLE_PLACEHOLDER
811
MAX_LOG_N_ROWS_PER_TABLE = MAX_LOG_N_ROWS_PER_TABLE_PLACEHOLDER
912
MIN_LOG_MEMORY_SIZE = MIN_LOG_MEMORY_SIZE_PLACEHOLDER
@@ -673,27 +676,33 @@ def fingerprint_bytecode(instr_evals, eval_on_pc, logup_alphas_eq_poly):
673676

674677

675678
def verify_gkr_quotient(fs: Mut, n_vars):
676-
fs, nums = fs_receive_ef_inlined(fs, 2)
677-
fs, denoms = fs_receive_ef_inlined(fs, 2)
678-
679-
q1 = div_extension_ret(nums, denoms)
680-
q2 = div_extension_ret(nums + DIM, denoms + DIM)
681-
quotient = add_extension_ret(q1, q2)
679+
fs, nums = fs_receive_ef_inlined(fs, LOGUP_GKR_N_COEFFS_SENT)
680+
fs, denoms = fs_receive_ef_inlined(fs, LOGUP_GKR_N_COEFFS_SENT)
681+
682+
initial_quotients = Array(LOGUP_GKR_N_COEFFS_SENT * DIM)
683+
for k in unroll(0, LOGUP_GKR_N_COEFFS_SENT):
684+
div_extension(nums + k * DIM, denoms + k * DIM, initial_quotients + k * DIM)
685+
debug_assert(NUM_REPEATED_ONES <= LOGUP_GKR_N_COEFFS_SENT)
686+
debug_assert(LOGUP_GKR_N_COEFFS_SENT % NUM_REPEATED_ONES == 0)
687+
quotient: Mut = ZERO_VEC_PTR
688+
for k in unroll(0, LOGUP_GKR_N_COEFFS_SENT / NUM_REPEATED_ONES):
689+
quotient = add_extension_ret(quotient, sum_continuous_ef(initial_quotients + k * NUM_REPEATED_ONES * DIM, NUM_REPEATED_ONES))
682690

683691
points = Array(n_vars)
684692
claims_num = Array(n_vars)
685693
claims_den = Array(n_vars)
686694

687-
fs, points[0] = fs_sample_ef(fs)
695+
fs, initial_point = fs_sample_many_ef(fs, LOGUP_GKR_N_VARS_TO_SEND_COEFFS)
696+
points[LOGUP_GKR_N_VARS_TO_SEND_COEFFS - 1] = initial_point
688697

689-
point_poly_eq = poly_eq_extension(points[0], 1)
698+
point_poly_eq = poly_eq_extension(initial_point, LOGUP_GKR_N_VARS_TO_SEND_COEFFS)
690699

691-
first_claim_num = dot_product_ee_ret(nums, point_poly_eq, 2)
692-
first_claim_den = dot_product_ee_ret(denoms, point_poly_eq, 2)
693-
claims_num[0] = first_claim_num
694-
claims_den[0] = first_claim_den
700+
first_claim_num = dot_product_ee_ret(nums, point_poly_eq, LOGUP_GKR_N_COEFFS_SENT)
701+
first_claim_den = dot_product_ee_ret(denoms, point_poly_eq, LOGUP_GKR_N_COEFFS_SENT)
702+
claims_num[LOGUP_GKR_N_VARS_TO_SEND_COEFFS - 1] = first_claim_num
703+
claims_den[LOGUP_GKR_N_VARS_TO_SEND_COEFFS - 1] = first_claim_den
695704

696-
for i in range(1, n_vars):
705+
for i in range(LOGUP_GKR_N_VARS_TO_SEND_COEFFS, n_vars):
697706
fs, points[i], claims_num[i], claims_den[i] = verify_gkr_quotient_step(fs, i, points[i - 1], claims_num[i - 1], claims_den[i - 1])
698707

699708
return (

crates/rec_aggregation/src/compilation.rs

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ use lean_vm::*;
88
use std::collections::{BTreeMap, HashMap};
99
use std::path::Path;
1010
use std::sync::OnceLock;
11-
use sub_protocols::{min_stacked_n_vars, total_whir_statements};
11+
use sub_protocols::{N_VARS_TO_SEND_GKR_COEFFS, min_stacked_n_vars, total_whir_statements};
1212
use tracing::instrument;
1313
use utils::Counter;
1414
use xmss::{LOG_LIFETIME, MESSAGE_LEN_FE, RANDOMNESS_LEN_FE, TARGET_SUM, V, V_GRINDING, W};
@@ -183,6 +183,10 @@ fn build_replacements(
183183
"MAX_NUM_VARIABLES_TO_SEND_COEFFS_PLACEHOLDER".to_string(),
184184
MAX_NUM_VARIABLES_TO_SEND_COEFFS.to_string(),
185185
);
186+
replacements.insert(
187+
"LOGUP_GKR_N_VARS_TO_SEND_COEFFS_PLACEHOLDER".to_string(),
188+
N_VARS_TO_SEND_GKR_COEFFS.to_string(),
189+
);
186190
replacements.insert(
187191
"WHIR_INITIAL_FOLDING_FACTOR_PLACEHOLDER".to_string(),
188192
WHIR_INITIAL_FOLDING_FACTOR.to_string(),

crates/rec_aggregation/utils.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -304,9 +304,14 @@ def mul_base_extension_ret(a, b):
304304
@inline
305305
def div_extension_ret(n, d):
306306
quotient = Array(DIM)
307-
dot_product_ee(d, quotient, n)
307+
div_extension(n, d, quotient)
308308
return quotient
309309

310+
@inline
311+
def div_extension(n, d, res):
312+
dot_product_ee(d, res, n)
313+
return
314+
310315

311316
@inline
312317
def sub_extension(a, b, c):

crates/sub_protocols/src/lib.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,3 +11,4 @@ mod quotient_gkr;
1111
pub use quotient_gkr::*;
1212

1313
pub(crate) const MIN_VARS_FOR_PACKING: usize = 8;
14+
pub const N_VARS_TO_SEND_GKR_COEFFS: usize = 5;

crates/sub_protocols/src/quotient_gkr.rs

Lines changed: 22 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@ use std::ops::Mul;
22

33
use backend::*;
44

5-
use crate::MIN_VARS_FOR_PACKING;
5+
use crate::{MIN_VARS_FOR_PACKING, N_VARS_TO_SEND_GKR_COEFFS};
66

77
/*
88
GKR to compute sum of fractions.
@@ -13,6 +13,8 @@ pub fn prove_gkr_quotient<EF: ExtensionField<PF<EF>>>(
1313
numerators: &MleRef<'_, EF>,
1414
denominators: &MleRef<'_, EF>,
1515
) -> (EF, MultilinearPoint<EF>, EF, EF) {
16+
assert!(numerators.n_vars() == denominators.n_vars());
17+
assert!(numerators.n_vars() > N_VARS_TO_SEND_GKR_COEFFS);
1618
assert!(numerators.is_packed() == denominators.is_packed());
1719
let mut layers: Vec<(Mle<'_, EF>, Mle<'_, EF>)> =
1820
vec![(numerators.soft_clone().into(), denominators.soft_clone().into())];
@@ -26,23 +28,23 @@ pub fn prove_gkr_quotient<EF: ExtensionField<PF<EF>>>(
2628
prev_denominators.unpack().as_owned_or_clone().into(),
2729
)
2830
}
29-
if prev_numerators.n_vars() == 1 {
31+
if prev_numerators.n_vars() <= N_VARS_TO_SEND_GKR_COEFFS {
3032
break;
3133
}
3234
let (new_numerators, new_denominators) = sum_quotients(prev_numerators.by_ref(), prev_denominators.by_ref());
3335
layers.push((new_numerators.into(), new_denominators.into()));
3436
}
3537

36-
let (last_numerators, last_denominators) = layers.pop().unwrap();
37-
let last_numerators = last_numerators.as_owned().unwrap();
38-
let last_numerators = last_numerators.as_extension().unwrap();
39-
let last_denominators = last_denominators.as_owned().unwrap();
40-
let last_denominators = last_denominators.as_extension().unwrap();
38+
let (last_numerators_mle, last_denominators_mle) = layers.pop().unwrap();
39+
let last_numerators_owned = last_numerators_mle.unpack().as_owned_or_clone();
40+
let last_denominators_owned = last_denominators_mle.unpack().as_owned_or_clone();
41+
let last_numerators = last_numerators_owned.as_extension().unwrap();
42+
let last_denominators = last_denominators_owned.as_extension().unwrap();
4143
prover_state.add_extension_scalars(last_numerators);
4244
prover_state.add_extension_scalars(last_denominators);
43-
let quotient = last_numerators[0] / last_denominators[0] + last_numerators[1] / last_denominators[1];
45+
let quotient = compute_quotient(last_numerators, last_denominators);
4446

45-
let mut point = MultilinearPoint(vec![prover_state.sample()]);
47+
let mut point = MultilinearPoint(prover_state.sample_vec(N_VARS_TO_SEND_GKR_COEFFS));
4648
let mut claims = vec![last_numerators.evaluate(&point), last_denominators.evaluate(&point)];
4749

4850
for (nums, denoms) in layers.iter().rev() {
@@ -208,13 +210,15 @@ pub fn verify_gkr_quotient<EF: ExtensionField<PF<EF>>>(
208210
verifier_state: &mut impl FSVerifier<EF>,
209211
n_vars: usize,
210212
) -> Result<(EF, MultilinearPoint<EF>, EF, EF), ProofError> {
211-
let last_nums = verifier_state.next_extension_scalars_vec(2)?;
212-
let last_dens = verifier_state.next_extension_scalars_vec(2)?;
213-
let quotient = last_nums[0] / last_dens[0] + last_nums[1] / last_dens[1];
214-
let mut point = MultilinearPoint(vec![verifier_state.sample()]);
213+
assert!(n_vars > N_VARS_TO_SEND_GKR_COEFFS);
214+
let send_len = 1 << N_VARS_TO_SEND_GKR_COEFFS;
215+
let last_nums = verifier_state.next_extension_scalars_vec(send_len)?;
216+
let last_dens = verifier_state.next_extension_scalars_vec(send_len)?;
217+
let quotient: EF = compute_quotient(&last_nums, &last_dens);
218+
let mut point = MultilinearPoint(verifier_state.sample_vec(N_VARS_TO_SEND_GKR_COEFFS));
215219
let mut claims_num = last_nums.evaluate(&point);
216220
let mut claims_den = last_dens.evaluate(&point);
217-
for i in 1..n_vars {
221+
for i in N_VARS_TO_SEND_GKR_COEFFS..n_vars {
218222
(point, claims_num, claims_den) = verify_gkr_quotient_step(verifier_state, i, &point, claims_num, claims_den)?;
219223
}
220224
Ok((quotient, point, claims_num, claims_den))
@@ -291,6 +295,10 @@ where
291295
(new_numerators, new_denominators)
292296
}
293297

298+
fn compute_quotient<EF: ExtensionField<PF<EF>>>(numerators: &[EF], denominators: &[EF]) -> EF {
299+
numerators.iter().zip(denominators).map(|(&n, &d)| n / d).sum()
300+
}
301+
294302
#[cfg(test)]
295303
mod tests {
296304
use super::*;

0 commit comments

Comments
 (0)