From db482f809109c6857e6e3c481625a41d82eb5948 Mon Sep 17 00:00:00 2001 From: Tom Wambsgans Date: Mon, 1 Jun 2026 03:54:40 +0400 Subject: [PATCH 01/31] naming --- crates/lean_prover/python-verifier/verifier.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/crates/lean_prover/python-verifier/verifier.py b/crates/lean_prover/python-verifier/verifier.py index b61aeb5d..4e5c45e5 100644 --- a/crates/lean_prover/python-verifier/verifier.py +++ b/crates/lean_prover/python-verifier/verifier.py @@ -86,7 +86,7 @@ def n_columns(self) -> int: return len(self.columns) @property - def n_buses(self) -> int: + def n_bus_interractions(self) -> int: return sum(b[3] if b[0] == BusInteraction.MEMORY else 1 for b in self.buses) @property @@ -608,7 +608,7 @@ def verify_generic_logup( tallest_h = tables_sorted[0][1] total_active_len = ( - (1 << log_memory) + max(1 << log_bytecode, 1 << tallest_h) + sum(t.n_buses << h for t, h in tables_sorted) + (1 << log_memory) + max(1 << log_bytecode, 1 << tallest_h) + sum(t.n_bus_interractions << h for t, h in tables_sorted) ) total_gkr_n_vars = log2_ceil(total_active_len) @@ -654,7 +654,7 @@ def pref_at(offset: int, log_height: int) -> EF: table_offsets: dict[str, int] = {} for table, log_n_rows in tables_sorted: table_offsets[table.name] = offset - offset += table.n_buses << log_n_rows + offset += table.n_bus_interractions << log_n_rows final_offset = offset bus_num_vals: dict[str, EF] = {} From b060daed030915a8ccb5928f7043b2a156d720d7 Mon Sep 17 00:00:00 2001 From: Tom Wambsgans Date: Mon, 1 Jun 2026 04:50:18 +0400 Subject: [PATCH 02/31] simplify bus --- .../lean_prover/python-verifier/verifier.py | 100 ++++++++---------- 1 file changed, 45 insertions(+), 55 deletions(-) diff --git a/crates/lean_prover/python-verifier/verifier.py b/crates/lean_prover/python-verifier/verifier.py index 4e5c45e5..c06d5b80 100644 --- a/crates/lean_prover/python-verifier/verifier.py +++ b/crates/lean_prover/python-verifier/verifier.py @@ -64,17 +64,19 @@ class BusDirection(IntEnum): PULL = -1 -class BusInteraction(IntEnum): - PRECOMPILE = 0 - BYTECODE = 1 - MEMORY = 2 +@dataclass(frozen=True) +class BusInteraction: + direction: BusDirection + domain_sep: int = 0 + cols: tuple[str, ...] = () # committed columns forming σ (address column first, for memory) + n_terms: int = 1 # number of logup terms (memory groups: consecutive cells sharing cols[0]) @dataclass(frozen=True) class Table: name: str columns: tuple[str, ...] - buses: tuple + buses: tuple[BusInteraction, ...] air_degree: int n_constraints: int n_shift: int # shift (next-row) columns are always the first ones @@ -87,11 +89,11 @@ def n_columns(self) -> int: @property def n_bus_interractions(self) -> int: - return sum(b[3] if b[0] == BusInteraction.MEMORY else 1 for b in self.buses) + return sum(b.n_terms for b in self.buses) @property def precompile_bus_interraction_sign(self) -> EF: - return EF(self.buses[0][1]) # precompile interraction is the first, by convention + return EF(self.buses[0].direction) # precompile interraction is the first, by convention def col(self, name: str) -> int: return self.columns.index(name) @@ -608,7 +610,9 @@ def verify_generic_logup( tallest_h = tables_sorted[0][1] total_active_len = ( - (1 << log_memory) + max(1 << log_bytecode, 1 << tallest_h) + sum(t.n_bus_interractions << h for t, h in tables_sorted) + (1 << log_memory) + + max(1 << log_bytecode, 1 << tallest_h) + + sum(t.n_bus_interractions << h for t, h in tables_sorted) ) total_gkr_n_vars = log2_ceil(total_active_len) @@ -665,48 +669,34 @@ def pref_at(offset: int, log_height: int) -> EF: name = table.name log_n_rows = heights[name] row_stride = 1 << log_n_rows - offset_within_table = table_offsets[name] - table_values: dict[int, EF] = {} + offset = table_offsets[name] + vals: dict[int, EF] = {} - def read_fresh(cols: list[int]) -> None: - """Read one extension scalar per column not yet in `table_values`, in order.""" - missing = [c for c in cols if c not in table_values] + def read(cols: Sequence[int]) -> list[EF]: + """Evals of `cols`, batch-reading any not-yet-seen column from the transcript, in order.""" + missing = [c for c in cols if c not in vals] for c, e in zip(missing, fiat_shamir.next_extension_scalars_vec(len(missing))): - table_values[c] = e + vals[c] = e + return [vals[c] for c in cols] for bus in table.buses: - pref = pref_at(offset_within_table, log_n_rows) - kind = bus[0] - if kind == BusInteraction.PRECOMPILE: + if not bus.cols: + pref = pref_at(offset, log_n_rows) bus_num_vals[name] = fiat_shamir.next_extension_scalar() bus_den_vals[name] = fiat_shamir.next_extension_scalar() num += pref * bus_num_vals[name] den += pref * bus_den_vals[name] - n_sub = 1 - elif kind == BusInteraction.BYTECODE: - cols = list(range(N_RUNTIME_COLUMNS, N_RUNTIME_COLUMNS + N_INSTRUCTION_COLUMNS)) + [table.col("pc")] - read_fresh(cols) - evals = [table_values[c] for c in cols] + offset += row_stride + continue + sep, base = Fp(bus.domain_sep), [table.col(c) for c in bus.cols] # memory / bytecode + for i in range(bus.n_terms): # term i: σ = (m[base[0]] + i, m[base[1:] + i]) + pref = pref_at(offset, log_n_rows) + d = read([base[0], *(c + i for c in base[1:])]) num += pref - den += pref * (gamma - finger_print(ds_byte, evals, beta_eq)) - n_sub = 1 - elif kind == BusInteraction.MEMORY: - _, idx_ref, vals_ref, n_sub = bus - idx_col, vals_start = table.col(idx_ref), table.col(vals_ref) - # One sub-bus per cell in the group; the prover sends only the not-yet-seen - # columns per row (idx_col is shared across all n_sub rows). - for i in range(n_sub): - val_col = vals_start + i - read_fresh([idx_col, val_col]) - pref = pref_at(offset_within_table + i * row_stride, log_n_rows) - fp = finger_print(ds_mem, [table_values[idx_col] + i, table_values[val_col]], beta_eq) - num += pref - den += pref * (gamma - fp) - else: - raise ProofError(f"unknown bus kind: {kind}") - offset_within_table += n_sub * row_stride - - columns_values[name] = table_values + den += pref * (gamma - finger_print(sep, [d[0] + i, *d[1:]], beta_eq)) + offset += row_stride + + columns_values[name] = vals den += mle_of_zeros_then_ones(final_offset, point_gkr) if num != claim_num: @@ -981,11 +971,11 @@ def eval_air_poseidon16(folder: ConstraintFolder, logup_beta_eq: list[EF]) -> No name="execution", columns=EXECUTION_COLUMNS, buses=( - (BusInteraction.PRECOMPILE, BusDirection.PUSH), - (BusInteraction.BYTECODE,), - (BusInteraction.MEMORY, "addr_a", "value_a", 1), - (BusInteraction.MEMORY, "addr_b", "value_b", 1), - (BusInteraction.MEMORY, "addr_c", "value_c", 1), + BusInteraction(BusDirection.PUSH), + BusInteraction(BusDirection.PULL, LOGUP_BYTECODE_DOMAINSEP, (*EXECUTION_COLUMNS[N_RUNTIME_COLUMNS:], "pc")), + BusInteraction(BusDirection.PULL, LOGUP_MEMORY_DOMAINSEP, ("addr_a", "value_a")), + BusInteraction(BusDirection.PULL, LOGUP_MEMORY_DOMAINSEP, ("addr_b", "value_b")), + BusInteraction(BusDirection.PULL, LOGUP_MEMORY_DOMAINSEP, ("addr_c", "value_c")), ), air_degree=5, n_constraints=14, @@ -997,10 +987,10 @@ def eval_air_poseidon16(folder: ConstraintFolder, logup_beta_eq: list[EF]) -> No name="extension", columns=EXTENSION_COLUMNS, buses=( - (BusInteraction.PRECOMPILE, BusDirection.PULL), - (BusInteraction.MEMORY, "idx_a", "v_a_0", 5), - (BusInteraction.MEMORY, "idx_b", "v_b_0", 5), - (BusInteraction.MEMORY, "idx_r", "res_0", 5), + BusInteraction(BusDirection.PULL), + BusInteraction(BusDirection.PULL, LOGUP_MEMORY_DOMAINSEP, ("idx_a", "v_a_0"), 5), + BusInteraction(BusDirection.PULL, LOGUP_MEMORY_DOMAINSEP, ("idx_b", "v_b_0"), 5), + BusInteraction(BusDirection.PULL, LOGUP_MEMORY_DOMAINSEP, ("idx_r", "res_0"), 5), ), air_degree=6, n_constraints=35, @@ -1012,11 +1002,11 @@ def eval_air_poseidon16(folder: ConstraintFolder, logup_beta_eq: list[EF]) -> No name="poseidon", columns=POSEIDON_COLUMNS, buses=( - (BusInteraction.PRECOMPILE, BusDirection.PULL), - (BusInteraction.MEMORY, "addr_left_lo", "input_0", 4), - (BusInteraction.MEMORY, "addr_left_hi", "input_4", 4), - (BusInteraction.MEMORY, "nu_b", "input_8", 8), - (BusInteraction.MEMORY, "nu_c", "out_lo_0", 16), + BusInteraction(BusDirection.PULL), + BusInteraction(BusDirection.PULL, LOGUP_MEMORY_DOMAINSEP, ("addr_left_lo", "input_0"), 4), + BusInteraction(BusDirection.PULL, LOGUP_MEMORY_DOMAINSEP, ("addr_left_hi", "input_4"), 4), + BusInteraction(BusDirection.PULL, LOGUP_MEMORY_DOMAINSEP, ("nu_b", "input_8"), 8), + BusInteraction(BusDirection.PULL, LOGUP_MEMORY_DOMAINSEP, ("nu_c", "out_lo_0"), 16), ), air_degree=10, n_constraints=101, From b16e852cacc350a4b148d9edf36457b9c0958249 Mon Sep 17 00:00:00 2001 From: Tom Wambsgans Date: Mon, 1 Jun 2026 04:57:26 +0400 Subject: [PATCH 03/31] wip --- crates/lean_prover/python-verifier/verifier.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/crates/lean_prover/python-verifier/verifier.py b/crates/lean_prover/python-verifier/verifier.py index c06d5b80..37e15dbb 100644 --- a/crates/lean_prover/python-verifier/verifier.py +++ b/crates/lean_prover/python-verifier/verifier.py @@ -754,10 +754,7 @@ def eval_precompile_bus_virtual_columns( def eval_air_execution(folder: ConstraintFolder, logup_beta_eq: list[EF]) -> None: c, n = folder.flat, folder.next (pc, fp, addr_a, addr_b, addr_c, value_a, value_b, value_c, operand_a, operand_b, operand_c, - flag_a, flag_b, flag_c, flag_c_fp, flag_ab_fp, flag_mul, flag_jump, aux_1, aux_2) = (c[k] for k in ( - "pc", "fp", "addr_a", "addr_b", "addr_c", "value_a", "value_b", "value_c", - "operand_a", "operand_b", "operand_c", "flag_a", "flag_b", "flag_c", "flag_c_fp", - "flag_ab_fp", "flag_mul", "flag_jump", "aux_1", "aux_2")) # fmt: skip + flag_a, flag_b, flag_c, flag_c_fp, flag_ab_fp, flag_mul, flag_jump, aux_1, aux_2) = (c[k] for k in EXECUTION_COLUMNS) # fmt: skip pc_shift, fp_shift = n["pc"], n["fp"] # nu_x = flag·operand + (1 − flag − flag_ab_fp)·value + flag_ab_fp·(fp + operand) From 71a82d1f52721918898b301ceda73f7e883d34dd Mon Sep 17 00:00:00 2001 From: Tom Wambsgans Date: Mon, 1 Jun 2026 04:58:30 +0400 Subject: [PATCH 04/31] wip --- crates/lean_prover/python-verifier/verifier.py | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/crates/lean_prover/python-verifier/verifier.py b/crates/lean_prover/python-verifier/verifier.py index 37e15dbb..5b56e547 100644 --- a/crates/lean_prover/python-verifier/verifier.py +++ b/crates/lean_prover/python-verifier/verifier.py @@ -399,11 +399,7 @@ def verify_stir_challenges( op = fiat_shamir.next_merkle_opening() merkle_verify_path(commitment.root, log_height, idx, op.leaf_data, op.path) # Round 0 leaves are raw base-field elements; later rounds pack DIM Fp values per EF element. - leaf = op.leaf_data - if round_index == 0: - packed = leaf - else: - packed = pack_ef(leaf) + packed = op.leaf_data if round_index == 0 else pack_ef(op.leaf_data) fold = eval_multilinear_evals(packed, folding_randomness) ef_pt = EF(pow(int(gen.value), idx, P)) pt = expand_from_univariate(ef_pt, num_variables) From f0d18cc588f07ab8b2ebc469e95e7c329d2111e3 Mon Sep 17 00:00:00 2001 From: Tom Wambsgans Date: Mon, 1 Jun 2026 04:59:43 +0400 Subject: [PATCH 05/31] wip --- crates/lean_prover/python-verifier/verifier.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/crates/lean_prover/python-verifier/verifier.py b/crates/lean_prover/python-verifier/verifier.py index 5b56e547..8bf6d99a 100644 --- a/crates/lean_prover/python-verifier/verifier.py +++ b/crates/lean_prover/python-verifier/verifier.py @@ -88,12 +88,12 @@ def n_columns(self) -> int: return len(self.columns) @property - def n_bus_interractions(self) -> int: + def n_bus_interactions(self) -> int: return sum(b.n_terms for b in self.buses) @property - def precompile_bus_interraction_sign(self) -> EF: - return EF(self.buses[0].direction) # precompile interraction is the first, by convention + def precompile_bus_interaction_sign(self) -> EF: + return EF(self.buses[0].direction) # precompile interaction is the first, by convention def col(self, name: str) -> int: return self.columns.index(name) @@ -608,7 +608,7 @@ def verify_generic_logup( total_active_len = ( (1 << log_memory) + max(1 << log_bytecode, 1 << tallest_h) - + sum(t.n_bus_interractions << h for t, h in tables_sorted) + + sum(t.n_bus_interactions << h for t, h in tables_sorted) ) total_gkr_n_vars = log2_ceil(total_active_len) @@ -654,7 +654,7 @@ def pref_at(offset: int, log_height: int) -> EF: table_offsets: dict[str, int] = {} for table, log_n_rows in tables_sorted: table_offsets[table.name] = offset - offset += table.n_bus_interractions << log_n_rows + offset += table.n_bus_interactions << log_n_rows final_offset = offset bus_num_vals: dict[str, EF] = {} @@ -1081,7 +1081,7 @@ def verify_execution( initial_sum, offset = ZERO, 0 for table in TABLES: - initial_sum += alpha_powers[offset] * (logup["bus_num"][table.name] * table.precompile_bus_interraction_sign) + initial_sum += alpha_powers[offset] * (logup["bus_num"][table.name] * table.precompile_bus_interaction_sign) initial_sum += alpha_powers[offset + 1] * (logup_gamma - logup["bus_den"][table.name]) offset += table.n_constraints sc_point, sc_value = verify_sumcheck(state, initial_sum, n_max, max(t.air_degree + 1 for t in TABLES)) From 0238171db64b8a7dc20e300ab38f11afbde0bf44 Mon Sep 17 00:00:00 2001 From: Tom Wambsgans Date: Mon, 1 Jun 2026 05:01:18 +0400 Subject: [PATCH 06/31] wip --- crates/lean_prover/python-verifier/verifier.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/crates/lean_prover/python-verifier/verifier.py b/crates/lean_prover/python-verifier/verifier.py index 8bf6d99a..9da013cb 100644 --- a/crates/lean_prover/python-verifier/verifier.py +++ b/crates/lean_prover/python-verifier/verifier.py @@ -1172,7 +1172,7 @@ def main() -> int: print(f"FAIL: {e}") return 1 - print(f"Proof successfully verified") + print("Proof successfully verified") return 0 From c4097a54685632c98a5c506a15e224f559c947c7 Mon Sep 17 00:00:00 2001 From: Tom Wambsgans Date: Mon, 1 Jun 2026 05:04:24 +0400 Subject: [PATCH 07/31] wip --- crates/lean_prover/python-verifier/verifier.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/crates/lean_prover/python-verifier/verifier.py b/crates/lean_prover/python-verifier/verifier.py index 9da013cb..6041e215 100644 --- a/crates/lean_prover/python-verifier/verifier.py +++ b/crates/lean_prover/python-verifier/verifier.py @@ -1014,7 +1014,7 @@ def verify_execution( public_input: Sequence[Fp], proof: Proof, bytecode_multilinear: list[int], -): +) -> None: bytecode_log_size = log2_strict(len(bytecode_multilinear)) - log2_ceil(N_INSTRUCTION_COLUMNS) ending_pc = (1 << bytecode_log_size) - 1 bytecode_hash = sponge_hash([Fp(v) for v in bytecode_multilinear]) From 36aa38f35cf779ef80b10b687cf9a93cc12043f6 Mon Sep 17 00:00:00 2001 From: Tom Wambsgans Date: Mon, 1 Jun 2026 05:16:17 +0400 Subject: [PATCH 08/31] wip --- .../lean_prover/python-verifier/verifier.py | 23 +++++++++---------- 1 file changed, 11 insertions(+), 12 deletions(-) diff --git a/crates/lean_prover/python-verifier/verifier.py b/crates/lean_prover/python-verifier/verifier.py index 6041e215..c8be9541 100644 --- a/crates/lean_prover/python-verifier/verifier.py +++ b/crates/lean_prover/python-verifier/verifier.py @@ -358,6 +358,15 @@ class ParsedCommitment: ood_points: list[EF] ood_answers: list[EF] + @classmethod + def read(cls, fs: "FiatShamir", num_variables: int, n_ood: int) -> "ParsedCommitment": + return cls( + num_variables, + fs.next_base_scalars_vec(DIGEST_ELEMS), + fs.sample_many_ef(n_ood), + fs.next_extension_scalars_vec(n_ood), + ) + def oods_constraints(self) -> list[SparseStatements]: return [ SparseStatements(self.num_variables, expand_from_univariate(p, self.num_variables), [(0, ev)]) @@ -446,12 +455,7 @@ def step(constraints: list[SparseStatements], n_fold: int, pow_bits: int) -> Non round_params = cfg["rounds"][r] current_vars -= whir_folding_factor_at_round(r) n_ood_samples = round_params["ood_samples"] - new_commitment = ParsedCommitment( - current_vars, - fiat_shamir.next_base_scalars_vec(DIGEST_ELEMS), - fiat_shamir.sample_many_ef(n_ood_samples), - fiat_shamir.next_extension_scalars_vec(n_ood_samples), - ) + new_commitment = ParsedCommitment.read(fiat_shamir, current_vars, n_ood_samples) stir = verify_stir_challenges( fiat_shamir, r, @@ -1053,12 +1057,7 @@ def verify_execution( raise ProofError("InvalidProof: stacked_n_vars exceeds WHIR domain bound") cfg = WHIR_CONFIGS[(log_inv_rate, stacked_n_vars)] nood = cfg["commitment_ood_samples"] - parsed_commitment = ParsedCommitment( - stacked_n_vars, - state.next_base_scalars_vec(DIGEST_ELEMS), - state.sample_many_ef(nood), - state.next_extension_scalars_vec(nood), - ) + parsed_commitment = ParsedCommitment.read(state, stacked_n_vars, nood) logup_gamma = state.sample_ef() # the quotient denominator state.duplex() From 634d60b784118d052040754ec94f5b238706acff Mon Sep 17 00:00:00 2001 From: Tom Wambsgans Date: Mon, 1 Jun 2026 18:58:49 +0400 Subject: [PATCH 09/31] wip --- .../lean_prover/python-verifier/verifier.py | 100 +++++++++--------- 1 file changed, 50 insertions(+), 50 deletions(-) diff --git a/crates/lean_prover/python-verifier/verifier.py b/crates/lean_prover/python-verifier/verifier.py index c8be9541..beff3e6a 100644 --- a/crates/lean_prover/python-verifier/verifier.py +++ b/crates/lean_prover/python-verifier/verifier.py @@ -81,7 +81,7 @@ class Table: n_constraints: int n_shift: int # shift (next-row) columns are always the first ones max_log_height: int - air_constraints_fn: object # (folder, logup_beta_eq) -> None + air_constraints_fn: object # (constraint_evaluator, logup_beta_eq) -> None @property def n_columns(self) -> int: @@ -99,9 +99,9 @@ def col(self, name: str) -> int: return self.columns.index(name) def eval_air(self, col_evals: Sequence[EF], alpha_powers: Sequence[EF], logup_beta_eq: list[EF]) -> EF: - folder = ConstraintFolder(col_evals[: self.n_columns], col_evals[self.n_columns :], alpha_powers, self.columns) - self.air_constraints_fn(folder, logup_beta_eq) - return folder.accumulator + constraint_evaluator = ConstraintEvaluator(col_evals[: self.n_columns], col_evals[self.n_columns :], alpha_powers, self.columns) + self.air_constraints_fn(constraint_evaluator, logup_beta_eq) + return constraint_evaluator.accumulator def boundary_statements( self, stacked_n_vars: int, offset: int, n_vars: int, ending_pc: int @@ -716,7 +716,7 @@ def arr(self, prefix: str, n: int) -> list: return [self[f"{prefix}_{i}"] for i in range(n)] -class ConstraintFolder: +class ConstraintEvaluator: def __init__( self, flat: Sequence[EF], shift: Sequence[EF], alpha_powers: Sequence[EF], columns: Sequence[str] ) -> None: @@ -741,18 +741,18 @@ def assert_bool(self, x: EF) -> None: def eval_precompile_bus_virtual_columns( - folder: "ConstraintFolder", + evaluator: "ConstraintEvaluator", logup_beta_eq: list[EF], multiplicity: EF, domainsep: EF, data: Sequence[EF], ) -> None: - folder.assert_zero(multiplicity) - folder.assert_zero(finger_print(domainsep, data, logup_beta_eq)) + evaluator.assert_zero(multiplicity) + evaluator.assert_zero(finger_print(domainsep, data, logup_beta_eq)) -def eval_air_execution(folder: ConstraintFolder, logup_beta_eq: list[EF]) -> None: - c, n = folder.flat, folder.next +def eval_air_execution(evaluator: ConstraintEvaluator, logup_beta_eq: list[EF]) -> None: + c, n = evaluator.flat, evaluator.next (pc, fp, addr_a, addr_b, addr_c, value_a, value_b, value_c, operand_a, operand_b, operand_c, flag_a, flag_b, flag_c, flag_c_fp, flag_ab_fp, flag_mul, flag_jump, aux_1, aux_2) = (c[k] for k in EXECUTION_COLUMNS) # fmt: skip pc_shift, fp_shift = n["pc"], n["fp"] @@ -770,25 +770,25 @@ def eval_air_execution(folder: ConstraintFolder, logup_beta_eq: list[EF]) -> Non flag_deref = aux_1 * (aux_1 - ONE) * ((P + 1) // 2) # (P+1)/2 is the inverse of 2 mod P flag_precompile = ONE - flag_add - flag_mul - flag_deref - flag_jump - eval_precompile_bus_virtual_columns(folder, logup_beta_eq, flag_precompile, aux_2, [nu_a, nu_b, nu_c]) - folder.assert_zero(nfa * (addr_a - (fp + operand_a))) - folder.assert_zero(nfb * (addr_b - (fp + operand_b))) - folder.assert_zero(nfc * (addr_c - (fp + operand_c))) - folder.assert_zero(flag_add * (nu_b - (nu_a + nu_c))) - folder.assert_zero(flag_mul * (nu_b - nu_a * nu_c)) - folder.assert_zero(flag_deref * (addr_b - (value_a + operand_b))) - folder.assert_zero(flag_deref * (value_b - nu_c)) + eval_precompile_bus_virtual_columns(evaluator, logup_beta_eq, flag_precompile, aux_2, [nu_a, nu_b, nu_c]) + evaluator.assert_zero(nfa * (addr_a - (fp + operand_a))) + evaluator.assert_zero(nfb * (addr_b - (fp + operand_b))) + evaluator.assert_zero(nfc * (addr_c - (fp + operand_c))) + evaluator.assert_zero(flag_add * (nu_b - (nu_a + nu_c))) + evaluator.assert_zero(flag_mul * (nu_b - nu_a * nu_c)) + evaluator.assert_zero(flag_deref * (addr_b - (value_a + operand_b))) + evaluator.assert_zero(flag_deref * (value_b - nu_c)) jc = flag_jump * nu_a - folder.assert_zero(jc * (nu_a - ONE)) - folder.assert_zero(jc * (pc_shift - nu_b)) - folder.assert_zero(jc * (fp_shift - nu_c)) + evaluator.assert_zero(jc * (nu_a - ONE)) + evaluator.assert_zero(jc * (pc_shift - nu_b)) + evaluator.assert_zero(jc * (fp_shift - nu_c)) not_jc = ONE - jc - folder.assert_zero(not_jc * (pc_shift - (pc + ONE))) - folder.assert_zero(not_jc * (fp_shift - fp)) + evaluator.assert_zero(not_jc * (pc_shift - (pc + ONE))) + evaluator.assert_zero(not_jc * (fp_shift - fp)) -def eval_air_extension(folder: ConstraintFolder, logup_beta_eq: list[EF]) -> None: - c, n = folder.flat, folder.next +def eval_air_extension(evaluator: ConstraintEvaluator, logup_beta_eq: list[EF]) -> None: + c, n = evaluator.flat, evaluator.next flag_be, flag_start, len_col = c["flag_be"], c["flag_start"], c["len"] flag_add, flag_dot_product, flag_eq = c["flag_add"], c["flag_dot_product"], c["flag_eq"] idx_a, idx_b, idx_r = c["idx_a"], c["idx_b"], c["idx_r"] @@ -806,11 +806,11 @@ def eval_air_extension(folder: ConstraintFolder, logup_beta_eq: list[EF]) -> Non + len_col * EXT_OP_LEN_MULTIPLIER ) eval_precompile_bus_virtual_columns( - folder, logup_beta_eq, flag_start * (flag_add + flag_dot_product + flag_eq), aux_2, [idx_a, idx_b, idx_r] + evaluator, logup_beta_eq, flag_start * (flag_add + flag_dot_product + flag_eq), aux_2, [idx_a, idx_b, idx_r] ) for x in (flag_be, flag_start, flag_add, flag_dot_product, flag_eq): - folder.assert_bool(x) + evaluator.assert_bool(x) is_ee, not_start_sh = ONE - flag_be, ONE - flag_start_sh v_a_tilde = [v_a[0]] + [v_a[k] * is_ee for k in range(1, 5)] @@ -818,18 +818,18 @@ def eval_air_extension(folder: ConstraintFolder, logup_beta_eq: list[EF]) -> Non v_a_v_b = quintic_mul(v_a_tilde, v_b, ZERO) for k in range(5): - folder.assert_zero((acc[k] - (v_a_tilde[k] + v_b[k] + acc_tail[k])) * flag_add) + evaluator.assert_zero((acc[k] - (v_a_tilde[k] + v_b[k] + acc_tail[k])) * flag_add) for k in range(5): - folder.assert_zero((acc[k] - (v_a_v_b[k] + acc_tail[k])) * flag_dot_product) + evaluator.assert_zero((acc[k] - (v_a_v_b[k] + acc_tail[k])) * flag_dot_product) # eq: acc ← (2·v_a·v_b − v_a − v_b + 1) · (acc_tail or 1 at group end). e_eq = [2 * v_a_v_b[k] - v_a_tilde[k] - v_b[k] + (ONE if k == 0 else ZERO) for k in range(5)] acc_tail_or_one = [acc_sh[0] * not_start_sh + flag_start_sh] + [acc_sh[k] * not_start_sh for k in range(1, 5)] eq_result = quintic_mul(e_eq, acc_tail_or_one, ZERO) for k in range(5): - folder.assert_zero((acc[k] - eq_result[k]) * flag_eq) + evaluator.assert_zero((acc[k] - eq_result[k]) * flag_eq) for k in range(5): - folder.assert_zero((acc[k] - res[k]) * flag_start) + evaluator.assert_zero((acc[k] - res[k]) * flag_start) for x, y in [ (len_col, len_sh + ONE), @@ -838,11 +838,11 @@ def eval_air_extension(folder: ConstraintFolder, logup_beta_eq: list[EF]) -> Non (flag_dot_product, flag_dot_product_sh), (flag_eq, flag_eq_sh), ]: - folder.assert_zero(not_start_sh * (x - y)) + evaluator.assert_zero(not_start_sh * (x - y)) - folder.assert_zero(not_start_sh * (idx_a_sh - idx_a - (flag_be + is_ee * 5))) - folder.assert_zero(not_start_sh * (idx_b_sh - idx_b - 5)) - folder.assert_zero(flag_start_sh * (len_col - ONE)) + evaluator.assert_zero(not_start_sh * (idx_a_sh - idx_a - (flag_be + is_ee * 5))) + evaluator.assert_zero(not_start_sh * (idx_b_sh - idx_b - 5)) + evaluator.assert_zero(flag_start_sh * (len_col - ONE)) def _full_round(state: list[EF], rc1: list[Fp], rc2: list[Fp]) -> list[EF]: @@ -853,8 +853,8 @@ def _full_round(state: list[EF], rc1: list[Fp], rc2: list[Fp]) -> list[EF]: return state -def eval_air_poseidon16(folder: ConstraintFolder, logup_beta_eq: list[EF]) -> None: - c = folder.flat +def eval_air_poseidon16(evaluator: ConstraintEvaluator, logup_beta_eq: list[EF]) -> None: + c = evaluator.flat half_pairs = POSEIDON_HALF_FULL_ROUNDS // 2 multiplicity = c["multiplicity"] @@ -880,14 +880,14 @@ def eval_air_poseidon16(folder: ConstraintFolder, logup_beta_eq: list[EF]) -> No not_flag_left = ONE - flag_left nu_a = addr_left_hi - not_flag_left * (DIGEST_ELEMS // 2) - eval_precompile_bus_virtual_columns(folder, logup_beta_eq, multiplicity, domainsep, [nu_a, nu_b, nu_c]) + eval_precompile_bus_virtual_columns(evaluator, logup_beta_eq, multiplicity, domainsep, [nu_a, nu_b, nu_c]) for f in (multiplicity, flag_out4, flag_out8, flag_left, flag_permute): - folder.assert_bool(f) - folder.assert_zero(flag_permute * flag_out4) - folder.assert_zero(flag_out8 * flag_out4) - folder.assert_zero((ONE - flag_permute) * (ONE - flag_out8) * (ONE - flag_out4)) - folder.assert_zero(flag_left * (offset_left - addr_left_lo)) - folder.assert_zero(not_flag_left * (nu_a - addr_left_lo)) + evaluator.assert_bool(f) + evaluator.assert_zero(flag_permute * flag_out4) + evaluator.assert_zero(flag_out8 * flag_out4) + evaluator.assert_zero((ONE - flag_permute) * (ONE - flag_out8) * (ONE - flag_out4)) + evaluator.assert_zero(flag_left * (offset_left - addr_left_lo)) + evaluator.assert_zero(not_flag_left * (nu_a - addr_left_lo)) # --- Poseidon1-16 permutation AIR: each committed `post` row pins the intermediate # state then re-binds it, capping polynomial degree across the long round sequence. @@ -897,7 +897,7 @@ def eval_air_poseidon16(folder: ConstraintFolder, logup_beta_eq: list[EF]) -> No for r in range(half_pairs): state = _full_round(state, POSEIDON_AIR_INITIAL_CONSTANTS[2 * r], POSEIDON_AIR_INITIAL_CONSTANTS[2 * r + 1]) for i, post in enumerate(beginning_full_rounds[r]): - folder.assert_eq(state[i], post) + evaluator.assert_eq(state[i], post) state[i] = post # Transition into sparse partial-round form. @@ -906,7 +906,7 @@ def eval_air_poseidon16(folder: ConstraintFolder, logup_beta_eq: list[EF]) -> No # Partial rounds: one sbox on lane 0, then sparse mat-vec. for r in range(POSEIDON_PARTIAL_ROUNDS): - folder.assert_eq(state[0].cube(), partial_cols[r]) + evaluator.assert_eq(state[0].cube(), partial_cols[r]) state[0] = partial_cols[r] if r < POSEIDON_PARTIAL_ROUNDS - 1: state[0] += POSEIDON_AIR_SPARSE_SCALAR_RC[r] @@ -919,7 +919,7 @@ def eval_air_poseidon16(folder: ConstraintFolder, logup_beta_eq: list[EF]) -> No for r in range(half_pairs - 1): state = _full_round(state, POSEIDON_AIR_FINAL_CONSTANTS[2 * r], POSEIDON_AIR_FINAL_CONSTANTS[2 * r + 1]) for i, post in enumerate(ending_full_rounds[r]): - folder.assert_eq(state[i], post) + evaluator.assert_eq(state[i], post) state[i] = post # Last full round: compression feeds `inputs` forward into out_lo (permute does not). @@ -933,10 +933,10 @@ def eval_air_poseidon16(folder: ConstraintFolder, logup_beta_eq: list[EF]) -> No for i in range(POSEIDON_WIDTH // 2): value = state[i] + not_permute * inputs[i] if i < (DIGEST_ELEMS // 2): - folder.assert_zero(value - out_lo[i]) + evaluator.assert_zero(value - out_lo[i]) else: - folder.assert_zero(gate_lo_8 * (value - out_lo[i])) - folder.assert_zero(gate_hi * (state[i + POSEIDON_WIDTH // 2] - out_hi[i])) + evaluator.assert_zero(gate_lo_8 * (value - out_lo[i])) + evaluator.assert_zero(gate_hi * (state[i + POSEIDON_WIDTH // 2] - out_hi[i])) EXECUTION_COLUMNS = ( From a90200d23a0baaccb1c07a8d4380508ab6677cc3 Mon Sep 17 00:00:00 2001 From: Tom Wambsgans Date: Mon, 1 Jun 2026 19:24:42 +0400 Subject: [PATCH 10/31] wip --- .../lean_prover/python-verifier/verifier.py | 56 +++++++++---------- 1 file changed, 27 insertions(+), 29 deletions(-) diff --git a/crates/lean_prover/python-verifier/verifier.py b/crates/lean_prover/python-verifier/verifier.py index beff3e6a..53db3f44 100644 --- a/crates/lean_prover/python-verifier/verifier.py +++ b/crates/lean_prover/python-verifier/verifier.py @@ -99,7 +99,9 @@ def col(self, name: str) -> int: return self.columns.index(name) def eval_air(self, col_evals: Sequence[EF], alpha_powers: Sequence[EF], logup_beta_eq: list[EF]) -> EF: - constraint_evaluator = ConstraintEvaluator(col_evals[: self.n_columns], col_evals[self.n_columns :], alpha_powers, self.columns) + constraint_evaluator = ConstraintEvaluator( + col_evals[: self.n_columns], col_evals[self.n_columns :], alpha_powers, self.columns + ) self.air_constraints_fn(constraint_evaluator, logup_beta_eq) return constraint_evaluator.accumulator @@ -343,14 +345,6 @@ def whir_folding_factor_at_round(r: int) -> int: return WHIR_INITIAL_FOLDING_FACTOR if r == 0 else WHIR_SUBSEQUENT_FOLDING_FACTOR -def whir_n_rounds_and_final_sumcheck(num_variables: int) -> tuple[int, int]: - nv = num_variables - WHIR_INITIAL_FOLDING_FACTOR - if nv < WHIR_MAX_NUM_VARIABLES_TO_SEND_COEFFS: - return 0, nv - n = div_ceil(nv - WHIR_MAX_NUM_VARIABLES_TO_SEND_COEFFS, WHIR_SUBSEQUENT_FOLDING_FACTOR) - return n, nv - n * WHIR_SUBSEQUENT_FOLDING_FACTOR - - @dataclass class ParsedCommitment: num_variables: int @@ -380,19 +374,19 @@ def verify_sumcheck( point: list[EF] = [] for _ in range(n_rounds): coeffs = fiat_shamir.next_extension_scalars_vec(degree + 1) - s = coeffs[0] + sum(coeffs) + s = coeffs[0] + sum(coeffs) # s = h(0) + h(1) if s != target: raise ProofError("Sumcheck identity failed: h(0) + h(1) != target") fiat_shamir.check_pow_grinding(pow_bits) - r = fiat_shamir.sample_ef() - point.append(r) - target = eval_univariate_polynomial(coeffs, r) + challenge = fiat_shamir.sample_ef() + point.append(challenge) + target = eval_univariate_polynomial(coeffs, challenge) return point, target def verify_stir_challenges( fiat_shamir: FiatShamir, - round_index: int, + is_first_round: int, log_height: int, num_variables: int, num_queries: int, @@ -408,7 +402,7 @@ def verify_stir_challenges( op = fiat_shamir.next_merkle_opening() merkle_verify_path(commitment.root, log_height, idx, op.leaf_data, op.path) # Round 0 leaves are raw base-field elements; later rounds pack DIM Fp values per EF element. - packed = op.leaf_data if round_index == 0 else pack_ef(op.leaf_data) + packed = op.leaf_data if is_first_round else pack_ef(op.leaf_data) fold = eval_multilinear_evals(packed, folding_randomness) ef_pt = EF(pow(int(gen.value), idx, P)) pt = expand_from_univariate(ef_pt, num_variables) @@ -422,13 +416,14 @@ def whir_verify( parsed_commitment: ParsedCommitment, statements: list[SparseStatements], ) -> list[EF]: - n_rounds, final_sumcheck_rounds = whir_n_rounds_and_final_sumcheck(cfg["num_variables"]) + nv = cfg["num_variables"] - WHIR_INITIAL_FOLDING_FACTOR + assert nv >= WHIR_MAX_NUM_VARIABLES_TO_SEND_COEFFS + n_rounds = div_ceil(nv - WHIR_MAX_NUM_VARIABLES_TO_SEND_COEFFS, WHIR_SUBSEQUENT_FOLDING_FACTOR) + final_sumcheck_rounds = nv - n_rounds * WHIR_SUBSEQUENT_FOLDING_FACTOR round_constraints: list[tuple[list[EF], list[SparseStatements]]] = [] round_folding: list[list[EF]] = [] - target = ZERO - def step(constraints: list[SparseStatements], n_fold: int, pow_bits: int) -> None: - nonlocal target + def step(target: EF, constraints: list[SparseStatements], n_fold: int, pow_bits: int) -> EF: fiat_shamir.duplex() gamma = fiat_shamir.sample_ef() combo: list[EF] = [] @@ -441,8 +436,10 @@ def step(constraints: list[SparseStatements], n_fold: int, pow_bits: int) -> Non round_constraints.append((combo, constraints)) sc_point, target = verify_sumcheck(fiat_shamir, target, n_fold, 2, pow_bits) round_folding.append(sc_point) + return target - step( + target = step( + ZERO, parsed_commitment.oods_constraints() + statements, whir_folding_factor_at_round(0), cfg["starting_folding_pow_bits"], @@ -451,34 +448,35 @@ def step(constraints: list[SparseStatements], n_fold: int, pow_bits: int) -> Non prev_commitment = parsed_commitment current_vars = cfg["num_variables"] log_domain = cfg["num_variables"] + cfg["log_inv_rate"] - for r in range(n_rounds): - round_params = cfg["rounds"][r] - current_vars -= whir_folding_factor_at_round(r) + for round in range(n_rounds): + round_params = cfg["rounds"][round] + current_vars -= whir_folding_factor_at_round(round) n_ood_samples = round_params["ood_samples"] new_commitment = ParsedCommitment.read(fiat_shamir, current_vars, n_ood_samples) stir = verify_stir_challenges( fiat_shamir, - r, - log_domain - whir_folding_factor_at_round(r), + round == 0, + log_domain - whir_folding_factor_at_round(round), current_vars, round_params["num_queries"], round_params["query_pow_bits"], prev_commitment, round_folding[-1], ) - step( + target = step( + target, new_commitment.oods_constraints() + stir, - whir_folding_factor_at_round(r + 1), + whir_folding_factor_at_round(round + 1), round_params["folding_pow_bits"], ) - log_domain -= RS_DOMAIN_INITIAL_REDUCTION_FACTOR if r == 0 else 1 + log_domain -= RS_DOMAIN_INITIAL_REDUCTION_FACTOR if round == 0 else 1 prev_commitment = new_commitment n_vars_final = current_vars - whir_folding_factor_at_round(n_rounds) final_coeffs = fiat_shamir.next_extension_scalars_vec(1 << n_vars_final) final_stir = verify_stir_challenges( fiat_shamir, - n_rounds, + False, log_domain - whir_folding_factor_at_round(n_rounds), n_vars_final, cfg["final_queries"], From 9857573d2a4726deed73441e6dca70137605014a Mon Sep 17 00:00:00 2001 From: Tom Wambsgans Date: Mon, 1 Jun 2026 20:58:10 +0400 Subject: [PATCH 11/31] wip --- .../lean_prover/python-verifier/primitives.py | 2 +- .../lean_prover/python-verifier/verifier.py | 45 +++++++++---------- 2 files changed, 22 insertions(+), 25 deletions(-) diff --git a/crates/lean_prover/python-verifier/primitives.py b/crates/lean_prover/python-verifier/primitives.py index 0e1c8b95..ec293371 100644 --- a/crates/lean_prover/python-verifier/primitives.py +++ b/crates/lean_prover/python-verifier/primitives.py @@ -142,7 +142,7 @@ def ef_powers(x: EF, n: int) -> list[EF]: return list(accumulate(repeat(x, n), lambda a, _: a * x, initial=ONE))[:n] -def pack_ef(flat: Sequence[Fp]) -> list[EF]: +def embed_ef(flat: Sequence[Fp]) -> list[EF]: """Pack a length-(n·DIM) Fp vector into n EF elements (5 Fp coordinates per EF).""" return [EF(flat[i : i + EF.DIMENSION]) for i in range(0, len(flat), EF.DIMENSION)] diff --git a/crates/lean_prover/python-verifier/verifier.py b/crates/lean_prover/python-verifier/verifier.py index 53db3f44..77194304 100644 --- a/crates/lean_prover/python-verifier/verifier.py +++ b/crates/lean_prover/python-verifier/verifier.py @@ -163,7 +163,7 @@ def _sample_many(self, n: int) -> list[Fp]: def sample_many_ef(self, n: int) -> list[EF]: flat = self._sample_many(div_ceil(n * EF.DIMENSION, SPONGE_RATE))[: n * EF.DIMENSION] - return pack_ef(flat) + return embed_ef(flat) def sample_ef(self) -> EF: return self.sample_many_ef(1)[0] @@ -212,7 +212,7 @@ def next_base_scalars_vec(self, n: int) -> list[Fp]: def next_extension_scalars_vec(self, n: int) -> list[EF]: flat = self.next_base_scalars_vec(n * EF.DIMENSION) - return pack_ef(flat) + return embed_ef(flat) def next_extension_scalar(self) -> EF: return self.next_extension_scalars_vec(1)[0] @@ -332,17 +332,17 @@ def eval_eq(point: Sequence[EF]) -> list[EF]: @dataclass class SparseStatements: total_num_variables: int - point: list[EF] - values: list[tuple[int, EF]] - is_next: bool = False + point: list[EF] # low-bits variables (suffix), shared by every entry in `values` + values: list[tuple[int, EF]] # (selector_index, eval): poly(high bits = selector_index, low bits = point) == eval + is_next: bool = False # if set, the low-variable part uses the shifted "next-row" MLE instead of plain eq @property def selector_num_variables(self) -> int: - return self.total_num_variables - len(self.point) + return self.total_num_variables - len(self.point) # count of high/selector bits that selector_index spans -def whir_folding_factor_at_round(r: int) -> int: - return WHIR_INITIAL_FOLDING_FACTOR if r == 0 else WHIR_SUBSEQUENT_FOLDING_FACTOR +def whir_folding_factor_at_round(round: int) -> int: + return WHIR_INITIAL_FOLDING_FACTOR if round == 0 else WHIR_SUBSEQUENT_FOLDING_FACTOR @dataclass @@ -401,12 +401,11 @@ def verify_stir_challenges( for idx in indices: op = fiat_shamir.next_merkle_opening() merkle_verify_path(commitment.root, log_height, idx, op.leaf_data, op.path) - # Round 0 leaves are raw base-field elements; later rounds pack DIM Fp values per EF element. - packed = op.leaf_data if is_first_round else pack_ef(op.leaf_data) + # Round 0 leaves are raw base-field elements; later rounds embed DIM Fp values per EF element. + packed = op.leaf_data if is_first_round else embed_ef(op.leaf_data) fold = eval_multilinear_evals(packed, folding_randomness) - ef_pt = EF(pow(int(gen.value), idx, P)) - pt = expand_from_univariate(ef_pt, num_variables) - constraints.append(SparseStatements(num_variables, pt, [(0, fold)])) + point = expand_from_univariate(EF(pow(int(gen.value), idx, P)), num_variables) + constraints.append(SparseStatements(num_variables, point, [(0, fold)])) return constraints @@ -420,20 +419,18 @@ def whir_verify( assert nv >= WHIR_MAX_NUM_VARIABLES_TO_SEND_COEFFS n_rounds = div_ceil(nv - WHIR_MAX_NUM_VARIABLES_TO_SEND_COEFFS, WHIR_SUBSEQUENT_FOLDING_FACTOR) final_sumcheck_rounds = nv - n_rounds * WHIR_SUBSEQUENT_FOLDING_FACTOR - round_constraints: list[tuple[list[EF], list[SparseStatements]]] = [] + round_constraints: list[tuple[EF, list[SparseStatements]]] = [] round_folding: list[list[EF]] = [] def step(target: EF, constraints: list[SparseStatements], n_fold: int, pow_bits: int) -> EF: fiat_shamir.duplex() gamma = fiat_shamir.sample_ef() - combo: list[EF] = [] - g = ONE + gamma_power = ONE for smt in constraints: for _, value in smt.values: - target += g * value - combo.append(g) - g *= gamma - round_constraints.append((combo, constraints)) + target += gamma_power * value + gamma_power *= gamma + round_constraints.append((gamma, constraints)) sc_point, target = verify_sumcheck(fiat_shamir, target, n_fold, 2, pow_bits) round_folding.append(sc_point) return target @@ -497,18 +494,18 @@ def step(target: EF, constraints: list[SparseStatements], n_fold: int, pow_bits: eval_weights = ZERO pt = folding_flat - for round_idx, (randomness, smts) in enumerate(round_constraints): + for round_idx, (gamma, smts) in enumerate(round_constraints): if round_idx > 0: pt = pt[whir_folding_factor_at_round(round_idx - 1) :] - i = 0 + gamma_power = ONE for smt in smts: inner_pt = pt[len(pt) - len(smt.point) :] common = next_mle(smt.point, inner_pt) if smt.is_next else eq_poly(smt.point, inner_pt) sel_n = smt.selector_num_variables for v in smt.values: lagrange = eq_at_index(pt, v[0], sel_n) - eval_weights += lagrange * common * randomness[i] - i += 1 + eval_weights += lagrange * common * gamma_power + gamma_power *= gamma final_value = eval_multilinear_coeffs(final_coeffs, list(reversed(final_sc_point))) if final_sc_value != eval_weights * final_value: raise ProofError("WHIR final sumcheck check failed") From 527793818a67d892673e2ddab2c294bfaf67eedf Mon Sep 17 00:00:00 2001 From: Tom Wambsgans Date: Tue, 2 Jun 2026 03:11:05 +0400 Subject: [PATCH 12/31] wip --- .../lean_prover/python-verifier/verifier.py | 74 +++++++------------ 1 file changed, 27 insertions(+), 47 deletions(-) diff --git a/crates/lean_prover/python-verifier/verifier.py b/crates/lean_prover/python-verifier/verifier.py index 77194304..48204975 100644 --- a/crates/lean_prover/python-verifier/verifier.py +++ b/crates/lean_prover/python-verifier/verifier.py @@ -31,11 +31,10 @@ "num_variables": c[1], "commitment_ood_samples": c[2], "starting_folding_pow_bits": c[3], - "final_queries": c[4], - "final_query_pow_bits": c[5], "rounds": [ {"num_queries": r[0], "ood_samples": r[1], "query_pow_bits": r[2], "folding_pow_bits": r[3]} for r in c[6] - ], + ] + + [{"num_queries": c[4], "query_pow_bits": c[5]}], } for c in _WHIR_CONFIGS } @@ -422,7 +421,14 @@ def whir_verify( round_constraints: list[tuple[EF, list[SparseStatements]]] = [] round_folding: list[list[EF]] = [] - def step(target: EF, constraints: list[SparseStatements], n_fold: int, pow_bits: int) -> EF: + current_vars = cfg["num_variables"] + log_domain = current_vars + cfg["log_inv_rate"] + target = ZERO + constraints = parsed_commitment.oods_constraints() + statements + fold_pow_bits = cfg["starting_folding_pow_bits"] + for round in range(n_rounds + 1): + round_params = cfg["rounds"][round] + folding_factor = whir_folding_factor_at_round(round) fiat_shamir.duplex() gamma = fiat_shamir.sample_ef() gamma_power = ONE @@ -431,58 +437,32 @@ def step(target: EF, constraints: list[SparseStatements], n_fold: int, pow_bits: target += gamma_power * value gamma_power *= gamma round_constraints.append((gamma, constraints)) - sc_point, target = verify_sumcheck(fiat_shamir, target, n_fold, 2, pow_bits) + sc_point, target = verify_sumcheck(fiat_shamir, target, folding_factor, 2, fold_pow_bits) round_folding.append(sc_point) - return target - - target = step( - ZERO, - parsed_commitment.oods_constraints() + statements, - whir_folding_factor_at_round(0), - cfg["starting_folding_pow_bits"], - ) - - prev_commitment = parsed_commitment - current_vars = cfg["num_variables"] - log_domain = cfg["num_variables"] + cfg["log_inv_rate"] - for round in range(n_rounds): - round_params = cfg["rounds"][round] - current_vars -= whir_folding_factor_at_round(round) - n_ood_samples = round_params["ood_samples"] - new_commitment = ParsedCommitment.read(fiat_shamir, current_vars, n_ood_samples) - stir = verify_stir_challenges( + current_vars -= folding_factor + is_final = round == n_rounds + if is_final: + final_coeffs = fiat_shamir.next_extension_scalars_vec(1 << current_vars) + else: + new_commitment = ParsedCommitment.read(fiat_shamir, current_vars, round_params["ood_samples"]) + stir_constraints = verify_stir_challenges( fiat_shamir, round == 0, - log_domain - whir_folding_factor_at_round(round), + log_domain - folding_factor, current_vars, round_params["num_queries"], round_params["query_pow_bits"], - prev_commitment, + parsed_commitment, round_folding[-1], ) - target = step( - target, - new_commitment.oods_constraints() + stir, - whir_folding_factor_at_round(round + 1), - round_params["folding_pow_bits"], - ) + if is_final: + final_stir_constraints = stir_constraints + break + constraints = new_commitment.oods_constraints() + stir_constraints + fold_pow_bits = round_params["folding_pow_bits"] log_domain -= RS_DOMAIN_INITIAL_REDUCTION_FACTOR if round == 0 else 1 - prev_commitment = new_commitment - - n_vars_final = current_vars - whir_folding_factor_at_round(n_rounds) - final_coeffs = fiat_shamir.next_extension_scalars_vec(1 << n_vars_final) - final_stir = verify_stir_challenges( - fiat_shamir, - False, - log_domain - whir_folding_factor_at_round(n_rounds), - n_vars_final, - cfg["final_queries"], - cfg["final_query_pow_bits"], - prev_commitment, - round_folding[-1], - ) - # Each STIR constraint's point is `expand_from_univariate(α, n)` = [α, α², α⁴, …]. We check that `Σ coeffs[i]·α^i == value` for each smt - for smt in final_stir: + parsed_commitment = new_commitment + for smt in final_stir_constraints: univ_eval = eval_univariate_polynomial(final_coeffs, smt.point[0]) if any(univ_eval != v[1] for v in smt.values): raise ProofError("Final STIR constraint mismatch") From e69b48ce3d77eeb624cea5d40e1e317667b5ed23 Mon Sep 17 00:00:00 2001 From: Tom Wambsgans Date: Tue, 2 Jun 2026 16:52:44 +0400 Subject: [PATCH 13/31] w --- .../lean_prover/python-verifier/verifier.py | 30 ++++++++----------- 1 file changed, 13 insertions(+), 17 deletions(-) diff --git a/crates/lean_prover/python-verifier/verifier.py b/crates/lean_prover/python-verifier/verifier.py index 48204975..a4c16b8d 100644 --- a/crates/lean_prover/python-verifier/verifier.py +++ b/crates/lean_prover/python-verifier/verifier.py @@ -413,13 +413,13 @@ def whir_verify( cfg: dict, parsed_commitment: ParsedCommitment, statements: list[SparseStatements], -) -> list[EF]: +): nv = cfg["num_variables"] - WHIR_INITIAL_FOLDING_FACTOR assert nv >= WHIR_MAX_NUM_VARIABLES_TO_SEND_COEFFS n_rounds = div_ceil(nv - WHIR_MAX_NUM_VARIABLES_TO_SEND_COEFFS, WHIR_SUBSEQUENT_FOLDING_FACTOR) final_sumcheck_rounds = nv - n_rounds * WHIR_SUBSEQUENT_FOLDING_FACTOR round_constraints: list[tuple[EF, list[SparseStatements]]] = [] - round_folding: list[list[EF]] = [] + round_folding_challenges: list[list[EF]] = [] current_vars = cfg["num_variables"] log_domain = current_vars + cfg["log_inv_rate"] @@ -438,7 +438,7 @@ def whir_verify( gamma_power *= gamma round_constraints.append((gamma, constraints)) sc_point, target = verify_sumcheck(fiat_shamir, target, folding_factor, 2, fold_pow_bits) - round_folding.append(sc_point) + round_folding_challenges.append(sc_point) current_vars -= folding_factor is_final = round == n_rounds if is_final: @@ -453,7 +453,7 @@ def whir_verify( round_params["num_queries"], round_params["query_pow_bits"], parsed_commitment, - round_folding[-1], + round_folding_challenges[-1], ) if is_final: final_stir_constraints = stir_constraints @@ -468,30 +468,26 @@ def whir_verify( raise ProofError("Final STIR constraint mismatch") final_sc_point, final_sc_value = verify_sumcheck(fiat_shamir, target, final_sumcheck_rounds, 2) - round_folding.append(final_sc_point) - - folding_flat = [r for chunk in round_folding for r in chunk] + round_folding_challenges.append(final_sc_point) eval_weights = ZERO - pt = folding_flat - for round_idx, (gamma, smts) in enumerate(round_constraints): - if round_idx > 0: - pt = pt[whir_folding_factor_at_round(round_idx - 1) :] + folding_challenges = [r for chunk in round_folding_challenges for r in chunk] + for round, (gamma, smts) in enumerate(round_constraints): + if round > 0: + folding_challenges = folding_challenges[whir_folding_factor_at_round(round - 1) :] gamma_power = ONE for smt in smts: - inner_pt = pt[len(pt) - len(smt.point) :] - common = next_mle(smt.point, inner_pt) if smt.is_next else eq_poly(smt.point, inner_pt) + point_suffix = folding_challenges[len(folding_challenges) - len(smt.point) :] # dense part of the point + eval_suffix = next_mle(smt.point, point_suffix) if smt.is_next else eq_poly(smt.point, point_suffix) sel_n = smt.selector_num_variables for v in smt.values: - lagrange = eq_at_index(pt, v[0], sel_n) - eval_weights += lagrange * common * gamma_power + eval_prefix = eq_at_index(folding_challenges, v[0], sel_n) # sparse part of the point + eval_weights += eval_prefix * eval_suffix * gamma_power gamma_power *= gamma final_value = eval_multilinear_coeffs(final_coeffs, list(reversed(final_sc_point))) if final_sc_value != eval_weights * final_value: raise ProofError("WHIR final sumcheck check failed") - return folding_flat - def stacked_pcs_global_statements( stacked_n_vars: int, From 87b82725da0f1d65cd9149ee695607fea3877342 Mon Sep 17 00:00:00 2001 From: Tom Wambsgans Date: Tue, 2 Jun 2026 17:41:41 +0400 Subject: [PATCH 14/31] w --- crates/lean_prover/python-verifier/verifier.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/crates/lean_prover/python-verifier/verifier.py b/crates/lean_prover/python-verifier/verifier.py index a4c16b8d..d19fcea7 100644 --- a/crates/lean_prover/python-verifier/verifier.py +++ b/crates/lean_prover/python-verifier/verifier.py @@ -105,14 +105,14 @@ def eval_air(self, col_evals: Sequence[EF], alpha_powers: Sequence[EF], logup_be return constraint_evaluator.accumulator def boundary_statements( - self, stacked_n_vars: int, offset: int, n_vars: int, ending_pc: int + self, stacked_n_vars: int, offset: int, log_n_rows: int, ending_pc: int ) -> list["SparseStatements"]: if self.name != "execution": return [] - pc_col_offset = offset + (self.col("pc") << n_vars) + pc_col_offset = offset + (self.col("pc") << log_n_rows) return [ SparseStatements(stacked_n_vars, [], [(pc_col_offset + idx, EF(pc))]) - for idx, pc in [(0, STARTING_PC), ((1 << n_vars) - 1, ending_pc)] + for idx, pc in [(0, STARTING_PC), ((1 << log_n_rows) - 1, ending_pc)] ] From 6ad48893a0426339c091a9fdf1fe27677c64943e Mon Sep 17 00:00:00 2001 From: Tom Wambsgans Date: Wed, 3 Jun 2026 15:24:33 +0400 Subject: [PATCH 15/31] wip --- .../lean_prover/python-verifier/verifier.py | 34 +++++++++---------- 1 file changed, 16 insertions(+), 18 deletions(-) diff --git a/crates/lean_prover/python-verifier/verifier.py b/crates/lean_prover/python-verifier/verifier.py index d19fcea7..d1c785be 100644 --- a/crates/lean_prover/python-verifier/verifier.py +++ b/crates/lean_prover/python-verifier/verifier.py @@ -345,14 +345,14 @@ def whir_folding_factor_at_round(round: int) -> int: @dataclass -class ParsedCommitment: +class WhirCommitment: num_variables: int root: list[Fp] ood_points: list[EF] ood_answers: list[EF] @classmethod - def read(cls, fs: "FiatShamir", num_variables: int, n_ood: int) -> "ParsedCommitment": + def read(cls, fs: "FiatShamir", num_variables: int, n_ood: int) -> "WhirCommitment": return cls( num_variables, fs.next_base_scalars_vec(DIGEST_ELEMS), @@ -390,7 +390,7 @@ def verify_stir_challenges( num_variables: int, num_queries: int, query_pow_bits: int, - commitment: ParsedCommitment, + commitment: WhirCommitment, folding_randomness: list[EF], ) -> list[SparseStatements]: gen = Fp(KB_TWO_ADIC_GENERATORS[log_height]) @@ -408,23 +408,21 @@ def verify_stir_challenges( return constraints -def whir_verify( +def verify_whir( fiat_shamir: FiatShamir, cfg: dict, - parsed_commitment: ParsedCommitment, + commitment: WhirCommitment, statements: list[SparseStatements], ): - nv = cfg["num_variables"] - WHIR_INITIAL_FOLDING_FACTOR - assert nv >= WHIR_MAX_NUM_VARIABLES_TO_SEND_COEFFS - n_rounds = div_ceil(nv - WHIR_MAX_NUM_VARIABLES_TO_SEND_COEFFS, WHIR_SUBSEQUENT_FOLDING_FACTOR) - final_sumcheck_rounds = nv - n_rounds * WHIR_SUBSEQUENT_FOLDING_FACTOR + current_vars = cfg["num_variables"] + num_vars_after_1_round = current_vars - WHIR_INITIAL_FOLDING_FACTOR + assert num_vars_after_1_round >= WHIR_MAX_NUM_VARIABLES_TO_SEND_COEFFS + n_rounds = div_ceil(num_vars_after_1_round - WHIR_MAX_NUM_VARIABLES_TO_SEND_COEFFS, WHIR_SUBSEQUENT_FOLDING_FACTOR) round_constraints: list[tuple[EF, list[SparseStatements]]] = [] round_folding_challenges: list[list[EF]] = [] - - current_vars = cfg["num_variables"] log_domain = current_vars + cfg["log_inv_rate"] target = ZERO - constraints = parsed_commitment.oods_constraints() + statements + constraints = commitment.oods_constraints() + statements fold_pow_bits = cfg["starting_folding_pow_bits"] for round in range(n_rounds + 1): round_params = cfg["rounds"][round] @@ -444,7 +442,7 @@ def whir_verify( if is_final: final_coeffs = fiat_shamir.next_extension_scalars_vec(1 << current_vars) else: - new_commitment = ParsedCommitment.read(fiat_shamir, current_vars, round_params["ood_samples"]) + new_commitment = WhirCommitment.read(fiat_shamir, current_vars, round_params["ood_samples"]) stir_constraints = verify_stir_challenges( fiat_shamir, round == 0, @@ -452,7 +450,7 @@ def whir_verify( current_vars, round_params["num_queries"], round_params["query_pow_bits"], - parsed_commitment, + commitment, round_folding_challenges[-1], ) if is_final: @@ -461,13 +459,13 @@ def whir_verify( constraints = new_commitment.oods_constraints() + stir_constraints fold_pow_bits = round_params["folding_pow_bits"] log_domain -= RS_DOMAIN_INITIAL_REDUCTION_FACTOR if round == 0 else 1 - parsed_commitment = new_commitment + commitment = new_commitment for smt in final_stir_constraints: univ_eval = eval_univariate_polynomial(final_coeffs, smt.point[0]) if any(univ_eval != v[1] for v in smt.values): raise ProofError("Final STIR constraint mismatch") - final_sc_point, final_sc_value = verify_sumcheck(fiat_shamir, target, final_sumcheck_rounds, 2) + final_sc_point, final_sc_value = verify_sumcheck(fiat_shamir, target, current_vars, 2) round_folding_challenges.append(final_sc_point) eval_weights = ZERO @@ -1028,7 +1026,7 @@ def verify_execution( raise ProofError("InvalidProof: stacked_n_vars exceeds WHIR domain bound") cfg = WHIR_CONFIGS[(log_inv_rate, stacked_n_vars)] nood = cfg["commitment_ood_samples"] - parsed_commitment = ParsedCommitment.read(state, stacked_n_vars, nood) + parsed_commitment = WhirCommitment.read(state, stacked_n_vars, nood) logup_gamma = state.sample_ef() # the quotient denominator state.duplex() @@ -1100,7 +1098,7 @@ def verify_execution( committed, ending_pc, ) - whir_verify(state, cfg, parsed_commitment, global_statements) + verify_whir(state, cfg, parsed_commitment, global_statements) if state.offset != len(state.transcript): raise ProofError( From 8f8fcc09a63cd6c976e7a50d31c0c5f6e04d7949 Mon Sep 17 00:00:00 2001 From: Tom Wambsgans Date: Wed, 3 Jun 2026 15:28:06 +0400 Subject: [PATCH 16/31] wip --- crates/lean_prover/python-verifier/verifier.py | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/crates/lean_prover/python-verifier/verifier.py b/crates/lean_prover/python-verifier/verifier.py index d1c785be..d5fa7ada 100644 --- a/crates/lean_prover/python-verifier/verifier.py +++ b/crates/lean_prover/python-verifier/verifier.py @@ -419,7 +419,7 @@ def verify_whir( assert num_vars_after_1_round >= WHIR_MAX_NUM_VARIABLES_TO_SEND_COEFFS n_rounds = div_ceil(num_vars_after_1_round - WHIR_MAX_NUM_VARIABLES_TO_SEND_COEFFS, WHIR_SUBSEQUENT_FOLDING_FACTOR) round_constraints: list[tuple[EF, list[SparseStatements]]] = [] - round_folding_challenges: list[list[EF]] = [] + folding_challenges: list[EF] = [] log_domain = current_vars + cfg["log_inv_rate"] target = ZERO constraints = commitment.oods_constraints() + statements @@ -436,7 +436,7 @@ def verify_whir( gamma_power *= gamma round_constraints.append((gamma, constraints)) sc_point, target = verify_sumcheck(fiat_shamir, target, folding_factor, 2, fold_pow_bits) - round_folding_challenges.append(sc_point) + folding_challenges += sc_point current_vars -= folding_factor is_final = round == n_rounds if is_final: @@ -451,7 +451,7 @@ def verify_whir( round_params["num_queries"], round_params["query_pow_bits"], commitment, - round_folding_challenges[-1], + folding_challenges[-folding_factor:], ) if is_final: final_stir_constraints = stir_constraints @@ -466,10 +466,9 @@ def verify_whir( raise ProofError("Final STIR constraint mismatch") final_sc_point, final_sc_value = verify_sumcheck(fiat_shamir, target, current_vars, 2) - round_folding_challenges.append(final_sc_point) + folding_challenges += final_sc_point eval_weights = ZERO - folding_challenges = [r for chunk in round_folding_challenges for r in chunk] for round, (gamma, smts) in enumerate(round_constraints): if round > 0: folding_challenges = folding_challenges[whir_folding_factor_at_round(round - 1) :] From 11fae2086a7c491c4b5edd877655546dc1213f80 Mon Sep 17 00:00:00 2001 From: Tom Wambsgans Date: Wed, 3 Jun 2026 15:54:12 +0400 Subject: [PATCH 17/31] wip --- .../lean_prover/python-verifier/verifier.py | 8 +++--- .../lean_prover/tests/check_whir_configs.rs | 27 +++++++++++++------ 2 files changed, 22 insertions(+), 13 deletions(-) diff --git a/crates/lean_prover/python-verifier/verifier.py b/crates/lean_prover/python-verifier/verifier.py index d5fa7ada..b9489f22 100644 --- a/crates/lean_prover/python-verifier/verifier.py +++ b/crates/lean_prover/python-verifier/verifier.py @@ -24,17 +24,16 @@ WHIR_INITIAL_FOLDING_FACTOR, WHIR_SUBSEQUENT_FOLDING_FACTOR, WHIR_MAX_NUM_VARIABLES_TO_SEND_COEFFS = 7, 5, 8 MIN_WHIR_LOG_INV_RATE, MAX_WHIR_LOG_INV_RATE, RS_DOMAIN_INITIAL_REDUCTION_FACTOR = 1, 4, 5 -_WHIR_CONFIGS = ((1,7,1,10,220,16,()),(1,8,1,11,220,16,()),(1,9,1,12,220,16,()),(1,10,1,13,220,16,()),(1,11,1,14,220,16,()),(1,12,1,15,220,16,()),(1,13,1,16,220,16,()),(1,14,1,15,221,16,()),(1,15,1,16,221,16,()),(1,16,1,16,73,16,((222,1,16,11),)),(1,17,1,16,73,16,((223,1,16,12),)),(1,18,1,16,73,16,((224,1,16,13),)),(1,19,1,16,73,16,((225,1,16,14),)),(1,20,1,16,73,16,((227,1,16,15),)),(1,21,2,16,32,16,((229,1,16,16),(73,1,16,9))),(1,22,2,16,32,16,((230,1,16,12),(74,1,16,10))),(1,23,2,16,32,16,((234,1,16,13),(74,1,16,11))),(1,24,2,16,32,16,((235,1,16,14),(74,1,16,12))),(1,25,2,16,32,16,((241,2,16,15),(74,2,16,13))),(1,26,2,16,21,14,((243,2,16,16),(74,2,16,14),(32,2,16,14))),(1,27,2,16,21,14,((248,2,16,15),(75,2,16,15),(32,2,16,15))),(1,28,2,16,21,14,((256,2,16,16),(75,2,16,16),(32,2,16,16))),(1,29,2,16,21,14,((262,2,16,15),(76,2,16,12),(33,2,16,17))),(1,30,2,16,21,14,((270,2,16,16),(76,2,16,13),(33,2,16,18))),(2,7,1,13,109,16,()),(2,8,1,14,109,16,()),(2,9,1,15,109,16,()),(2,10,1,16,109,16,()),(2,11,1,12,110,16,()),(2,12,1,13,110,16,()),(2,13,1,14,110,16,()),(2,14,1,15,110,16,()),(2,15,1,16,110,16,()),(2,16,1,14,55,16,((111,1,16,10),)),(2,17,1,15,55,16,((111,1,16,11),)),(2,18,1,16,55,16,((111,1,16,12),)),(2,19,1,15,55,16,((112,1,16,13),)),(2,20,2,16,55,16,((112,1,16,14),)),(2,21,2,16,28,16,((113,1,16,15),(55,1,16,10))),(2,22,2,15,28,16,((114,1,16,16),(55,1,16,11))),(2,23,2,16,28,16,((114,1,16,13),(56,1,16,12))),(2,24,2,16,28,16,((115,1,16,14),(56,2,16,13))),(2,25,2,15,28,16,((118,2,16,15),(56,2,16,14))),(2,26,2,16,19,15,((118,2,16,16),(56,2,16,15),(28,2,16,17))),(2,27,2,16,19,15,((119,2,16,13),(57,2,16,16),(28,2,16,18))),(2,28,2,16,19,15,((120,2,16,14),(57,2,16,14),(29,2,15,19))),(2,29,2,16,19,15,((123,2,16,15),(57,2,16,15),(29,2,15,20))),(3,7,1,9,73,16,()),(3,8,1,10,73,16,()),(3,9,1,11,73,16,()),(3,10,1,12,73,16,()),(3,11,1,13,73,16,()),(3,12,1,14,73,16,()),(3,13,1,15,73,16,()),(3,14,1,16,73,16,()),(3,15,1,12,74,16,()),(3,16,1,13,44,16,((74,1,16,11),)),(3,17,1,14,44,16,((74,1,16,12),)),(3,18,2,15,44,16,((74,1,16,13),)),(3,19,2,16,44,16,((74,1,16,14),)),(3,20,2,15,44,16,((75,1,16,15),)),(3,21,2,16,25,16,((75,1,16,16),(44,1,16,11))),(3,22,2,15,25,16,((76,1,16,11),(45,1,16,12))),(3,23,2,16,25,16,((76,1,16,12),(45,2,16,13))),(3,24,2,16,25,16,((77,2,16,13),(45,2,16,14))),(3,25,2,16,25,16,((78,2,15,14),(45,2,16,15))),(3,26,2,16,18,12,((79,2,15,15),(45,2,16,16),(25,2,16,19))),(3,27,2,16,18,12,((80,2,16,16),(45,2,16,15),(26,2,13,20))),(3,28,2,15,18,12,((82,2,15,15),(46,2,16,16),(26,2,13,21))),(4,7,1,8,55,16,()),(4,8,1,9,55,16,()),(4,9,1,10,55,16,()),(4,10,1,11,55,16,()),(4,11,1,12,55,16,()),(4,12,1,13,55,16,()),(4,13,1,14,55,16,()),(4,14,1,15,55,16,()),(4,15,1,16,55,16,()),(4,16,1,13,37,16,((56,1,16,9),)),(4,17,1,14,37,16,((56,1,16,10),)),(4,18,2,15,37,16,((56,1,16,11),)),(4,19,2,16,37,16,((56,1,16,12),)),(4,20,2,13,37,16,((57,1,16,13),)),(4,21,2,14,23,15,((57,2,16,14),(37,2,16,12))),(4,22,2,15,23,15,((57,2,16,15),(37,2,16,13))),(4,23,2,16,23,15,((57,2,16,16),(37,2,16,14))),(4,24,2,15,23,15,((58,2,16,13),(38,2,16,15))),(4,25,2,16,23,15,((58,2,16,14),(38,2,16,16))),(4,26,2,16,16,16,((60,2,15,15),(38,2,16,17),(23,2,15,22))),(4,27,2,15,16,16,((61,2,16,16),(38,2,16,18),(23,2,15,23)))) # fmt: skip +_WHIR_CONFIGS = WHIR_CONFIGS = [(1,7,1,10,220,16,[]), (1,8,1,11,220,16,[]), (1,9,1,12,220,16,[]), (1,10,1,13,220,16,[]), (1,11,1,14,220,16,[]), (1,12,1,15,220,16,[]), (1,13,1,16,220,16,[]), (1,14,1,15,221,16,[]), (1,15,1,16,221,16,[]), (1,16,1,11,73,16,[(222,1,16,16), ]), (1,17,1,12,73,16,[(223,1,16,16), ]), (1,18,1,13,73,16,[(224,1,16,16), ]), (1,19,1,14,73,16,[(225,1,16,16), ]), (1,20,1,15,73,16,[(227,1,16,16), ]), (1,21,2,9,32,16,[(229,1,16,16), (73,1,16,16)]), (1,22,2,10,32,16,[(230,1,16,16), (74,1,16,12)]), (1,23,2,11,32,16,[(234,1,16,16), (74,1,16,13)]), (1,24,2,12,32,16,[(235,1,16,16), (74,1,16,14)]), (1,25,2,13,32,16,[(241,2,16,16), (74,2,16,15)]), (1,26,2,14,21,14,[(243,2,16,16), (74,2,16,16), (32,2,16,14)]), (1,27,2,15,21,14,[(248,2,16,16), (75,2,16,15), (32,2,16,15)]), (1,28,2,16,21,14,[(256,2,16,16), (75,2,16,16), (32,2,16,16)]), (1,29,2,17,21,14,[(262,2,16,16), (76,2,16,15), (33,2,16,12)]), (1,30,2,18,21,14,[(270,2,16,16), (76,2,16,16), (33,2,16,13)]), (2,7,1,13,109,16,[]), (2,8,1,14,109,16,[]), (2,9,1,15,109,16,[]), (2,10,1,16,109,16,[]), (2,11,1,12,110,16,[]), (2,12,1,13,110,16,[]), (2,13,1,14,110,16,[]), (2,14,1,15,110,16,[]), (2,15,1,16,110,16,[]), (2,16,1,10,55,16,[(111,1,16,14), ]), (2,17,1,11,55,16,[(111,1,16,15), ]), (2,18,1,12,55,16,[(111,1,16,16), ]), (2,19,1,13,55,16,[(112,1,16,15), ]), (2,20,2,14,55,16,[(112,1,16,16), ]), (2,21,2,10,28,16,[(113,1,16,16), (55,1,16,15)]), (2,22,2,11,28,16,[(114,1,16,15), (55,1,16,16)]), (2,23,2,12,28,16,[(114,1,16,16), (56,1,16,13)]), (2,24,2,13,28,16,[(115,1,16,16), (56,2,16,14)]), (2,25,2,14,28,16,[(118,2,16,15), (56,2,16,15)]), (2,26,2,17,19,15,[(118,2,16,16), (56,2,16,16), (28,2,16,15)]), (2,27,2,18,19,15,[(119,2,16,16), (57,2,16,13), (28,2,16,16)]), (2,28,2,19,19,15,[(120,2,16,16), (57,2,16,14), (29,2,15,14)]), (2,29,2,20,19,15,[(123,2,16,16), (57,2,16,15), (29,2,15,15)]), (3,7,1,9,73,16,[]), (3,8,1,10,73,16,[]), (3,9,1,11,73,16,[]), (3,10,1,12,73,16,[]), (3,11,1,13,73,16,[]), (3,12,1,14,73,16,[]), (3,13,1,15,73,16,[]), (3,14,1,16,73,16,[]), (3,15,1,12,74,16,[]), (3,16,1,11,44,16,[(74,1,16,13), ]), (3,17,1,12,44,16,[(74,1,16,14), ]), (3,18,2,13,44,16,[(74,1,16,15), ]), (3,19,2,14,44,16,[(74,1,16,16), ]), (3,20,2,15,44,16,[(75,1,16,15), ]), (3,21,2,11,25,16,[(75,1,16,16), (44,1,16,16)]), (3,22,2,12,25,16,[(76,1,16,15), (45,1,16,11)]), (3,23,2,13,25,16,[(76,1,16,16), (45,2,16,12)]), (3,24,2,14,25,16,[(77,2,16,16), (45,2,16,13)]), (3,25,2,15,25,16,[(78,2,15,16), (45,2,16,14)]), (3,26,2,19,18,12,[(79,2,15,16), (45,2,16,15), (25,2,16,16)]), (3,27,2,20,18,12,[(80,2,16,16), (45,2,16,16), (26,2,13,15)]), (3,28,2,21,18,12,[(82,2,15,15), (46,2,16,15), (26,2,13,16)]), (4,7,1,8,55,16,[]), (4,8,1,9,55,16,[]), (4,9,1,10,55,16,[]), (4,10,1,11,55,16,[]), (4,11,1,12,55,16,[]), (4,12,1,13,55,16,[]), (4,13,1,14,55,16,[]), (4,14,1,15,55,16,[]), (4,15,1,16,55,16,[]), (4,16,1,9,37,16,[(56,1,16,13), ]), (4,17,1,10,37,16,[(56,1,16,14), ]), (4,18,2,11,37,16,[(56,1,16,15), ]), (4,19,2,12,37,16,[(56,1,16,16), ]), (4,20,2,13,37,16,[(57,1,16,13), ]), (4,21,2,12,23,15,[(57,2,16,14), (37,2,16,14)]), (4,22,2,13,23,15,[(57,2,16,15), (37,2,16,15)]), (4,23,2,14,23,15,[(57,2,16,16), (37,2,16,16)]), (4,24,2,15,23,15,[(58,2,16,15), (38,2,16,13)]), (4,25,2,16,23,15,[(58,2,16,16), (38,2,16,14)]), (4,26,2,22,16,16,[(60,2,15,16), (38,2,16,15), (23,2,15,17)]), (4,27,2,23,16,16,[(61,2,16,15), (38,2,16,16), (23,2,15,18)])] # fmt: skip WHIR_CONFIGS = { (c[0], c[1]): { "log_inv_rate": c[0], "num_variables": c[1], "commitment_ood_samples": c[2], - "starting_folding_pow_bits": c[3], "rounds": [ {"num_queries": r[0], "ood_samples": r[1], "query_pow_bits": r[2], "folding_pow_bits": r[3]} for r in c[6] ] - + [{"num_queries": c[4], "query_pow_bits": c[5]}], + + [{"num_queries": c[4], "query_pow_bits": c[5], "folding_pow_bits": c[3]}], } for c in _WHIR_CONFIGS } @@ -423,9 +422,9 @@ def verify_whir( log_domain = current_vars + cfg["log_inv_rate"] target = ZERO constraints = commitment.oods_constraints() + statements - fold_pow_bits = cfg["starting_folding_pow_bits"] for round in range(n_rounds + 1): round_params = cfg["rounds"][round] + fold_pow_bits = round_params["folding_pow_bits"] folding_factor = whir_folding_factor_at_round(round) fiat_shamir.duplex() gamma = fiat_shamir.sample_ef() @@ -457,7 +456,6 @@ def verify_whir( final_stir_constraints = stir_constraints break constraints = new_commitment.oods_constraints() + stir_constraints - fold_pow_bits = round_params["folding_pow_bits"] log_domain -= RS_DOMAIN_INITIAL_REDUCTION_FACTOR if round == 0 else 1 commitment = new_commitment for smt in final_stir_constraints: diff --git a/crates/lean_prover/tests/check_whir_configs.rs b/crates/lean_prover/tests/check_whir_configs.rs index 60266ef2..530e7a10 100644 --- a/crates/lean_prover/tests/check_whir_configs.rs +++ b/crates/lean_prover/tests/check_whir_configs.rs @@ -17,29 +17,40 @@ fn expected_whir_configs_line() -> String { for num_variables in first_ff..=max_nv { let cfg: WhirConfig = WhirConfig::new(&builder, num_variables); - let mut rounds = String::from("("); - for (i, r) in cfg.round_parameters.iter().enumerate() { + let mut rounds = String::from("["); + for i in 0..cfg.n_rounds() { + let r = &cfg.round_parameters[i]; + let folding_pow_bits = if i == 0 { + cfg.starting_folding_pow_bits + } else { + cfg.round_parameters[i - 1].folding_pow_bits + }; if i > 0 { - rounds.push(','); + rounds += ", "; } write!( rounds, "({},{},{},{})", - r.num_queries, r.ood_samples, r.query_pow_bits, r.folding_pow_bits + r.num_queries, r.ood_samples, r.query_pow_bits, folding_pow_bits ) .unwrap(); } if cfg.round_parameters.len() == 1 { - rounds.push(','); + rounds += ", "; } - rounds.push(')'); + rounds.push(']'); + let final_pow_bits = cfg + .round_parameters + .last() + .map(|r| r.folding_pow_bits) + .unwrap_or(cfg.starting_folding_pow_bits); entries.push(format!( "({},{},{},{},{},{},{})", log_inv_rate, num_variables, cfg.commitment_ood_samples, - cfg.starting_folding_pow_bits, + final_pow_bits, cfg.final_queries, cfg.final_query_pow_bits, rounds, @@ -47,7 +58,7 @@ fn expected_whir_configs_line() -> String { } } - format!("WHIR_CONFIGS = ({})", entries.join(",")) + format!("WHIR_CONFIGS = [{}]", entries.join(", ")) } fn strip_ws(s: &str) -> String { From cfdb1ce9237d69f1ce8ea47d1c50b4a335d17988 Mon Sep 17 00:00:00 2001 From: Tom Wambsgans Date: Wed, 3 Jun 2026 16:57:18 +0400 Subject: [PATCH 18/31] wip --- .../lean_prover/python-verifier/verifier.py | 68 +++++++------------ 1 file changed, 24 insertions(+), 44 deletions(-) diff --git a/crates/lean_prover/python-verifier/verifier.py b/crates/lean_prover/python-verifier/verifier.py index b9489f22..6dac65b0 100644 --- a/crates/lean_prover/python-verifier/verifier.py +++ b/crates/lean_prover/python-verifier/verifier.py @@ -273,7 +273,7 @@ def next_mle(x: Sequence[EF], y: Sequence[EF]) -> EF: return s + math.prod([*x, *y]) -def eval_multilinear_evals(evals: Sequence[Fp | EF], point: Sequence[EF]) -> EF: +def eval_multilinear_by_evals(evals: Sequence[Fp | EF], point: Sequence[EF]) -> EF: """Evaluate a multilinear in evaluation form at `point`.""" assert len(evals) == 1 << len(point) cur: Sequence = evals @@ -282,14 +282,14 @@ def eval_multilinear_evals(evals: Sequence[Fp | EF], point: Sequence[EF]) -> EF: return cur[0] -def eval_multilinear_coeffs(coeffs: Sequence[EF], point: Sequence[EF]) -> EF: +def eval_multilinear_by_coeffs(coeffs: Sequence[EF], point: Sequence[EF]) -> EF: """Evaluate a multilinear in coefficient form at `point`.""" assert len(coeffs) == 1 << len(point) if not point: return coeffs[0] half = len(coeffs) // 2 - lo = eval_multilinear_coeffs(coeffs[:half], point[1:]) - hi = eval_multilinear_coeffs(coeffs[half:], point[1:]) + lo = eval_multilinear_by_coeffs(coeffs[:half], point[1:]) + hi = eval_multilinear_by_coeffs(coeffs[half:], point[1:]) return lo + hi * point[0] @@ -382,31 +382,6 @@ def verify_sumcheck( return point, target -def verify_stir_challenges( - fiat_shamir: FiatShamir, - is_first_round: int, - log_height: int, - num_variables: int, - num_queries: int, - query_pow_bits: int, - commitment: WhirCommitment, - folding_randomness: list[EF], -) -> list[SparseStatements]: - gen = Fp(KB_TWO_ADIC_GENERATORS[log_height]) - fiat_shamir.check_pow_grinding(query_pow_bits) - indices = fiat_shamir.sample_in_range(log_height, num_queries) - constraints: list[SparseStatements] = [] - for idx in indices: - op = fiat_shamir.next_merkle_opening() - merkle_verify_path(commitment.root, log_height, idx, op.leaf_data, op.path) - # Round 0 leaves are raw base-field elements; later rounds embed DIM Fp values per EF element. - packed = op.leaf_data if is_first_round else embed_ef(op.leaf_data) - fold = eval_multilinear_evals(packed, folding_randomness) - point = expand_from_univariate(EF(pow(int(gen.value), idx, P)), num_variables) - constraints.append(SparseStatements(num_variables, point, [(0, fold)])) - return constraints - - def verify_whir( fiat_shamir: FiatShamir, cfg: dict, @@ -442,16 +417,21 @@ def verify_whir( final_coeffs = fiat_shamir.next_extension_scalars_vec(1 << current_vars) else: new_commitment = WhirCommitment.read(fiat_shamir, current_vars, round_params["ood_samples"]) - stir_constraints = verify_stir_challenges( - fiat_shamir, - round == 0, - log_domain - folding_factor, - current_vars, - round_params["num_queries"], - round_params["query_pow_bits"], - commitment, - folding_challenges[-folding_factor:], - ) + + log_height = log_domain - folding_factor + gen = Fp(KB_TWO_ADIC_GENERATORS[log_height]) + fiat_shamir.check_pow_grinding(round_params["query_pow_bits"]) + indices = fiat_shamir.sample_in_range(log_height, round_params["num_queries"]) + stir_constraints: list[SparseStatements] = [] + for idx in indices: + op = fiat_shamir.next_merkle_opening() + merkle_verify_path(commitment.root, log_height, idx, op.leaf_data, op.path) + # Round 0 leaves are raw base-field elements; later rounds embed DIM Fp values per EF element. + packed = op.leaf_data if round == 0 else embed_ef(op.leaf_data) + fold = eval_multilinear_by_evals(packed, folding_challenges[-folding_factor:]) + point = expand_from_univariate(EF(pow(int(gen.value), idx, P)), current_vars) + stir_constraints.append(SparseStatements(current_vars, point, [(0, fold)])) + if is_final: final_stir_constraints = stir_constraints break @@ -479,7 +459,7 @@ def verify_whir( eval_prefix = eq_at_index(folding_challenges, v[0], sel_n) # sparse part of the point eval_weights += eval_prefix * eval_suffix * gamma_power gamma_power *= gamma - final_value = eval_multilinear_coeffs(final_coeffs, list(reversed(final_sc_point))) + final_value = eval_multilinear_by_coeffs(final_coeffs, list(reversed(final_sc_point))) if final_sc_value != eval_weights * final_value: raise ProofError("WHIR final sumcheck check failed") @@ -527,8 +507,8 @@ def verify_gkr_quotient(fiat_shamir: FiatShamir, n_vars: int) -> tuple[EF, list[ quotient = sum(n * d.inv() for n, d in zip(nums, dens)) point = fiat_shamir.sample_many_ef(N_VARS_TO_SEND_GKR_COEFFS) - claim_num = eval_multilinear_evals(nums, point) - claim_den = eval_multilinear_evals(dens, point) + claim_num = eval_multilinear_by_evals(nums, point) + claim_den = eval_multilinear_by_evals(dens, point) for layer_n_vars in range(N_VARS_TO_SEND_GKR_COEFFS, n_vars): fiat_shamir.duplex() @@ -609,7 +589,7 @@ def pref_at(offset: int, log_height: int) -> EF: pref = pref_at(offset, log_bytecode) pref_pad = pref_at(offset, log_byte_pad) value_bytecode_acc = fiat_shamir.next_extension_scalar() - bytecode_value = eval_multilinear_evals([Fp(v) for v in bytecode_multilinear], byte_pt + beta[-log_instr:]) + bytecode_value = eval_multilinear_by_evals([Fp(v) for v in bytecode_multilinear], byte_pt + beta[-log_instr:]) correction = math.prod(ONE - a for a in beta[: len(beta) - log_instr]) fp_byte = ( bytecode_value * correction @@ -1071,7 +1051,7 @@ def verify_execution( raise ProofError("AIR sumcheck: claimed value mismatch") pm_point = state.sample_many_ef(log2_strict(PUBLIC_INPUT_SIZE)) - pm_eval = eval_multilinear_evals(public_input, pm_point) + pm_eval = eval_multilinear_by_evals(public_input, pm_point) bytecode_acc_idx = (2 << log_memory) >> bytecode_log_size previous_statements = [ From ed2941b0fa5115e5929c1b584eaf7ef4554d3f82 Mon Sep 17 00:00:00 2001 From: Tom Wambsgans Date: Wed, 3 Jun 2026 17:04:05 +0400 Subject: [PATCH 19/31] wip --- crates/lean_prover/python-verifier/verifier.py | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/crates/lean_prover/python-verifier/verifier.py b/crates/lean_prover/python-verifier/verifier.py index 6dac65b0..8987a24e 100644 --- a/crates/lean_prover/python-verifier/verifier.py +++ b/crates/lean_prover/python-verifier/verifier.py @@ -427,10 +427,10 @@ def verify_whir( op = fiat_shamir.next_merkle_opening() merkle_verify_path(commitment.root, log_height, idx, op.leaf_data, op.path) # Round 0 leaves are raw base-field elements; later rounds embed DIM Fp values per EF element. - packed = op.leaf_data if round == 0 else embed_ef(op.leaf_data) - fold = eval_multilinear_by_evals(packed, folding_challenges[-folding_factor:]) + leaf = op.leaf_data if round == 0 else embed_ef(op.leaf_data) + leaf_eval = eval_multilinear_by_evals(leaf, folding_challenges[-folding_factor:]) point = expand_from_univariate(EF(pow(int(gen.value), idx, P)), current_vars) - stir_constraints.append(SparseStatements(current_vars, point, [(0, fold)])) + stir_constraints.append(SparseStatements(current_vars, point, [(0, leaf_eval)])) if is_final: final_stir_constraints = stir_constraints @@ -481,7 +481,7 @@ def stacked_pcs_global_statements( table_offsets[table.name] = layout_offset layout_offset += table.n_columns << n_vars - out = list(previous_statements) + res = list(previous_statements) def values_at(d: dict[int, EF], col_base: int) -> list[tuple[int, EF]]: return [(col_base + i, v) for i, v in sorted(d.items())] @@ -490,13 +490,13 @@ def values_at(d: dict[int, EF], col_base: int) -> list[tuple[int, EF]]: n_vars = heights[table.name] offset = table_offsets[table.name] col_base = offset >> n_vars - out.extend(table.boundary_statements(stacked_n_vars, offset, n_vars, ending_pc)) + res.extend(table.boundary_statements(stacked_n_vars, offset, n_vars, ending_pc)) for point, eq_values, next_values in committed_statements[table.name]: if next_values: - out.append(SparseStatements(stacked_n_vars, list(point), values_at(next_values, col_base), True)) - out.append(SparseStatements(stacked_n_vars, list(point), values_at(eq_values, col_base))) + res.append(SparseStatements(stacked_n_vars, point, values_at(next_values, col_base), True)) + res.append(SparseStatements(stacked_n_vars, point, values_at(eq_values, col_base))) - return out + return res def verify_gkr_quotient(fiat_shamir: FiatShamir, n_vars: int) -> tuple[EF, list[EF], EF, EF]: From 4bef008bd0f51b3beefa82dd298210c5e3c74139 Mon Sep 17 00:00:00 2001 From: Tom Wambsgans Date: Wed, 3 Jun 2026 17:22:36 +0400 Subject: [PATCH 20/31] wip --- .../lean_prover/python-verifier/verifier.py | 76 +++++++++---------- 1 file changed, 34 insertions(+), 42 deletions(-) diff --git a/crates/lean_prover/python-verifier/verifier.py b/crates/lean_prover/python-verifier/verifier.py index 8987a24e..b0142b68 100644 --- a/crates/lean_prover/python-verifier/verifier.py +++ b/crates/lean_prover/python-verifier/verifier.py @@ -513,8 +513,8 @@ def verify_gkr_quotient(fiat_shamir: FiatShamir, n_vars: int) -> tuple[EF, list[ for layer_n_vars in range(N_VARS_TO_SEND_GKR_COEFFS, n_vars): fiat_shamir.duplex() alpha = fiat_shamir.sample_ef() - raw_pt, sc_value = verify_sumcheck(fiat_shamir, claim_num + alpha * claim_den, layer_n_vars, 3) - sc_point = list(reversed(raw_pt)) + sc_point, sc_value = verify_sumcheck(fiat_shamir, claim_num + alpha * claim_den, layer_n_vars, 3) + sc_point = list(reversed(sc_point)) nl, nr, dl, dr = fiat_shamir.next_extension_scalars_vec(4) if sc_value != eq_poly(point, sc_point) * (alpha * dl * dr + nl * dr + nr * dl): raise ProofError("GKR step: postponed value mismatch") @@ -540,19 +540,17 @@ def sort_tables_by_height(tables: Sequence[Table], heights: dict[str, int]) -> l def verify_generic_logup( fiat_shamir: FiatShamir, gamma: EF, # quotient denominator challenge - beta: list[EF], # bus-tuple hashing seeds + beta: list[EF], # bus-tuple hashing seed beta_eq: list[EF], # eq(beta, ·) evaluation table log_memory: int, bytecode_multilinear: list[int], tables: Sequence[Table], - heights: dict[str, int], + table_heights: dict[str, int], ) -> dict: - ds_mem = Fp(LOGUP_MEMORY_DOMAINSEP) - ds_byte = Fp(LOGUP_BYTECODE_DOMAINSEP) log_instr = log2_ceil(N_INSTRUCTION_COLUMNS) log_bytecode = log2_strict(len(bytecode_multilinear)) - log_instr - tables_sorted = sort_tables_by_height(tables, heights) + tables_sorted = sort_tables_by_height(tables, table_heights) tallest_h = tables_sorted[0][1] total_active_len = ( @@ -560,15 +558,14 @@ def verify_generic_logup( + max(1 << log_bytecode, 1 << tallest_h) + sum(t.n_bus_interactions << h for t, h in tables_sorted) ) - total_gkr_n_vars = log2_ceil(total_active_len) + logup_n_vars = log2_ceil(total_active_len) - quotient, point_gkr, claim_num, claim_den = verify_gkr_quotient(fiat_shamir, total_gkr_n_vars) + quotient, point_gkr, claim_num, claim_den = verify_gkr_quotient(fiat_shamir, logup_n_vars) if quotient != ZERO: - raise ProofError("logup: GKR sum != 0") + raise ProofError("imbalanced logup bus") def pref_at(offset: int, log_height: int) -> EF: - """Lagrange weight for the layout-offset of a section of height 2^log_height.""" - n_missing = total_gkr_n_vars - log_height + n_missing = logup_n_vars - log_height return eq_at_index(point_gkr, offset >> log_height, n_missing) num = den = ZERO @@ -578,27 +575,26 @@ def pref_at(offset: int, log_height: int) -> EF: pref = pref_at(0, log_memory) value_memory_acc = fiat_shamir.next_extension_scalar() value_memory = fiat_shamir.next_extension_scalar() - fp_mem = finger_print(ds_mem, [mle_of_01234567_etc(mem_pt), value_memory], beta_eq) num -= pref * value_memory_acc - den += pref * (gamma - fp_mem) + den += pref * (gamma - finger_print(Fp(LOGUP_MEMORY_DOMAINSEP), [mle_of_01234567_etc(mem_pt), value_memory], beta_eq)) offset = 1 << log_memory # Bytecode section (padded to the tallest table) - log_byte_pad = max(log_bytecode, tallest_h) - byte_pt = point_gkr[-log_bytecode:] + log_bytecode_padded = max(log_bytecode, tallest_h) + bytecode_pt = point_gkr[-log_bytecode:] pref = pref_at(offset, log_bytecode) - pref_pad = pref_at(offset, log_byte_pad) + pref_padded = pref_at(offset, log_bytecode_padded) value_bytecode_acc = fiat_shamir.next_extension_scalar() - bytecode_value = eval_multilinear_by_evals([Fp(v) for v in bytecode_multilinear], byte_pt + beta[-log_instr:]) + value_bytecode = eval_multilinear_by_evals([Fp(v) for v in bytecode_multilinear], bytecode_pt + beta[-log_instr:]) correction = math.prod(ONE - a for a in beta[: len(beta) - log_instr]) - fp_byte = ( - bytecode_value * correction - + mle_of_01234567_etc(byte_pt) * beta_eq[N_INSTRUCTION_COLUMNS] - + beta_eq[-1] * ds_byte + fingerprint_bytecode = ( + value_bytecode * correction + + mle_of_01234567_etc(bytecode_pt) * beta_eq[N_INSTRUCTION_COLUMNS] + + beta_eq[-1] * Fp(LOGUP_BYTECODE_DOMAINSEP) ) num -= pref * value_bytecode_acc - den += pref * (gamma - fp_byte) + pref_pad * mle_of_zeros_then_ones(1 << log_bytecode, point_gkr[-log_byte_pad:]) - offset += 1 << log_byte_pad + den += pref * (gamma - fingerprint_bytecode) + pref_padded * mle_of_zeros_then_ones(1 << log_bytecode, point_gkr[-log_bytecode_padded:]) + offset += 1 << log_bytecode_padded # Per-table section table_offsets: dict[str, int] = {} @@ -612,37 +608,33 @@ def pref_at(offset: int, log_height: int) -> EF: columns_values: dict[str, dict[int, EF]] = {} for table in tables: - name = table.name - log_n_rows = heights[name] - row_stride = 1 << log_n_rows - offset = table_offsets[name] - vals: dict[int, EF] = {} + offset = table_offsets[table.name] + column_values: dict[int, EF] = {} def read(cols: Sequence[int]) -> list[EF]: - """Evals of `cols`, batch-reading any not-yet-seen column from the transcript, in order.""" - missing = [c for c in cols if c not in vals] + missing = [c for c in cols if c not in column_values] for c, e in zip(missing, fiat_shamir.next_extension_scalars_vec(len(missing))): - vals[c] = e - return [vals[c] for c in cols] + column_values[c] = e + return [column_values[c] for c in cols] for bus in table.buses: if not bus.cols: - pref = pref_at(offset, log_n_rows) - bus_num_vals[name] = fiat_shamir.next_extension_scalar() - bus_den_vals[name] = fiat_shamir.next_extension_scalar() - num += pref * bus_num_vals[name] - den += pref * bus_den_vals[name] - offset += row_stride + pref = pref_at(offset, table_heights[table.name]) + bus_num_vals[table.name] = fiat_shamir.next_extension_scalar() + bus_den_vals[table.name] = fiat_shamir.next_extension_scalar() + num += pref * bus_num_vals[table.name] + den += pref * bus_den_vals[table.name] + offset += 1 << table_heights[table.name] continue sep, base = Fp(bus.domain_sep), [table.col(c) for c in bus.cols] # memory / bytecode for i in range(bus.n_terms): # term i: σ = (m[base[0]] + i, m[base[1:] + i]) - pref = pref_at(offset, log_n_rows) + pref = pref_at(offset, table_heights[table.name]) d = read([base[0], *(c + i for c in base[1:])]) num += pref den += pref * (gamma - finger_print(sep, [d[0] + i, *d[1:]], beta_eq)) - offset += row_stride + offset += 1 << table_heights[table.name] - columns_values[name] = vals + columns_values[table.name] = column_values den += mle_of_zeros_then_ones(final_offset, point_gkr) if num != claim_num: From cbe4ddfe3c98fa40cb154654594c32547ee75adc Mon Sep 17 00:00:00 2001 From: Tom Wambsgans Date: Wed, 3 Jun 2026 17:44:18 +0400 Subject: [PATCH 21/31] wip --- .../lean_prover/python-verifier/verifier.py | 88 +++++++++---------- 1 file changed, 44 insertions(+), 44 deletions(-) diff --git a/crates/lean_prover/python-verifier/verifier.py b/crates/lean_prover/python-verifier/verifier.py index b0142b68..bc0ed014 100644 --- a/crates/lean_prover/python-verifier/verifier.py +++ b/crates/lean_prover/python-verifier/verifier.py @@ -560,40 +560,40 @@ def verify_generic_logup( ) logup_n_vars = log2_ceil(total_active_len) - quotient, point_gkr, claim_num, claim_den = verify_gkr_quotient(fiat_shamir, logup_n_vars) + quotient, gkr_point, claim_num, claim_den = verify_gkr_quotient(fiat_shamir, logup_n_vars) if quotient != ZERO: raise ProofError("imbalanced logup bus") def pref_at(offset: int, log_height: int) -> EF: n_missing = logup_n_vars - log_height - return eq_at_index(point_gkr, offset >> log_height, n_missing) + return eq_at_index(gkr_point, offset >> log_height, n_missing) num = den = ZERO # Memory section - mem_pt = point_gkr[-log_memory:] + mem_pt = gkr_point[-log_memory:] pref = pref_at(0, log_memory) - value_memory_acc = fiat_shamir.next_extension_scalar() - value_memory = fiat_shamir.next_extension_scalar() - num -= pref * value_memory_acc - den += pref * (gamma - finger_print(Fp(LOGUP_MEMORY_DOMAINSEP), [mle_of_01234567_etc(mem_pt), value_memory], beta_eq)) + memory_acc_eval = fiat_shamir.next_extension_scalar() + memory_eval = fiat_shamir.next_extension_scalar() + num -= pref * memory_acc_eval + den += pref * (gamma - finger_print(Fp(LOGUP_MEMORY_DOMAINSEP), [mle_of_01234567_etc(mem_pt), memory_eval], beta_eq)) offset = 1 << log_memory # Bytecode section (padded to the tallest table) log_bytecode_padded = max(log_bytecode, tallest_h) - bytecode_pt = point_gkr[-log_bytecode:] + bytecode_point = gkr_point[-log_bytecode:] pref = pref_at(offset, log_bytecode) pref_padded = pref_at(offset, log_bytecode_padded) value_bytecode_acc = fiat_shamir.next_extension_scalar() - value_bytecode = eval_multilinear_by_evals([Fp(v) for v in bytecode_multilinear], bytecode_pt + beta[-log_instr:]) + bytecode_eval = eval_multilinear_by_evals([Fp(v) for v in bytecode_multilinear], bytecode_point + beta[-log_instr:]) correction = math.prod(ONE - a for a in beta[: len(beta) - log_instr]) fingerprint_bytecode = ( - value_bytecode * correction - + mle_of_01234567_etc(bytecode_pt) * beta_eq[N_INSTRUCTION_COLUMNS] + bytecode_eval * correction + + mle_of_01234567_etc(bytecode_point) * beta_eq[N_INSTRUCTION_COLUMNS] + beta_eq[-1] * Fp(LOGUP_BYTECODE_DOMAINSEP) ) num -= pref * value_bytecode_acc - den += pref * (gamma - fingerprint_bytecode) + pref_padded * mle_of_zeros_then_ones(1 << log_bytecode, point_gkr[-log_bytecode_padded:]) + den += pref * (gamma - fingerprint_bytecode) + pref_padded * mle_of_zeros_then_ones(1 << log_bytecode, gkr_point[-log_bytecode_padded:]) offset += 1 << log_bytecode_padded # Per-table section @@ -603,49 +603,49 @@ def pref_at(offset: int, log_height: int) -> EF: offset += table.n_bus_interactions << log_n_rows final_offset = offset - bus_num_vals: dict[str, EF] = {} - bus_den_vals: dict[str, EF] = {} - columns_values: dict[str, dict[int, EF]] = {} + precompile_nums: dict[str, EF] = {} + precompile_dens: dict[str, EF] = {} + columns_evals: dict[str, dict[int, EF]] = {} for table in tables: offset = table_offsets[table.name] - column_values: dict[int, EF] = {} + columns_evals[table.name] = {} - def read(cols: Sequence[int]) -> list[EF]: - missing = [c for c in cols if c not in column_values] + def request_column_evals_dedup(cols: Sequence[int]) -> list[EF]: + missing = [c for c in cols if c not in columns_evals[table.name]] for c, e in zip(missing, fiat_shamir.next_extension_scalars_vec(len(missing))): - column_values[c] = e - return [column_values[c] for c in cols] + columns_evals[table.name][c] = e + return [columns_evals[table.name][c] for c in cols] for bus in table.buses: - if not bus.cols: + if bus.cols: + # memory / bytecode interraction + base = [table.col(c) for c in bus.cols] + for i in range(bus.n_terms): # term i: σ = (m[base[0]] + i, m[base[1:] + i]) + pref = pref_at(offset, table_heights[table.name]) + d = request_column_evals_dedup([base[0], *(c + i for c in base[1:])]) + num += pref # always multiplicity 1 + den += pref * (gamma - finger_print(Fp(bus.domain_sep), [d[0] + i, *d[1:]], beta_eq)) + offset += 1 << table_heights[table.name] + else: + # precompile interraction pref = pref_at(offset, table_heights[table.name]) - bus_num_vals[table.name] = fiat_shamir.next_extension_scalar() - bus_den_vals[table.name] = fiat_shamir.next_extension_scalar() - num += pref * bus_num_vals[table.name] - den += pref * bus_den_vals[table.name] + precompile_nums[table.name] = fiat_shamir.next_extension_scalar() + precompile_dens[table.name] = fiat_shamir.next_extension_scalar() + num += pref * precompile_nums[table.name] + den += pref * precompile_dens[table.name] offset += 1 << table_heights[table.name] - continue - sep, base = Fp(bus.domain_sep), [table.col(c) for c in bus.cols] # memory / bytecode - for i in range(bus.n_terms): # term i: σ = (m[base[0]] + i, m[base[1:] + i]) - pref = pref_at(offset, table_heights[table.name]) - d = read([base[0], *(c + i for c in base[1:])]) - num += pref - den += pref * (gamma - finger_print(sep, [d[0] + i, *d[1:]], beta_eq)) - offset += 1 << table_heights[table.name] - - columns_values[table.name] = column_values - - den += mle_of_zeros_then_ones(final_offset, point_gkr) + + den += mle_of_zeros_then_ones(final_offset, gkr_point) if num != claim_num: raise ProofError("logup: numerators value mismatch") if den != claim_den: raise ProofError("logup: denominators value mismatch") return { - "value_memory": value_memory, "value_memory_acc": value_memory_acc, - "value_bytecode_acc": value_bytecode_acc, "bus_num": bus_num_vals, "bus_den": bus_den_vals, - "gkr_point": point_gkr, "columns_values": columns_values, + "memory_eval": memory_eval, "memory_acc_eval": memory_acc_eval, + "value_bytecode_acc": value_bytecode_acc, "precompile_nums": precompile_nums, "precompile_dens": precompile_dens, + "gkr_point": gkr_point, "columns_evals": columns_evals, } # fmt: skip @@ -1018,12 +1018,12 @@ def verify_execution( initial_sum, offset = ZERO, 0 for table in TABLES: - initial_sum += alpha_powers[offset] * (logup["bus_num"][table.name] * table.precompile_bus_interaction_sign) - initial_sum += alpha_powers[offset + 1] * (logup_gamma - logup["bus_den"][table.name]) + initial_sum += alpha_powers[offset] * (logup["precompile_nums"][table.name] * table.precompile_bus_interaction_sign) + initial_sum += alpha_powers[offset + 1] * (logup_gamma - logup["precompile_dens"][table.name]) offset += table.n_constraints sc_point, sc_value = verify_sumcheck(state, initial_sum, n_max, max(t.air_degree + 1 for t in TABLES)) - committed = {t.name: [(gkr_point[-log_heights[t.name] :], logup["columns_values"][t.name], {})] for t in TABLES} + committed = {t.name: [(gkr_point[-log_heights[t.name] :], logup["columns_evals"][t.name], {})] for t in TABLES} my_air_final, offset = ZERO, 0 for table in TABLES: log_n_rows = log_heights[table.name] @@ -1050,7 +1050,7 @@ def verify_execution( SparseStatements( stacked_n_vars, gkr_point[-log_memory:], - [(0, logup["value_memory"]), (1, logup["value_memory_acc"])], + [(0, logup["memory_eval"]), (1, logup["memory_acc_eval"])], ), SparseStatements(stacked_n_vars, pm_point, [(0, pm_eval)]), SparseStatements( From d508238087fee7d67de13b20f677bb5c400c7c02 Mon Sep 17 00:00:00 2001 From: Tom Wambsgans Date: Wed, 3 Jun 2026 18:35:43 +0400 Subject: [PATCH 22/31] wip --- .../lean_prover/python-verifier/verifier.py | 140 +++++++++--------- 1 file changed, 71 insertions(+), 69 deletions(-) diff --git a/crates/lean_prover/python-verifier/verifier.py b/crates/lean_prover/python-verifier/verifier.py index bc0ed014..3f897c00 100644 --- a/crates/lean_prover/python-verifier/verifier.py +++ b/crates/lean_prover/python-verifier/verifier.py @@ -24,7 +24,7 @@ WHIR_INITIAL_FOLDING_FACTOR, WHIR_SUBSEQUENT_FOLDING_FACTOR, WHIR_MAX_NUM_VARIABLES_TO_SEND_COEFFS = 7, 5, 8 MIN_WHIR_LOG_INV_RATE, MAX_WHIR_LOG_INV_RATE, RS_DOMAIN_INITIAL_REDUCTION_FACTOR = 1, 4, 5 -_WHIR_CONFIGS = WHIR_CONFIGS = [(1,7,1,10,220,16,[]), (1,8,1,11,220,16,[]), (1,9,1,12,220,16,[]), (1,10,1,13,220,16,[]), (1,11,1,14,220,16,[]), (1,12,1,15,220,16,[]), (1,13,1,16,220,16,[]), (1,14,1,15,221,16,[]), (1,15,1,16,221,16,[]), (1,16,1,11,73,16,[(222,1,16,16), ]), (1,17,1,12,73,16,[(223,1,16,16), ]), (1,18,1,13,73,16,[(224,1,16,16), ]), (1,19,1,14,73,16,[(225,1,16,16), ]), (1,20,1,15,73,16,[(227,1,16,16), ]), (1,21,2,9,32,16,[(229,1,16,16), (73,1,16,16)]), (1,22,2,10,32,16,[(230,1,16,16), (74,1,16,12)]), (1,23,2,11,32,16,[(234,1,16,16), (74,1,16,13)]), (1,24,2,12,32,16,[(235,1,16,16), (74,1,16,14)]), (1,25,2,13,32,16,[(241,2,16,16), (74,2,16,15)]), (1,26,2,14,21,14,[(243,2,16,16), (74,2,16,16), (32,2,16,14)]), (1,27,2,15,21,14,[(248,2,16,16), (75,2,16,15), (32,2,16,15)]), (1,28,2,16,21,14,[(256,2,16,16), (75,2,16,16), (32,2,16,16)]), (1,29,2,17,21,14,[(262,2,16,16), (76,2,16,15), (33,2,16,12)]), (1,30,2,18,21,14,[(270,2,16,16), (76,2,16,16), (33,2,16,13)]), (2,7,1,13,109,16,[]), (2,8,1,14,109,16,[]), (2,9,1,15,109,16,[]), (2,10,1,16,109,16,[]), (2,11,1,12,110,16,[]), (2,12,1,13,110,16,[]), (2,13,1,14,110,16,[]), (2,14,1,15,110,16,[]), (2,15,1,16,110,16,[]), (2,16,1,10,55,16,[(111,1,16,14), ]), (2,17,1,11,55,16,[(111,1,16,15), ]), (2,18,1,12,55,16,[(111,1,16,16), ]), (2,19,1,13,55,16,[(112,1,16,15), ]), (2,20,2,14,55,16,[(112,1,16,16), ]), (2,21,2,10,28,16,[(113,1,16,16), (55,1,16,15)]), (2,22,2,11,28,16,[(114,1,16,15), (55,1,16,16)]), (2,23,2,12,28,16,[(114,1,16,16), (56,1,16,13)]), (2,24,2,13,28,16,[(115,1,16,16), (56,2,16,14)]), (2,25,2,14,28,16,[(118,2,16,15), (56,2,16,15)]), (2,26,2,17,19,15,[(118,2,16,16), (56,2,16,16), (28,2,16,15)]), (2,27,2,18,19,15,[(119,2,16,16), (57,2,16,13), (28,2,16,16)]), (2,28,2,19,19,15,[(120,2,16,16), (57,2,16,14), (29,2,15,14)]), (2,29,2,20,19,15,[(123,2,16,16), (57,2,16,15), (29,2,15,15)]), (3,7,1,9,73,16,[]), (3,8,1,10,73,16,[]), (3,9,1,11,73,16,[]), (3,10,1,12,73,16,[]), (3,11,1,13,73,16,[]), (3,12,1,14,73,16,[]), (3,13,1,15,73,16,[]), (3,14,1,16,73,16,[]), (3,15,1,12,74,16,[]), (3,16,1,11,44,16,[(74,1,16,13), ]), (3,17,1,12,44,16,[(74,1,16,14), ]), (3,18,2,13,44,16,[(74,1,16,15), ]), (3,19,2,14,44,16,[(74,1,16,16), ]), (3,20,2,15,44,16,[(75,1,16,15), ]), (3,21,2,11,25,16,[(75,1,16,16), (44,1,16,16)]), (3,22,2,12,25,16,[(76,1,16,15), (45,1,16,11)]), (3,23,2,13,25,16,[(76,1,16,16), (45,2,16,12)]), (3,24,2,14,25,16,[(77,2,16,16), (45,2,16,13)]), (3,25,2,15,25,16,[(78,2,15,16), (45,2,16,14)]), (3,26,2,19,18,12,[(79,2,15,16), (45,2,16,15), (25,2,16,16)]), (3,27,2,20,18,12,[(80,2,16,16), (45,2,16,16), (26,2,13,15)]), (3,28,2,21,18,12,[(82,2,15,15), (46,2,16,15), (26,2,13,16)]), (4,7,1,8,55,16,[]), (4,8,1,9,55,16,[]), (4,9,1,10,55,16,[]), (4,10,1,11,55,16,[]), (4,11,1,12,55,16,[]), (4,12,1,13,55,16,[]), (4,13,1,14,55,16,[]), (4,14,1,15,55,16,[]), (4,15,1,16,55,16,[]), (4,16,1,9,37,16,[(56,1,16,13), ]), (4,17,1,10,37,16,[(56,1,16,14), ]), (4,18,2,11,37,16,[(56,1,16,15), ]), (4,19,2,12,37,16,[(56,1,16,16), ]), (4,20,2,13,37,16,[(57,1,16,13), ]), (4,21,2,12,23,15,[(57,2,16,14), (37,2,16,14)]), (4,22,2,13,23,15,[(57,2,16,15), (37,2,16,15)]), (4,23,2,14,23,15,[(57,2,16,16), (37,2,16,16)]), (4,24,2,15,23,15,[(58,2,16,15), (38,2,16,13)]), (4,25,2,16,23,15,[(58,2,16,16), (38,2,16,14)]), (4,26,2,22,16,16,[(60,2,15,16), (38,2,16,15), (23,2,15,17)]), (4,27,2,23,16,16,[(61,2,16,15), (38,2,16,16), (23,2,15,18)])] # fmt: skip +_WHIR_CONFIGS = WHIR_CONFIGS = [(1,7,1,10,220,16,[]), (1,8,1,11,220,16,[]), (1,9,1,12,220,16,[]), (1,10,1,13,220,16,[]), (1,11,1,14,220,16,[]), (1,12,1,15,220,16,[]), (1,13,1,16,220,16,[]), (1,14,1,15,221,16,[]), (1,15,1,16,221,16,[]), (1,16,1,11,73,16,[(222,1,16,16), ]), (1,17,1,12,73,16,[(223,1,16,16), ]), (1,18,1,13,73,16,[(224,1,16,16), ]), (1,19,1,14,73,16,[(225,1,16,16), ]), (1,20,1,15,73,16,[(227,1,16,16), ]), (1,21,2,9,32,16,[(229,1,16,16), (73,1,16,16)]), (1,22,2,10,32,16,[(230,1,16,16), (74,1,16,12)]), (1,23,2,11,32,16,[(234,1,16,16), (74,1,16,13)]), (1,24,2,12,32,16,[(235,1,16,16), (74,1,16,14)]), (1,25,2,13,32,16,[(241,2,16,16), (74,2,16,15)]), (1,26,2,14,21,14,[(243,2,16,16), (74,2,16,16), (32,2,16,14)]), (1,27,2,15,21,14,[(248,2,16,16), (75,2,16,15), (32,2,16,15)]), (1,28,2,16,21,14,[(256,2,16,16), (75,2,16,16), (32,2,16,16)]), (1,29,2,17,21,14,[(262,2,16,16), (76,2,16,15), (33,2,16,12)]), (1,30,2,18,21,14,[(270,2,16,16), (76,2,16,16), (33,2,16,13)]), (2,7,1,13,109,16,[]), (2,8,1,14,109,16,[]), (2,9,1,15,109,16,[]), (2,10,1,16,109,16,[]), (2,11,1,12,110,16,[]), (2,12,1,13,110,16,[]), (2,13,1,14,110,16,[]), (2,14,1,15,110,16,[]), (2,15,1,16,110,16,[]), (2,16,1,10,55,16,[(111,1,16,14), ]), (2,17,1,11,55,16,[(111,1,16,15), ]), (2,18,1,12,55,16,[(111,1,16,16), ]), (2,19,1,13,55,16,[(112,1,16,15), ]), (2,20,2,14,55,16,[(112,1,16,16), ]), (2,21,2,10,28,16,[(113,1,16,16), (55,1,16,15)]), (2,22,2,11,28,16,[(114,1,16,15), (55,1,16,16)]), (2,23,2,12,28,16,[(114,1,16,16), (56,1,16,13)]), (2,24,2,13,28,16,[(115,1,16,16), (56,2,16,14)]), (2,25,2,14,28,16,[(118,2,16,15), (56,2,16,15)]), (2,26,2,17,19,15,[(118,2,16,16), (56,2,16,16), (28,2,16,15)]), (2,27,2,18,19,15,[(119,2,16,16), (57,2,16,13), (28,2,16,16)]), (2,28,2,19,19,15,[(120,2,16,16), (57,2,16,14), (29,2,15,14)]), (2,29,2,20,19,15,[(123,2,16,16), (57,2,16,15), (29,2,15,15)]), (3,7,1,9,73,16,[]), (3,8,1,10,73,16,[]), (3,9,1,11,73,16,[]), (3,10,1,12,73,16,[]), (3,11,1,13,73,16,[]), (3,12,1,14,73,16,[]), (3,13,1,15,73,16,[]), (3,14,1,16,73,16,[]), (3,15,1,12,74,16,[]), (3,16,1,11,44,16,[(74,1,16,13), ]), (3,17,1,12,44,16,[(74,1,16,14), ]), (3,18,2,13,44,16,[(74,1,16,15), ]), (3,19,2,14,44,16,[(74,1,16,16), ]), (3,20,2,15,44,16,[(75,1,16,15), ]), (3,21,2,11,25,16,[(75,1,16,16), (44,1,16,16)]), (3,22,2,12,25,16,[(76,1,16,15), (45,1,16,11)]), (3,23,2,13,25,16,[(76,1,16,16), (45,2,16,12)]), (3,24,2,14,25,16,[(77,2,16,16), (45,2,16,13)]), (3,25,2,15,25,16,[(78,2,15,16), (45,2,16,14)]), (3,26,2,19,18,12,[(79,2,15,16), (45,2,16,15), (25,2,16,16)]), (3,27,2,20,18,12,[(80,2,16,16), (45,2,16,16), (26,2,13,15)]), (3,28,2,21,18,12,[(82,2,15,15), (46,2,16,15), (26,2,13,16)]), (4,7,1,8,55,16,[]), (4,8,1,9,55,16,[]), (4,9,1,10,55,16,[]), (4,10,1,11,55,16,[]), (4,11,1,12,55,16,[]), (4,12,1,13,55,16,[]), (4,13,1,14,55,16,[]), (4,14,1,15,55,16,[]), (4,15,1,16,55,16,[]), (4,16,1,9,37,16,[(56,1,16,13), ]), (4,17,1,10,37,16,[(56,1,16,14), ]), (4,18,2,11,37,16,[(56,1,16,15), ]), (4,19,2,12,37,16,[(56,1,16,16), ]), (4,20,2,13,37,16,[(57,1,16,13), ]), (4,21,2,12,23,15,[(57,2,16,14), (37,2,16,14)]), (4,22,2,13,23,15,[(57,2,16,15), (37,2,16,15)]), (4,23,2,14,23,15,[(57,2,16,16), (37,2,16,16)]), (4,24,2,15,23,15,[(58,2,16,15), (38,2,16,13)]), (4,25,2,16,23,15,[(58,2,16,16), (38,2,16,14)]), (4,26,2,22,16,16,[(60,2,15,16), (38,2,16,15), (23,2,15,17)]), (4,27,2,23,16,16,[(61,2,16,15), (38,2,16,16), (23,2,15,18)])] # fmt: skip WHIR_CONFIGS = { (c[0], c[1]): { "log_inv_rate": c[0], @@ -431,7 +431,7 @@ def verify_whir( leaf_eval = eval_multilinear_by_evals(leaf, folding_challenges[-folding_factor:]) point = expand_from_univariate(EF(pow(int(gen.value), idx, P)), current_vars) stir_constraints.append(SparseStatements(current_vars, point, [(0, leaf_eval)])) - + if is_final: final_stir_constraints = stir_constraints break @@ -452,11 +452,11 @@ def verify_whir( folding_challenges = folding_challenges[whir_folding_factor_at_round(round - 1) :] gamma_power = ONE for smt in smts: - point_suffix = folding_challenges[len(folding_challenges) - len(smt.point) :] # dense part of the point + point_suffix = folding_challenges[len(folding_challenges) - len(smt.point) :] # dense part of the point eval_suffix = next_mle(smt.point, point_suffix) if smt.is_next else eq_poly(smt.point, point_suffix) sel_n = smt.selector_num_variables for v in smt.values: - eval_prefix = eq_at_index(folding_challenges, v[0], sel_n) # sparse part of the point + eval_prefix = eq_at_index(folding_challenges, v[0], sel_n) # sparse part of the point eval_weights += eval_prefix * eval_suffix * gamma_power gamma_power *= gamma final_value = eval_multilinear_by_coeffs(final_coeffs, list(reversed(final_sc_point))) @@ -576,7 +576,9 @@ def pref_at(offset: int, log_height: int) -> EF: memory_acc_eval = fiat_shamir.next_extension_scalar() memory_eval = fiat_shamir.next_extension_scalar() num -= pref * memory_acc_eval - den += pref * (gamma - finger_print(Fp(LOGUP_MEMORY_DOMAINSEP), [mle_of_01234567_etc(mem_pt), memory_eval], beta_eq)) + den += pref * ( + gamma - finger_print(Fp(LOGUP_MEMORY_DOMAINSEP), [mle_of_01234567_etc(mem_pt), memory_eval], beta_eq) + ) offset = 1 << log_memory # Bytecode section (padded to the tallest table) @@ -593,7 +595,9 @@ def pref_at(offset: int, log_height: int) -> EF: + beta_eq[-1] * Fp(LOGUP_BYTECODE_DOMAINSEP) ) num -= pref * value_bytecode_acc - den += pref * (gamma - fingerprint_bytecode) + pref_padded * mle_of_zeros_then_ones(1 << log_bytecode, gkr_point[-log_bytecode_padded:]) + den += pref * (gamma - fingerprint_bytecode) + pref_padded * mle_of_zeros_then_ones( + 1 << log_bytecode, gkr_point[-log_bytecode_padded:] + ) offset += 1 << log_bytecode_padded # Per-table section @@ -620,11 +624,11 @@ def request_column_evals_dedup(cols: Sequence[int]) -> list[EF]: for bus in table.buses: if bus.cols: # memory / bytecode interraction - base = [table.col(c) for c in bus.cols] + base = [table.col(c) for c in bus.cols] for i in range(bus.n_terms): # term i: σ = (m[base[0]] + i, m[base[1:] + i]) pref = pref_at(offset, table_heights[table.name]) d = request_column_evals_dedup([base[0], *(c + i for c in base[1:])]) - num += pref # always multiplicity 1 + num += pref # always multiplicity 1 den += pref * (gamma - finger_print(Fp(bus.domain_sep), [d[0] + i, *d[1:]], beta_eq)) offset += 1 << table_heights[table.name] else: @@ -635,7 +639,7 @@ def request_column_evals_dedup(cols: Sequence[int]) -> list[EF]: num += pref * precompile_nums[table.name] den += pref * precompile_dens[table.name] offset += 1 << table_heights[table.name] - + den += mle_of_zeros_then_ones(final_offset, gkr_point) if num != claim_num: raise ProofError("logup: numerators value mismatch") @@ -658,9 +662,9 @@ class ConstraintEvaluator: def __init__( self, flat: Sequence[EF], shift: Sequence[EF], alpha_powers: Sequence[EF], columns: Sequence[str] ) -> None: - self.flat = list(flat) - self.shift = list(shift) - self.alpha_powers = list(alpha_powers) + self.flat = flat + self.shift = shift + self.alpha_powers = alpha_powers # Shift columns are always the first `n_shift` columns of the table. self.flat = Cols(zip(columns, self.flat)) self.next = Cols(zip(columns[: len(self.shift)], self.shift)) @@ -678,7 +682,7 @@ def assert_bool(self, x: EF) -> None: self.assert_zero(x * (ONE - x)) -def eval_precompile_bus_virtual_columns( +def eval_precompile_bus_in_air( evaluator: "ConstraintEvaluator", logup_beta_eq: list[EF], multiplicity: EF, @@ -689,11 +693,10 @@ def eval_precompile_bus_virtual_columns( evaluator.assert_zero(finger_print(domainsep, data, logup_beta_eq)) -def eval_air_execution(evaluator: ConstraintEvaluator, logup_beta_eq: list[EF]) -> None: - c, n = evaluator.flat, evaluator.next +def eval_air_execution_table(evaluator: ConstraintEvaluator, logup_beta_eq: list[EF]) -> None: (pc, fp, addr_a, addr_b, addr_c, value_a, value_b, value_c, operand_a, operand_b, operand_c, - flag_a, flag_b, flag_c, flag_c_fp, flag_ab_fp, flag_mul, flag_jump, aux_1, aux_2) = (c[k] for k in EXECUTION_COLUMNS) # fmt: skip - pc_shift, fp_shift = n["pc"], n["fp"] + flag_a, flag_b, flag_c, flag_c_fp, flag_ab_fp, flag_mul, flag_jump, aux_1, aux_2) = (evaluator.flat[k] for k in EXECUTION_COLUMNS) # fmt: skip + next_pc, next_fp = evaluator.next["pc"], evaluator.next["fp"] # nu_x = flag·operand + (1 − flag − flag_ab_fp)·value + flag_ab_fp·(fp + operand) nfa = ONE - flag_a - flag_ab_fp @@ -708,7 +711,7 @@ def eval_air_execution(evaluator: ConstraintEvaluator, logup_beta_eq: list[EF]) flag_deref = aux_1 * (aux_1 - ONE) * ((P + 1) // 2) # (P+1)/2 is the inverse of 2 mod P flag_precompile = ONE - flag_add - flag_mul - flag_deref - flag_jump - eval_precompile_bus_virtual_columns(evaluator, logup_beta_eq, flag_precompile, aux_2, [nu_a, nu_b, nu_c]) + eval_precompile_bus_in_air(evaluator, logup_beta_eq, flag_precompile, aux_2, [nu_a, nu_b, nu_c]) evaluator.assert_zero(nfa * (addr_a - (fp + operand_a))) evaluator.assert_zero(nfb * (addr_b - (fp + operand_b))) evaluator.assert_zero(nfc * (addr_c - (fp + operand_c))) @@ -716,25 +719,20 @@ def eval_air_execution(evaluator: ConstraintEvaluator, logup_beta_eq: list[EF]) evaluator.assert_zero(flag_mul * (nu_b - nu_a * nu_c)) evaluator.assert_zero(flag_deref * (addr_b - (value_a + operand_b))) evaluator.assert_zero(flag_deref * (value_b - nu_c)) - jc = flag_jump * nu_a - evaluator.assert_zero(jc * (nu_a - ONE)) - evaluator.assert_zero(jc * (pc_shift - nu_b)) - evaluator.assert_zero(jc * (fp_shift - nu_c)) - not_jc = ONE - jc - evaluator.assert_zero(not_jc * (pc_shift - (pc + ONE))) - evaluator.assert_zero(not_jc * (fp_shift - fp)) - - -def eval_air_extension(evaluator: ConstraintEvaluator, logup_beta_eq: list[EF]) -> None: - c, n = evaluator.flat, evaluator.next - flag_be, flag_start, len_col = c["flag_be"], c["flag_start"], c["len"] - flag_add, flag_dot_product, flag_eq = c["flag_add"], c["flag_dot_product"], c["flag_eq"] - idx_a, idx_b, idx_r = c["idx_a"], c["idx_b"], c["idx_r"] - acc, v_a, v_b, res = c.arr("acc", 5), c.arr("v_a", 5), c.arr("v_b", 5), c.arr("res", 5) - flag_be_sh, flag_start_sh, len_sh = n["flag_be"], n["flag_start"], n["len"] - flag_add_sh, flag_dot_product_sh, flag_eq_sh = n["flag_add"], n["flag_dot_product"], n["flag_eq"] - idx_a_sh, idx_b_sh = n["idx_a"], n["idx_b"] - acc_sh = n.arr("acc", 5) + jumping = flag_jump * nu_a + evaluator.assert_zero(jumping * (nu_a - ONE)) # nu_a (condition) should be boolean in case of JUMP instruction + evaluator.assert_zero(jumping * (next_pc - nu_b)) + evaluator.assert_zero(jumping * (next_fp - nu_c)) + not_jumping = ONE - jumping + evaluator.assert_zero(not_jumping * (next_pc - (pc + ONE))) + evaluator.assert_zero(not_jumping * (next_fp - fp)) + + +def eval_air_extension_table(evaluator: ConstraintEvaluator, logup_beta_eq: list[EF]) -> None: + (flag_be, flag_start, len_col, flag_add, flag_dot_product, flag_eq, idx_a, idx_b) = (evaluator.flat[k] for k in EXTENSION_COLUMNS[:8]) # fmt: skip + idx_r, acc, v_a, v_b, res = evaluator.flat["idx_r"], evaluator.flat.arr("acc", EF.DIMENSION), evaluator.flat.arr("v_a", EF.DIMENSION), evaluator.flat.arr("v_b", EF.DIMENSION), evaluator.flat.arr("res", EF.DIMENSION) # fmt: skip + flag_be_next, flag_start_next, len_next, flag_add_next, flag_dot_product_next, flag_eq_next, idx_a_next, idx_b_next = (evaluator.next[k] for k in EXTENSION_COLUMNS[:8]) # fmt: skip + acc_next = evaluator.next.arr("acc", EF.DIMENSION) aux_2 = ( flag_be * EXT_OP_FLAG_BE @@ -743,44 +741,46 @@ def eval_air_extension(evaluator: ConstraintEvaluator, logup_beta_eq: list[EF]) + flag_eq * EXT_OP_FLAG_EQ + len_col * EXT_OP_LEN_MULTIPLIER ) - eval_precompile_bus_virtual_columns( + eval_precompile_bus_in_air( evaluator, logup_beta_eq, flag_start * (flag_add + flag_dot_product + flag_eq), aux_2, [idx_a, idx_b, idx_r] ) for x in (flag_be, flag_start, flag_add, flag_dot_product, flag_eq): evaluator.assert_bool(x) - is_ee, not_start_sh = ONE - flag_be, ONE - flag_start_sh - v_a_tilde = [v_a[0]] + [v_a[k] * is_ee for k in range(1, 5)] - acc_tail = [acc_sh[k] * not_start_sh for k in range(5)] + is_ee, not_start_next = ONE - flag_be, ONE - flag_start_next + v_a_tilde = [v_a[0]] + [v_a[k] * is_ee for k in range(1, EF.DIMENSION)] + acc_tail = [acc_next[k] * not_start_next for k in range(EF.DIMENSION)] v_a_v_b = quintic_mul(v_a_tilde, v_b, ZERO) - for k in range(5): + for k in range(EF.DIMENSION): evaluator.assert_zero((acc[k] - (v_a_tilde[k] + v_b[k] + acc_tail[k])) * flag_add) - for k in range(5): + for k in range(EF.DIMENSION): evaluator.assert_zero((acc[k] - (v_a_v_b[k] + acc_tail[k])) * flag_dot_product) # eq: acc ← (2·v_a·v_b − v_a − v_b + 1) · (acc_tail or 1 at group end). - e_eq = [2 * v_a_v_b[k] - v_a_tilde[k] - v_b[k] + (ONE if k == 0 else ZERO) for k in range(5)] - acc_tail_or_one = [acc_sh[0] * not_start_sh + flag_start_sh] + [acc_sh[k] * not_start_sh for k in range(1, 5)] + e_eq = [2 * v_a_v_b[k] - v_a_tilde[k] - v_b[k] + (ONE if k == 0 else ZERO) for k in range(EF.DIMENSION)] + acc_tail_or_one = [acc_next[0] * not_start_next + flag_start_next] + [ + acc_next[k] * not_start_next for k in range(1, EF.DIMENSION) + ] eq_result = quintic_mul(e_eq, acc_tail_or_one, ZERO) - for k in range(5): + for k in range(EF.DIMENSION): evaluator.assert_zero((acc[k] - eq_result[k]) * flag_eq) - for k in range(5): + for k in range(EF.DIMENSION): evaluator.assert_zero((acc[k] - res[k]) * flag_start) for x, y in [ - (len_col, len_sh + ONE), - (flag_be, flag_be_sh), - (flag_add, flag_add_sh), - (flag_dot_product, flag_dot_product_sh), - (flag_eq, flag_eq_sh), + (len_col, len_next + ONE), + (flag_be, flag_be_next), + (flag_add, flag_add_next), + (flag_dot_product, flag_dot_product_next), + (flag_eq, flag_eq_next), ]: - evaluator.assert_zero(not_start_sh * (x - y)) + evaluator.assert_zero(not_start_next * (x - y)) - evaluator.assert_zero(not_start_sh * (idx_a_sh - idx_a - (flag_be + is_ee * 5))) - evaluator.assert_zero(not_start_sh * (idx_b_sh - idx_b - 5)) - evaluator.assert_zero(flag_start_sh * (len_col - ONE)) + evaluator.assert_zero(not_start_next * (idx_a_next - idx_a - (flag_be + is_ee * EF.DIMENSION))) + evaluator.assert_zero(not_start_next * (idx_b_next - idx_b - EF.DIMENSION)) + evaluator.assert_zero(flag_start_next * (len_col - ONE)) def _full_round(state: list[EF], rc1: list[Fp], rc2: list[Fp]) -> list[EF]: @@ -791,7 +791,7 @@ def _full_round(state: list[EF], rc1: list[Fp], rc2: list[Fp]) -> list[EF]: return state -def eval_air_poseidon16(evaluator: ConstraintEvaluator, logup_beta_eq: list[EF]) -> None: +def eval_air_poseidon16_table(evaluator: ConstraintEvaluator, logup_beta_eq: list[EF]) -> None: c = evaluator.flat half_pairs = POSEIDON_HALF_FULL_ROUNDS // 2 @@ -818,7 +818,7 @@ def eval_air_poseidon16(evaluator: ConstraintEvaluator, logup_beta_eq: list[EF]) not_flag_left = ONE - flag_left nu_a = addr_left_hi - not_flag_left * (DIGEST_ELEMS // 2) - eval_precompile_bus_virtual_columns(evaluator, logup_beta_eq, multiplicity, domainsep, [nu_a, nu_b, nu_c]) + eval_precompile_bus_in_air(evaluator, logup_beta_eq, multiplicity, domainsep, [nu_a, nu_b, nu_c]) for f in (multiplicity, flag_out4, flag_out8, flag_left, flag_permute): evaluator.assert_bool(f) evaluator.assert_zero(flag_permute * flag_out4) @@ -884,11 +884,11 @@ def eval_air_poseidon16(evaluator: ConstraintEvaluator, logup_beta_eq: list[EF]) EXTENSION_COLUMNS = ( "flag_be", "flag_start", "len", "flag_add", "flag_dot_product", "flag_eq", "idx_a", "idx_b", - *(f"acc_{i}" for i in range(5)), + *(f"acc_{i}" for i in range(EF.DIMENSION)), "idx_r", - *(f"v_a_{i}" for i in range(5)), - *(f"v_b_{i}" for i in range(5)), - *(f"res_{i}" for i in range(5)), + *(f"v_a_{i}" for i in range(EF.DIMENSION)), + *(f"v_b_{i}" for i in range(EF.DIMENSION)), + *(f"res_{i}" for i in range(EF.DIMENSION)), ) # fmt: skip POSEIDON_COLUMNS = ( @@ -916,22 +916,22 @@ def eval_air_poseidon16(evaluator: ConstraintEvaluator, logup_beta_eq: list[EF]) n_constraints=14, n_shift=2, max_log_height=24, - air_constraints_fn=eval_air_execution, + air_constraints_fn=eval_air_execution_table, ), Table( name="extension", columns=EXTENSION_COLUMNS, buses=( BusInteraction(BusDirection.PULL), - BusInteraction(BusDirection.PULL, LOGUP_MEMORY_DOMAINSEP, ("idx_a", "v_a_0"), 5), - BusInteraction(BusDirection.PULL, LOGUP_MEMORY_DOMAINSEP, ("idx_b", "v_b_0"), 5), - BusInteraction(BusDirection.PULL, LOGUP_MEMORY_DOMAINSEP, ("idx_r", "res_0"), 5), + BusInteraction(BusDirection.PULL, LOGUP_MEMORY_DOMAINSEP, ("idx_a", "v_a_0"), EF.DIMENSION), + BusInteraction(BusDirection.PULL, LOGUP_MEMORY_DOMAINSEP, ("idx_b", "v_b_0"), EF.DIMENSION), + BusInteraction(BusDirection.PULL, LOGUP_MEMORY_DOMAINSEP, ("idx_r", "res_0"), EF.DIMENSION), ), air_degree=6, n_constraints=35, n_shift=13, max_log_height=21, - air_constraints_fn=eval_air_extension, + air_constraints_fn=eval_air_extension_table, ), Table( name="poseidon", @@ -947,7 +947,7 @@ def eval_air_poseidon16(evaluator: ConstraintEvaluator, logup_beta_eq: list[EF]) n_constraints=101, n_shift=0, max_log_height=21, - air_constraints_fn=eval_air_poseidon16, + air_constraints_fn=eval_air_poseidon16_table, ), ] @@ -1018,7 +1018,9 @@ def verify_execution( initial_sum, offset = ZERO, 0 for table in TABLES: - initial_sum += alpha_powers[offset] * (logup["precompile_nums"][table.name] * table.precompile_bus_interaction_sign) + initial_sum += alpha_powers[offset] * ( + logup["precompile_nums"][table.name] * table.precompile_bus_interaction_sign + ) initial_sum += alpha_powers[offset + 1] * (logup_gamma - logup["precompile_dens"][table.name]) offset += table.n_constraints sc_point, sc_value = verify_sumcheck(state, initial_sum, n_max, max(t.air_degree + 1 for t in TABLES)) From b448d04250980dc22e678c2ebd95d80095206ad6 Mon Sep 17 00:00:00 2001 From: Tom Wambsgans Date: Wed, 3 Jun 2026 18:54:50 +0400 Subject: [PATCH 23/31] wip --- .../lean_prover/python-verifier/verifier.py | 21 ++++--------------- 1 file changed, 4 insertions(+), 17 deletions(-) diff --git a/crates/lean_prover/python-verifier/verifier.py b/crates/lean_prover/python-verifier/verifier.py index 3f897c00..fe3ee6b7 100644 --- a/crates/lean_prover/python-verifier/verifier.py +++ b/crates/lean_prover/python-verifier/verifier.py @@ -757,7 +757,6 @@ def eval_air_extension_table(evaluator: ConstraintEvaluator, logup_beta_eq: list evaluator.assert_zero((acc[k] - (v_a_tilde[k] + v_b[k] + acc_tail[k])) * flag_add) for k in range(EF.DIMENSION): evaluator.assert_zero((acc[k] - (v_a_v_b[k] + acc_tail[k])) * flag_dot_product) - # eq: acc ← (2·v_a·v_b − v_a − v_b + 1) · (acc_tail or 1 at group end). e_eq = [2 * v_a_v_b[k] - v_a_tilde[k] - v_b[k] + (ONE if k == 0 else ZERO) for k in range(EF.DIMENSION)] acc_tail_or_one = [acc_next[0] * not_start_next + flag_start_next] + [ @@ -766,16 +765,11 @@ def eval_air_extension_table(evaluator: ConstraintEvaluator, logup_beta_eq: list eq_result = quintic_mul(e_eq, acc_tail_or_one, ZERO) for k in range(EF.DIMENSION): evaluator.assert_zero((acc[k] - eq_result[k]) * flag_eq) + for k in range(EF.DIMENSION): evaluator.assert_zero((acc[k] - res[k]) * flag_start) - for x, y in [ - (len_col, len_next + ONE), - (flag_be, flag_be_next), - (flag_add, flag_add_next), - (flag_dot_product, flag_dot_product_next), - (flag_eq, flag_eq_next), - ]: + for x, y in [(len_col, len_next + ONE), (flag_be, flag_be_next), (flag_add, flag_add_next), (flag_dot_product, flag_dot_product_next), (flag_eq, flag_eq_next)]: # fmt: skip evaluator.assert_zero(not_start_next * (x - y)) evaluator.assert_zero(not_start_next * (idx_a_next - idx_a - (flag_be + is_ee * EF.DIMENSION))) @@ -826,22 +820,16 @@ def eval_air_poseidon16_table(evaluator: ConstraintEvaluator, logup_beta_eq: lis evaluator.assert_zero((ONE - flag_permute) * (ONE - flag_out8) * (ONE - flag_out4)) evaluator.assert_zero(flag_left * (offset_left - addr_left_lo)) evaluator.assert_zero(not_flag_left * (nu_a - addr_left_lo)) - - # --- Poseidon1-16 permutation AIR: each committed `post` row pins the intermediate - # state then re-binds it, capping polynomial degree across the long round sequence. state = list(inputs) - - # Beginning full rounds, paired up. + # 2-by2 initial full rounds for r in range(half_pairs): state = _full_round(state, POSEIDON_AIR_INITIAL_CONSTANTS[2 * r], POSEIDON_AIR_INITIAL_CONSTANTS[2 * r + 1]) for i, post in enumerate(beginning_full_rounds[r]): evaluator.assert_eq(state[i], post) state[i] = post - # Transition into sparse partial-round form. state = [s + rc for s, rc in zip(state, POSEIDON_AIR_SPARSE_FIRST_RC)] state = [dot_product(state, row) for row in POSEIDON_AIR_SPARSE_M_I] - # Partial rounds: one sbox on lane 0, then sparse mat-vec. for r in range(POSEIDON_PARTIAL_ROUNDS): evaluator.assert_eq(state[0].cube(), partial_cols[r]) @@ -852,8 +840,7 @@ def eval_air_poseidon16_table(evaluator: ConstraintEvaluator, logup_beta_eq: lis state[0] = dot_product(state, POSEIDON_AIR_SPARSE_FIRST_ROW[r]) for i in range(1, POSEIDON_WIDTH): state[i] += old_s0 * POSEIDON_AIR_SPARSE_V[r][i - 1] - - # Ending full rounds (all but the last pair) commit intermediate state. + # 2-by2 final full rounds for r in range(half_pairs - 1): state = _full_round(state, POSEIDON_AIR_FINAL_CONSTANTS[2 * r], POSEIDON_AIR_FINAL_CONSTANTS[2 * r + 1]) for i, post in enumerate(ending_full_rounds[r]): From 44b84e40a78c611e74a33d90dc9c65235e47bc5b Mon Sep 17 00:00:00 2001 From: Tom Wambsgans Date: Wed, 3 Jun 2026 19:16:44 +0400 Subject: [PATCH 24/31] wip --- .../lean_prover/python-verifier/primitives.py | 3 +- .../lean_prover/python-verifier/verifier.py | 94 ++++++------------- 2 files changed, 31 insertions(+), 66 deletions(-) diff --git a/crates/lean_prover/python-verifier/primitives.py b/crates/lean_prover/python-verifier/primitives.py index ec293371..66fc15da 100644 --- a/crates/lean_prover/python-verifier/primitives.py +++ b/crates/lean_prover/python-verifier/primitives.py @@ -301,7 +301,8 @@ def div_ceil(n: int, k: int) -> int: POSEIDON_FULL_ROUNDS = 8 POSEIDON_WIDTH = 16 POSEIDON_PARTIAL_ROUNDS = 20 -POSEIDON_HALF_FULL_ROUNDS = POSEIDON_FULL_ROUNDS // 2 # = 4 full rounds per side +POSEIDON_HALF_FULL_ROUNDS = POSEIDON_FULL_ROUNDS // 2 # = 4 full rounds per side (initial / final) +POSEIDON_QUARTER_FULL_ROUNDS = POSEIDON_HALF_FULL_ROUNDS // 2 def _mat_mul(a: list[list[int]], b: list[list[int]], n: int) -> list[list[int]]: diff --git a/crates/lean_prover/python-verifier/verifier.py b/crates/lean_prover/python-verifier/verifier.py index fe3ee6b7..87f74e6b 100644 --- a/crates/lean_prover/python-verifier/verifier.py +++ b/crates/lean_prover/python-verifier/verifier.py @@ -734,16 +734,8 @@ def eval_air_extension_table(evaluator: ConstraintEvaluator, logup_beta_eq: list flag_be_next, flag_start_next, len_next, flag_add_next, flag_dot_product_next, flag_eq_next, idx_a_next, idx_b_next = (evaluator.next[k] for k in EXTENSION_COLUMNS[:8]) # fmt: skip acc_next = evaluator.next.arr("acc", EF.DIMENSION) - aux_2 = ( - flag_be * EXT_OP_FLAG_BE - + flag_add * EXT_OP_FLAG_ADD - + flag_dot_product * EXT_OP_FLAG_DOT_PRODUCT - + flag_eq * EXT_OP_FLAG_EQ - + len_col * EXT_OP_LEN_MULTIPLIER - ) - eval_precompile_bus_in_air( - evaluator, logup_beta_eq, flag_start * (flag_add + flag_dot_product + flag_eq), aux_2, [idx_a, idx_b, idx_r] - ) + aux_2 = flag_be * EXT_OP_FLAG_BE + flag_add * EXT_OP_FLAG_ADD + flag_dot_product * EXT_OP_FLAG_DOT_PRODUCT + flag_eq * EXT_OP_FLAG_EQ + len_col * EXT_OP_LEN_MULTIPLIER # fmt: skip + eval_precompile_bus_in_air(evaluator, logup_beta_eq, flag_start * (flag_add + flag_dot_product + flag_eq), aux_2, [idx_a, idx_b, idx_r]) # fmt: skip for x in (flag_be, flag_start, flag_add, flag_dot_product, flag_eq): evaluator.assert_bool(x) @@ -777,8 +769,7 @@ def eval_air_extension_table(evaluator: ConstraintEvaluator, logup_beta_eq: list evaluator.assert_zero(flag_start_next * (len_col - ONE)) -def _full_round(state: list[EF], rc1: list[Fp], rc2: list[Fp]) -> list[EF]: - """Two consecutive Poseidon full rounds, fused as one AIR step.""" +def do_2_full_round(state: list[EF], rc1: list[Fp], rc2: list[Fp]) -> list[EF]: for rc in (rc1, rc2): sbox = [(s + c).cube() for s, c in zip(state, rc)] state = [dot_product(sbox, row) for row in POSEIDON_AIR_MDS_DENSE] @@ -786,32 +777,16 @@ def _full_round(state: list[EF], rc1: list[Fp], rc2: list[Fp]) -> list[EF]: def eval_air_poseidon16_table(evaluator: ConstraintEvaluator, logup_beta_eq: list[EF]) -> None: - c = evaluator.flat - half_pairs = POSEIDON_HALF_FULL_ROUNDS // 2 - - multiplicity = c["multiplicity"] - nu_b, nu_c = c["nu_b"], c["nu_c"] - flag_out4, flag_out8, flag_left = c["flag_out4"], c["flag_out8"], c["flag_left"] - offset_left = c["offset_left"] - addr_left_lo, addr_left_hi = c["addr_left_lo"], c["addr_left_hi"] - flag_permute = c["flag_permute"] - inputs = c.arr("input", POSEIDON_WIDTH) - beginning_full_rounds = [c.arr(f"begin_r{r}", POSEIDON_WIDTH) for r in range(half_pairs)] - partial_cols = c.arr("partial", POSEIDON_PARTIAL_ROUNDS) - ending_full_rounds = [c.arr(f"end_r{r}", POSEIDON_WIDTH) for r in range(half_pairs - 1)] - out_lo = c.arr("out_lo", POSEIDON_WIDTH // 2) - out_hi = c.arr("out_hi", POSEIDON_WIDTH // 2) - - domainsep = ( - POSEIDON_DOMAINSEP_BASE - + flag_permute * POSEIDON_FLAG_PERMUTE_SHIFT - + flag_out8 * POSEIDON_FLAG_OUT8_SHIFT - + flag_left * POSEIDON_FLAG_LEFT_SHIFT - + flag_left * offset_left * POSEIDON_OFFSET_LEFT_SHIFT - ) + multiplicity, nu_b, nu_c , flag_out4, flag_out8, flag_left, offset_left, addr_left_lo, addr_left_hi, flag_permute = (evaluator.flat[k] for k in POSEIDON_COLUMNS[:10]) # fmt: skip + inputs = evaluator.flat.arr("input", POSEIDON_WIDTH) + beginning_full_rounds = [evaluator.flat.arr(f"begin_r{r}", POSEIDON_WIDTH) for r in range(POSEIDON_QUARTER_FULL_ROUNDS)] # fmt: skip + partial_cols = evaluator.flat.arr("partial", POSEIDON_PARTIAL_ROUNDS) + ending_full_rounds = [evaluator.flat.arr(f"end_r{r}", POSEIDON_WIDTH) for r in range(POSEIDON_QUARTER_FULL_ROUNDS - 1)] # fmt: skip + out_lo, out_hi = evaluator.flat.arr("out_lo", POSEIDON_WIDTH // 2), evaluator.flat.arr("out_hi", POSEIDON_WIDTH // 2) # fmt: skip + + domainsep = POSEIDON_DOMAINSEP_BASE + flag_permute * POSEIDON_FLAG_PERMUTE_SHIFT + flag_out8 * POSEIDON_FLAG_OUT8_SHIFT + flag_left * POSEIDON_FLAG_LEFT_SHIFT + flag_left * offset_left * POSEIDON_OFFSET_LEFT_SHIFT # fmt: skip not_flag_left = ONE - flag_left nu_a = addr_left_hi - not_flag_left * (DIGEST_ELEMS // 2) - eval_precompile_bus_in_air(evaluator, logup_beta_eq, multiplicity, domainsep, [nu_a, nu_b, nu_c]) for f in (multiplicity, flag_out4, flag_out8, flag_left, flag_permute): evaluator.assert_bool(f) @@ -821,16 +796,15 @@ def eval_air_poseidon16_table(evaluator: ConstraintEvaluator, logup_beta_eq: lis evaluator.assert_zero(flag_left * (offset_left - addr_left_lo)) evaluator.assert_zero(not_flag_left * (nu_a - addr_left_lo)) state = list(inputs) - # 2-by2 initial full rounds - for r in range(half_pairs): - state = _full_round(state, POSEIDON_AIR_INITIAL_CONSTANTS[2 * r], POSEIDON_AIR_INITIAL_CONSTANTS[2 * r + 1]) + # 2-by-2 initial full rounds + for r in range(POSEIDON_QUARTER_FULL_ROUNDS): + state = do_2_full_round(state, POSEIDON_AIR_INITIAL_CONSTANTS[2 * r], POSEIDON_AIR_INITIAL_CONSTANTS[2 * r + 1]) for i, post in enumerate(beginning_full_rounds[r]): evaluator.assert_eq(state[i], post) state[i] = post - # Transition into sparse partial-round form. + # partial-rounds (using the sparse decomposition, see Appendix of [Poseidon1](https://eprint.iacr.org/2019/458)) state = [s + rc for s, rc in zip(state, POSEIDON_AIR_SPARSE_FIRST_RC)] state = [dot_product(state, row) for row in POSEIDON_AIR_SPARSE_M_I] - # Partial rounds: one sbox on lane 0, then sparse mat-vec. for r in range(POSEIDON_PARTIAL_ROUNDS): evaluator.assert_eq(state[0].cube(), partial_cols[r]) state[0] = partial_cols[r] @@ -840,52 +814,42 @@ def eval_air_poseidon16_table(evaluator: ConstraintEvaluator, logup_beta_eq: lis state[0] = dot_product(state, POSEIDON_AIR_SPARSE_FIRST_ROW[r]) for i in range(1, POSEIDON_WIDTH): state[i] += old_s0 * POSEIDON_AIR_SPARSE_V[r][i - 1] - # 2-by2 final full rounds - for r in range(half_pairs - 1): - state = _full_round(state, POSEIDON_AIR_FINAL_CONSTANTS[2 * r], POSEIDON_AIR_FINAL_CONSTANTS[2 * r + 1]) + # 2-by-2 final full rounds + for r in range(POSEIDON_QUARTER_FULL_ROUNDS - 1): + state = do_2_full_round(state, POSEIDON_AIR_FINAL_CONSTANTS[2 * r], POSEIDON_AIR_FINAL_CONSTANTS[2 * r + 1]) for i, post in enumerate(ending_full_rounds[r]): evaluator.assert_eq(state[i], post) state[i] = post - - # Last full round: compression feeds `inputs` forward into out_lo (permute does not). - # out_lo[4..8] is real unless the output is 4 elements (out4); out_hi (capacity) is only - # written by the full 16-element permutation (out16 = neither out8 nor out4). - last = 2 * (half_pairs - 1) - state = _full_round(state, POSEIDON_AIR_FINAL_CONSTANTS[last], POSEIDON_AIR_FINAL_CONSTANTS[last + 1]) - not_permute = ONE - flag_permute - gate_lo_8 = ONE - flag_out4 - gate_hi = ONE - flag_out8 - flag_out4 + # Last full round + state = do_2_full_round(state, POSEIDON_AIR_FINAL_CONSTANTS[-2], POSEIDON_AIR_FINAL_CONSTANTS[-1]) + not_permute, gate_out_4_to_8, gate_hi = ONE - flag_permute, ONE - flag_out4, ONE - flag_out8 - flag_out4 for i in range(POSEIDON_WIDTH // 2): - value = state[i] + not_permute * inputs[i] + value = state[i] + not_permute * inputs[i] # when it's not permutation -> it's a compression (feedforward) if i < (DIGEST_ELEMS // 2): evaluator.assert_zero(value - out_lo[i]) else: - evaluator.assert_zero(gate_lo_8 * (value - out_lo[i])) + evaluator.assert_zero(gate_out_4_to_8 * (value - out_lo[i])) evaluator.assert_zero(gate_hi * (state[i + POSEIDON_WIDTH // 2] - out_hi[i])) EXECUTION_COLUMNS = ( - "pc", "fp", "addr_a", "addr_b", "addr_c", "value_a", "value_b", "value_c", # 8 runtime cols + "pc", "fp", # 'next' columns (the rest are 'flat') + "addr_a", "addr_b", "addr_c", "value_a", "value_b", "value_c", # 8 runtime cols "operand_a", "operand_b", "operand_c", "flag_a", "flag_b", "flag_c", "flag_c_fp", "flag_ab_fp", "flag_mul", "flag_jump", "aux_1", "aux_2", # 12 instruction cols. ) # fmt: skip EXTENSION_COLUMNS = ( - "flag_be", "flag_start", "len", "flag_add", "flag_dot_product", "flag_eq", "idx_a", "idx_b", - *(f"acc_{i}" for i in range(EF.DIMENSION)), - "idx_r", - *(f"v_a_{i}" for i in range(EF.DIMENSION)), - *(f"v_b_{i}" for i in range(EF.DIMENSION)), - *(f"res_{i}" for i in range(EF.DIMENSION)), + "flag_be", "flag_start", "len", "flag_add", "flag_dot_product", "flag_eq", "idx_a", "idx_b", *(f"acc_{i}" for i in range(EF.DIMENSION)), # 'next' columns + "idx_r", *(f"v_a_{i}" for i in range(EF.DIMENSION)), *(f"v_b_{i}" for i in range(EF.DIMENSION)), *(f"res_{i}" for i in range(EF.DIMENSION)), # # 'flat' columns ) # fmt: skip -POSEIDON_COLUMNS = ( +POSEIDON_COLUMNS = ( # all 'flat' columns "multiplicity", "nu_b", "nu_c", "flag_out4", "flag_out8", "flag_left", "offset_left", "addr_left_lo", "addr_left_hi", "flag_permute", *(f"input_{i}" for i in range(POSEIDON_WIDTH)), *(f"begin_r{r}_{i}" for r in range(POSEIDON_HALF_FULL_ROUNDS // 2) for i in range(POSEIDON_WIDTH)), *(f"partial_{i}" for i in range(POSEIDON_PARTIAL_ROUNDS)), *(f"end_r{r}_{i}" for r in range(POSEIDON_HALF_FULL_ROUNDS // 2 - 1) for i in range(POSEIDON_WIDTH)), - *(f"out_lo_{i}" for i in range(POSEIDON_WIDTH // 2)), - *(f"out_hi_{i}" for i in range(POSEIDON_WIDTH // 2)), + *(f"out_lo_{i}" for i in range(POSEIDON_WIDTH // 2)), *(f"out_hi_{i}" for i in range(POSEIDON_WIDTH // 2)), # lo: [0:8], hi: [8:16] ) # fmt: skip TABLES = [ From 8a2e826bc0c43026f813e0becfa617ee18ec0126 Mon Sep 17 00:00:00 2001 From: Tom Wambsgans Date: Wed, 3 Jun 2026 19:25:11 +0400 Subject: [PATCH 25/31] wip --- .../lean_prover/python-verifier/verifier.py | 45 +++++++++---------- 1 file changed, 22 insertions(+), 23 deletions(-) diff --git a/crates/lean_prover/python-verifier/verifier.py b/crates/lean_prover/python-verifier/verifier.py index 87f74e6b..f2073225 100644 --- a/crates/lean_prover/python-verifier/verifier.py +++ b/crates/lean_prover/python-verifier/verifier.py @@ -205,11 +205,11 @@ def _read_padded(self, n: int) -> list[Fp]: def observe_scalars(self, scalars: Sequence[Fp]) -> None: self.observe_many(list(scalars)) - def next_base_scalars_vec(self, n: int) -> list[Fp]: + def next_base_scalars(self, n: int) -> list[Fp]: return self._read_padded(n)[:n] def next_extension_scalars_vec(self, n: int) -> list[EF]: - flat = self.next_base_scalars_vec(n * EF.DIMENSION) + flat = self.next_base_scalars(n * EF.DIMENSION) return embed_ef(flat) def next_extension_scalar(self) -> EF: @@ -354,7 +354,7 @@ class WhirCommitment: def read(cls, fs: "FiatShamir", num_variables: int, n_ood: int) -> "WhirCommitment": return cls( num_variables, - fs.next_base_scalars_vec(DIGEST_ELEMS), + fs.next_base_scalars(DIGEST_ELEMS), fs.sample_many_ef(n_ood), fs.next_extension_scalars_vec(n_ood), ) @@ -904,9 +904,9 @@ def eval_air_poseidon16_table(evaluator: ConstraintEvaluator, logup_beta_eq: lis def verify_execution( + bytecode_multilinear: list[int], # trusted-source (and thus contains only valid instructions) public_input: Sequence[Fp], proof: Proof, - bytecode_multilinear: list[int], ) -> None: bytecode_log_size = log2_strict(len(bytecode_multilinear)) - log2_ceil(N_INSTRUCTION_COLUMNS) ending_pc = (1 << bytecode_log_size) - 1 @@ -914,10 +914,9 @@ def verify_execution( if len(public_input) != PUBLIC_INPUT_SIZE: raise ProofError("InvalidProof: public_input length mismatch") - state = FiatShamir(proof, poseidon16_compress(bytecode_hash, SNARK_DOMAIN_SEP)) # domain separator across bytecodes - state.observe_scalars(public_input) - dims = [int(x.value) for x in state.next_base_scalars_vec(2 + len(TABLES))] - log_inv_rate, log_memory, *table_log_n_rows = dims + fiat_shamir = FiatShamir(proof, poseidon16_compress(bytecode_hash, SNARK_DOMAIN_SEP)) # domain separator across bytecodes + fiat_shamir.observe_scalars(public_input) + log_inv_rate, log_memory, *table_log_n_rows = [int(x.value) for x in fiat_shamir.next_base_scalars(2 + len(TABLES))] if not MIN_WHIR_LOG_INV_RATE <= log_inv_rate <= MAX_WHIR_LOG_INV_RATE: raise ProofError("InvalidRate") if not MIN_LOG_MEMORY_SIZE <= log_memory <= MAX_LOG_MEMORY_SIZE: @@ -946,14 +945,14 @@ def verify_execution( raise ProofError("InvalidProof: stacked_n_vars exceeds WHIR domain bound") cfg = WHIR_CONFIGS[(log_inv_rate, stacked_n_vars)] nood = cfg["commitment_ood_samples"] - parsed_commitment = WhirCommitment.read(state, stacked_n_vars, nood) + parsed_commitment = WhirCommitment.read(fiat_shamir, stacked_n_vars, nood) - logup_gamma = state.sample_ef() # the quotient denominator - state.duplex() - logup_beta = state.sample_many_ef(log2_ceil(N_INSTRUCTION_COLUMNS + 2)) # the bus-tuple hashing seeds + logup_gamma = fiat_shamir.sample_ef() # the quotient denominator + fiat_shamir.duplex() + logup_beta = fiat_shamir.sample_many_ef(log2_ceil(N_INSTRUCTION_COLUMNS + 2)) # the bus-tuple hashing seeds logup_beta_eq = eval_eq(logup_beta) logup = verify_generic_logup( - state, + fiat_shamir, logup_gamma, logup_beta, logup_beta_eq, @@ -964,7 +963,7 @@ def verify_execution( ) gkr_point = logup["gkr_point"] - air_alpha = state.sample_ef() + air_alpha = fiat_shamir.sample_ef() alpha_powers = ef_powers(air_alpha, sum(t.n_constraints for t in TABLES)) initial_sum, offset = ZERO, 0 @@ -974,13 +973,13 @@ def verify_execution( ) initial_sum += alpha_powers[offset + 1] * (logup_gamma - logup["precompile_dens"][table.name]) offset += table.n_constraints - sc_point, sc_value = verify_sumcheck(state, initial_sum, n_max, max(t.air_degree + 1 for t in TABLES)) + sc_point, sc_value = verify_sumcheck(fiat_shamir, initial_sum, n_max, max(t.air_degree + 1 for t in TABLES)) committed = {t.name: [(gkr_point[-log_heights[t.name] :], logup["columns_evals"][t.name], {})] for t in TABLES} my_air_final, offset = ZERO, 0 for table in TABLES: log_n_rows = log_heights[table.name] - col_evals = state.next_extension_scalars_vec(table.n_columns + table.n_shift) + col_evals = fiat_shamir.next_extension_scalars_vec(table.n_columns + table.n_shift) alphas = alpha_powers[offset : offset + table.n_constraints] offset += table.n_constraints constraint_eval = table.eval_air(col_evals, alphas, logup_beta_eq) @@ -995,7 +994,7 @@ def verify_execution( if my_air_final != sc_value: raise ProofError("AIR sumcheck: claimed value mismatch") - pm_point = state.sample_many_ef(log2_strict(PUBLIC_INPUT_SIZE)) + pm_point = fiat_shamir.sample_many_ef(log2_strict(PUBLIC_INPUT_SIZE)) pm_eval = eval_multilinear_by_evals(public_input, pm_point) bytecode_acc_idx = (2 << log_memory) >> bytecode_log_size @@ -1020,14 +1019,14 @@ def verify_execution( committed, ending_pc, ) - verify_whir(state, cfg, parsed_commitment, global_statements) + verify_whir(fiat_shamir, cfg, parsed_commitment, global_statements) - if state.offset != len(state.transcript): + if fiat_shamir.offset != len(fiat_shamir.transcript): raise ProofError( - f"InvalidProof: transcript not fully consumed ({state.offset}/{len(state.transcript)} scalars read)" + f"InvalidProof: transcript not fully consumed ({fiat_shamir.offset}/{len(fiat_shamir.transcript)} scalars read)" ) - if state.openings: - raise ProofError(f"InvalidProof: {len(state.openings)} Merkle openings unused") + if fiat_shamir.openings: + raise ProofError(f"InvalidProof: {len(fiat_shamir.openings)} Merkle openings unused") def main() -> int: @@ -1057,7 +1056,7 @@ def main() -> int: ) try: - verify_execution(public_input, proof, bytecode_multilinear) + verify_execution(bytecode_multilinear, public_input, proof) except ProofError as e: print(f"FAIL: {e}") return 1 From dcb2e427548a69d70bd48d7661d34f772af357d1 Mon Sep 17 00:00:00 2001 From: Tom Wambsgans Date: Wed, 3 Jun 2026 19:36:57 +0400 Subject: [PATCH 26/31] use assert instead of custom errors --- .../lean_prover/python-verifier/verifier.py | 100 +++++------------- 1 file changed, 27 insertions(+), 73 deletions(-) diff --git a/crates/lean_prover/python-verifier/verifier.py b/crates/lean_prover/python-verifier/verifier.py index f2073225..08242f8d 100644 --- a/crates/lean_prover/python-verifier/verifier.py +++ b/crates/lean_prover/python-verifier/verifier.py @@ -53,10 +53,6 @@ STARTING_PC = 0 # every program starts at PC = 0, and ends at PC = len(bytecode) - 1 -class ProofError(Exception): - pass - - class BusDirection(IntEnum): PUSH = 1 PULL = -1 @@ -147,7 +143,7 @@ def duplex(self) -> None: self.observe([Fp(0)] * SPONGE_RATE) def _sample_rate(self) -> list[Fp]: - assert self.rate_fresh, "stale rate — insert duplex() before sampling" + assert self.rate_fresh, "stale rate: insert duplex() before sampling" # unreachable self.rate_fresh = False return self.state[SPONGE_CAPACITY:] @@ -193,12 +189,10 @@ def __init__(self, proof: Proof, initial_capacity: Sequence[Fp]) -> None: def _read_padded(self, n: int) -> list[Fp]: n_pad = next_multiple_of(n, SPONGE_RATE) - if self.offset + n_pad > len(self.transcript): - raise ProofError("ExceededTranscript") + assert self.offset + n_pad <= len(self.transcript), "Exceeded Transcript" chunk = self.transcript[self.offset : self.offset + n_pad] self.offset += n_pad - if any(int(chunk[i].value) for i in range(n, n_pad)): - raise ProofError("InvalidTranscript: non-zero padding") + assert all(int(chunk[i].value) == 0 for i in range(n, n_pad)), "InvalidTranscript: non-zero padding" self.observe_many(chunk) return chunk @@ -216,16 +210,14 @@ def next_extension_scalar(self) -> EF: return self.next_extension_scalars_vec(1)[0] def next_merkle_opening(self) -> MerkleOpening: - if not self.openings: - raise ProofError("ExceededTranscript: no more Merkle openings") + assert self.openings, "Exceeded Transcript: no more Merkle openings" return self.openings.pop() def check_pow_grinding(self, bits: int) -> None: if bits == 0: return self._read_padded(SPONGE_RATE) - if int(self.state[SPONGE_CAPACITY].value) & ((1 << bits) - 1) != 0: - raise ProofError("InvalidGrindingWitness") + assert int(self.state[SPONGE_CAPACITY].value) & ((1 << bits) - 1) == 0, "Invalid Grinding Witness" def merkle_verify_path( @@ -235,15 +227,13 @@ def merkle_verify_path( opened_values: Sequence[Fp], opening_proof: Sequence[list[Fp]], ) -> None: - if len(opening_proof) != log_height: - raise ProofError("Merkle verification failed: opening proof has wrong length") + assert len(opening_proof) == log_height, "Merkle verification failed: opening proof has wrong length" chunks = [list(opened_values[i : i + SPONGE_RATE]) for i in range(0, len(opened_values), SPONGE_RATE)] current = sponge_hash([x for c in reversed(chunks) for x in c]) for sibling in opening_proof: current = poseidon16_compress(current, sibling) if index & 1 == 0 else poseidon16_compress(sibling, current) index >>= 1 - if root != current: - raise ProofError("Merkle verification failed: root mismatch") + assert root == current, "Merkle verification failed: root mismatch" def expand_from_univariate(x: EF, num_variables: int) -> list[EF]: @@ -373,8 +363,7 @@ def verify_sumcheck( for _ in range(n_rounds): coeffs = fiat_shamir.next_extension_scalars_vec(degree + 1) s = coeffs[0] + sum(coeffs) # s = h(0) + h(1) - if s != target: - raise ProofError("Sumcheck identity failed: h(0) + h(1) != target") + assert s == target, "Sumcheck identity failed: h(0) + h(1) != target" fiat_shamir.check_pow_grinding(pow_bits) challenge = fiat_shamir.sample_ef() point.append(challenge) @@ -440,8 +429,7 @@ def verify_whir( commitment = new_commitment for smt in final_stir_constraints: univ_eval = eval_univariate_polynomial(final_coeffs, smt.point[0]) - if any(univ_eval != v[1] for v in smt.values): - raise ProofError("Final STIR constraint mismatch") + assert all(univ_eval == v[1] for v in smt.values), "Final STIR constraint mismatch" final_sc_point, final_sc_value = verify_sumcheck(fiat_shamir, target, current_vars, 2) folding_challenges += final_sc_point @@ -460,8 +448,7 @@ def verify_whir( eval_weights += eval_prefix * eval_suffix * gamma_power gamma_power *= gamma final_value = eval_multilinear_by_coeffs(final_coeffs, list(reversed(final_sc_point))) - if final_sc_value != eval_weights * final_value: - raise ProofError("WHIR final sumcheck check failed") + assert final_sc_value == eval_weights * final_value, "WHIR final sumcheck check failed" def stacked_pcs_global_statements( @@ -516,8 +503,7 @@ def verify_gkr_quotient(fiat_shamir: FiatShamir, n_vars: int) -> tuple[EF, list[ sc_point, sc_value = verify_sumcheck(fiat_shamir, claim_num + alpha * claim_den, layer_n_vars, 3) sc_point = list(reversed(sc_point)) nl, nr, dl, dr = fiat_shamir.next_extension_scalars_vec(4) - if sc_value != eq_poly(point, sc_point) * (alpha * dl * dr + nl * dr + nr * dl): - raise ProofError("GKR step: postponed value mismatch") + assert sc_value == eq_poly(point, sc_point) * (alpha * dl * dr + nl * dr + nr * dl), "GKR step: postponed value mismatch" beta = fiat_shamir.sample_ef() one_minus = ONE - beta claim_num = one_minus * nl + beta * nr @@ -561,8 +547,7 @@ def verify_generic_logup( logup_n_vars = log2_ceil(total_active_len) quotient, gkr_point, claim_num, claim_den = verify_gkr_quotient(fiat_shamir, logup_n_vars) - if quotient != ZERO: - raise ProofError("imbalanced logup bus") + assert quotient == ZERO, "imbalanced logup bus" def pref_at(offset: int, log_height: int) -> EF: n_missing = logup_n_vars - log_height @@ -641,10 +626,8 @@ def request_column_evals_dedup(cols: Sequence[int]) -> list[EF]: offset += 1 << table_heights[table.name] den += mle_of_zeros_then_ones(final_offset, gkr_point) - if num != claim_num: - raise ProofError("logup: numerators value mismatch") - if den != claim_den: - raise ProofError("logup: denominators value mismatch") + assert num == claim_num, "logup: numerators value mismatch" + assert den == claim_den, "logup: denominators value mismatch" return { "memory_eval": memory_eval, "memory_acc_eval": memory_acc_eval, @@ -911,25 +894,17 @@ def verify_execution( bytecode_log_size = log2_strict(len(bytecode_multilinear)) - log2_ceil(N_INSTRUCTION_COLUMNS) ending_pc = (1 << bytecode_log_size) - 1 bytecode_hash = sponge_hash([Fp(v) for v in bytecode_multilinear]) - if len(public_input) != PUBLIC_INPUT_SIZE: - raise ProofError("InvalidProof: public_input length mismatch") + assert len(public_input) == PUBLIC_INPUT_SIZE, "InvalidProof: public_input length mismatch" fiat_shamir = FiatShamir(proof, poseidon16_compress(bytecode_hash, SNARK_DOMAIN_SEP)) # domain separator across bytecodes fiat_shamir.observe_scalars(public_input) log_inv_rate, log_memory, *table_log_n_rows = [int(x.value) for x in fiat_shamir.next_base_scalars(2 + len(TABLES))] - if not MIN_WHIR_LOG_INV_RATE <= log_inv_rate <= MAX_WHIR_LOG_INV_RATE: - raise ProofError("InvalidRate") - if not MIN_LOG_MEMORY_SIZE <= log_memory <= MAX_LOG_MEMORY_SIZE: - raise ProofError("InvalidProof: log_memory out of range") - if not MIN_BYTECODE_LOG_SIZE <= bytecode_log_size <= MAX_BYTECODE_LOG_SIZE: - raise ProofError("InvalidProof: bytecode log_size out of range") - if log_memory < max(max(table_log_n_rows, default=0), bytecode_log_size): - raise ProofError("InvalidProof: memory smaller than tables/bytecode") + assert MIN_WHIR_LOG_INV_RATE <= log_inv_rate <= MAX_WHIR_LOG_INV_RATE, "InvalidRate" + assert MIN_LOG_MEMORY_SIZE <= log_memory <= MAX_LOG_MEMORY_SIZE, "InvalidProof: log_memory out of range" + assert MIN_BYTECODE_LOG_SIZE <= bytecode_log_size <= MAX_BYTECODE_LOG_SIZE, "InvalidProof: bytecode log_size out of range" + assert log_memory >= max(max(table_log_n_rows, default=0), bytecode_log_size), "InvalidProof: memory smaller than tables/bytecode" for table, log_height in zip(TABLES, table_log_n_rows): - if not MIN_LOG_N_ROWS_PER_TABLE <= log_height <= table.max_log_height: - raise ProofError( - f"InvalidProof: table {table.name} log_n_rows={log_height} not in [{MIN_LOG_N_ROWS_PER_TABLE}, {table.max_log_height}]" - ) + assert MIN_LOG_N_ROWS_PER_TABLE <= log_height <= table.max_log_height, f"InvalidProof: table {table.name} log_n_rows={log_height} not in [{MIN_LOG_N_ROWS_PER_TABLE}, {table.max_log_height}]" log_heights = {t.name: h for t, h in zip(TABLES, table_log_n_rows)} n_max = sort_tables_by_height(TABLES, log_heights)[0][1] @@ -941,8 +916,7 @@ def verify_execution( ) stacked_n_vars = log2_ceil(total_stacked) - if stacked_n_vars > TWO_ADICITY + WHIR_INITIAL_FOLDING_FACTOR - log_inv_rate: - raise ProofError("InvalidProof: stacked_n_vars exceeds WHIR domain bound") + assert stacked_n_vars <= TWO_ADICITY + WHIR_INITIAL_FOLDING_FACTOR - log_inv_rate, "InvalidProof: stacked_n_vars exceeds WHIR domain bound" cfg = WHIR_CONFIGS[(log_inv_rate, stacked_n_vars)] nood = cfg["commitment_ood_samples"] parsed_commitment = WhirCommitment.read(fiat_shamir, stacked_n_vars, nood) @@ -991,8 +965,7 @@ def verify_execution( eq_vals = {i: col_evals[i] for i in range(table.n_columns)} next_vals = {j: col_evals[table.n_columns + j] for j in range(table.n_shift)} committed[table.name].append((natural_pt, eq_vals, next_vals)) - if my_air_final != sc_value: - raise ProofError("AIR sumcheck: claimed value mismatch") + assert my_air_final == sc_value, "AIR sumcheck: claimed value mismatch" pm_point = fiat_shamir.sample_many_ef(log2_strict(PUBLIC_INPUT_SIZE)) pm_eval = eval_multilinear_by_evals(public_input, pm_point) @@ -1021,21 +994,12 @@ def verify_execution( ) verify_whir(fiat_shamir, cfg, parsed_commitment, global_statements) - if fiat_shamir.offset != len(fiat_shamir.transcript): - raise ProofError( - f"InvalidProof: transcript not fully consumed ({fiat_shamir.offset}/{len(fiat_shamir.transcript)} scalars read)" - ) - if fiat_shamir.openings: - raise ProofError(f"InvalidProof: {len(fiat_shamir.openings)} Merkle openings unused") + assert fiat_shamir.offset == len(fiat_shamir.transcript), f"InvalidProof: transcript not fully consumed ({fiat_shamir.offset}/{len(fiat_shamir.transcript)} scalars read)" + assert not fiat_shamir.openings, f"InvalidProof: {len(fiat_shamir.openings)} Merkle openings unused" - -def main() -> int: +if __name__ == "__main__": vector_path = Path(__file__).resolve().parents[3] / "target" / "zkvm_test_vectors" / "proof.json" - if not vector_path.exists(): - print( - f"Test vector not found at {vector_path}. Please follow the instructions at the beginning of verifier.py file." - ) - return 1 + assert vector_path.exists(), f"Test vector not found at {vector_path}. Please follow the instructions at the beginning of verifier.py file." print(f"Loading {vector_path.name}...") raw = json.loads(vector_path.read_text()) @@ -1055,15 +1019,5 @@ def main() -> int: ], ) - try: - verify_execution(bytecode_multilinear, public_input, proof) - except ProofError as e: - print(f"FAIL: {e}") - return 1 - + verify_execution(bytecode_multilinear, public_input, proof) print("Proof successfully verified") - return 0 - - -if __name__ == "__main__": - sys.exit(main()) From 0f89768e50225088f2a1480f1b41d1827276bea4 Mon Sep 17 00:00:00 2001 From: Tom Wambsgans Date: Wed, 3 Jun 2026 19:42:23 +0400 Subject: [PATCH 27/31] --line-length 150 --- .../lean_prover/python-verifier/primitives.py | 4 +- .../lean_prover/python-verifier/verifier.py | 86 +++++++------------ 2 files changed, 30 insertions(+), 60 deletions(-) diff --git a/crates/lean_prover/python-verifier/primitives.py b/crates/lean_prover/python-verifier/primitives.py index 66fc15da..d1c52a5a 100644 --- a/crates/lean_prover/python-verifier/primitives.py +++ b/crates/lean_prover/python-verifier/primitives.py @@ -416,9 +416,7 @@ def _compute_sparse_constants() -> dict: # External full-round constants: first / last POSEIDON_HALF_FULL_ROUNDS slices of round_constants. POSEIDON_AIR_INITIAL_CONSTANTS: list[list[Fp]] = [[Fp(v) for v in _RCS[i * _W : (i + 1) * _W]] for i in range(_HF)] _TAIL = (_HF + POSEIDON_PARTIAL_ROUNDS) * _W -POSEIDON_AIR_FINAL_CONSTANTS: list[list[Fp]] = [ - [Fp(v) for v in _RCS[_TAIL + i * _W : _TAIL + (i + 1) * _W]] for i in range(_HF) -] +POSEIDON_AIR_FINAL_CONSTANTS: list[list[Fp]] = [[Fp(v) for v in _RCS[_TAIL + i * _W : _TAIL + (i + 1) * _W]] for i in range(_HF)] # Sparse partial-round constants (Fp-wrapped). POSEIDON_AIR_SPARSE_M_I: list[list[Fp]] = [[Fp(v) for v in row] for row in _SPARSE["sparse_m_i"]] diff --git a/crates/lean_prover/python-verifier/verifier.py b/crates/lean_prover/python-verifier/verifier.py index 08242f8d..79194735 100644 --- a/crates/lean_prover/python-verifier/verifier.py +++ b/crates/lean_prover/python-verifier/verifier.py @@ -4,7 +4,7 @@ Run: python3 crates/lean_prover/python-verifier/verifier.py Format: - ruff format --line-length 120 crates/lean_prover/python-verifier + ruff format --line-length 150 crates/lean_prover/python-verifier """ from __future__ import annotations @@ -30,9 +30,7 @@ "log_inv_rate": c[0], "num_variables": c[1], "commitment_ood_samples": c[2], - "rounds": [ - {"num_queries": r[0], "ood_samples": r[1], "query_pow_bits": r[2], "folding_pow_bits": r[3]} for r in c[6] - ] + "rounds": [{"num_queries": r[0], "ood_samples": r[1], "query_pow_bits": r[2], "folding_pow_bits": r[3]} for r in c[6]] + [{"num_queries": c[4], "query_pow_bits": c[5], "folding_pow_bits": c[3]}], } for c in _WHIR_CONFIGS @@ -93,15 +91,11 @@ def col(self, name: str) -> int: return self.columns.index(name) def eval_air(self, col_evals: Sequence[EF], alpha_powers: Sequence[EF], logup_beta_eq: list[EF]) -> EF: - constraint_evaluator = ConstraintEvaluator( - col_evals[: self.n_columns], col_evals[self.n_columns :], alpha_powers, self.columns - ) + constraint_evaluator = ConstraintEvaluator(col_evals[: self.n_columns], col_evals[self.n_columns :], alpha_powers, self.columns) self.air_constraints_fn(constraint_evaluator, logup_beta_eq) return constraint_evaluator.accumulator - def boundary_statements( - self, stacked_n_vars: int, offset: int, log_n_rows: int, ending_pc: int - ) -> list["SparseStatements"]: + def boundary_statements(self, stacked_n_vars: int, offset: int, log_n_rows: int, ending_pc: int) -> list["SparseStatements"]: if self.name != "execution": return [] pc_col_offset = offset + (self.col("pc") << log_n_rows) @@ -143,7 +137,7 @@ def duplex(self) -> None: self.observe([Fp(0)] * SPONGE_RATE) def _sample_rate(self) -> list[Fp]: - assert self.rate_fresh, "stale rate: insert duplex() before sampling" # unreachable + assert self.rate_fresh, "stale rate: insert duplex() before sampling" # unreachable self.rate_fresh = False return self.state[SPONGE_CAPACITY:] @@ -356,9 +350,7 @@ def oods_constraints(self) -> list[SparseStatements]: ] -def verify_sumcheck( - fiat_shamir: FiatShamir, target: EF, n_rounds: int, degree: int, pow_bits: int = 0 -) -> tuple[list[EF], EF]: +def verify_sumcheck(fiat_shamir: FiatShamir, target: EF, n_rounds: int, degree: int, pow_bits: int = 0) -> tuple[list[EF], EF]: point: list[EF] = [] for _ in range(n_rounds): coeffs = fiat_shamir.next_extension_scalars_vec(degree + 1) @@ -503,7 +495,7 @@ def verify_gkr_quotient(fiat_shamir: FiatShamir, n_vars: int) -> tuple[EF, list[ sc_point, sc_value = verify_sumcheck(fiat_shamir, claim_num + alpha * claim_den, layer_n_vars, 3) sc_point = list(reversed(sc_point)) nl, nr, dl, dr = fiat_shamir.next_extension_scalars_vec(4) - assert sc_value == eq_poly(point, sc_point) * (alpha * dl * dr + nl * dr + nr * dl), "GKR step: postponed value mismatch" + assert sc_value == eq_poly(point, sc_point) * (alpha * dl * dr + nl * dr + nr * dl), "GKR step: postponed value mismatch" # fmt: skip beta = fiat_shamir.sample_ef() one_minus = ONE - beta claim_num = one_minus * nl + beta * nr @@ -539,11 +531,7 @@ def verify_generic_logup( tables_sorted = sort_tables_by_height(tables, table_heights) tallest_h = tables_sorted[0][1] - total_active_len = ( - (1 << log_memory) - + max(1 << log_bytecode, 1 << tallest_h) - + sum(t.n_bus_interactions << h for t, h in tables_sorted) - ) + total_active_len = (1 << log_memory) + max(1 << log_bytecode, 1 << tallest_h) + sum(t.n_bus_interactions << h for t, h in tables_sorted) logup_n_vars = log2_ceil(total_active_len) quotient, gkr_point, claim_num, claim_den = verify_gkr_quotient(fiat_shamir, logup_n_vars) @@ -561,9 +549,7 @@ def pref_at(offset: int, log_height: int) -> EF: memory_acc_eval = fiat_shamir.next_extension_scalar() memory_eval = fiat_shamir.next_extension_scalar() num -= pref * memory_acc_eval - den += pref * ( - gamma - finger_print(Fp(LOGUP_MEMORY_DOMAINSEP), [mle_of_01234567_etc(mem_pt), memory_eval], beta_eq) - ) + den += pref * (gamma - finger_print(Fp(LOGUP_MEMORY_DOMAINSEP), [mle_of_01234567_etc(mem_pt), memory_eval], beta_eq)) offset = 1 << log_memory # Bytecode section (padded to the tallest table) @@ -575,14 +561,10 @@ def pref_at(offset: int, log_height: int) -> EF: bytecode_eval = eval_multilinear_by_evals([Fp(v) for v in bytecode_multilinear], bytecode_point + beta[-log_instr:]) correction = math.prod(ONE - a for a in beta[: len(beta) - log_instr]) fingerprint_bytecode = ( - bytecode_eval * correction - + mle_of_01234567_etc(bytecode_point) * beta_eq[N_INSTRUCTION_COLUMNS] - + beta_eq[-1] * Fp(LOGUP_BYTECODE_DOMAINSEP) + bytecode_eval * correction + mle_of_01234567_etc(bytecode_point) * beta_eq[N_INSTRUCTION_COLUMNS] + beta_eq[-1] * Fp(LOGUP_BYTECODE_DOMAINSEP) ) num -= pref * value_bytecode_acc - den += pref * (gamma - fingerprint_bytecode) + pref_padded * mle_of_zeros_then_ones( - 1 << log_bytecode, gkr_point[-log_bytecode_padded:] - ) + den += pref * (gamma - fingerprint_bytecode) + pref_padded * mle_of_zeros_then_ones(1 << log_bytecode, gkr_point[-log_bytecode_padded:]) offset += 1 << log_bytecode_padded # Per-table section @@ -642,9 +624,7 @@ def arr(self, prefix: str, n: int) -> list: class ConstraintEvaluator: - def __init__( - self, flat: Sequence[EF], shift: Sequence[EF], alpha_powers: Sequence[EF], columns: Sequence[str] - ) -> None: + def __init__(self, flat: Sequence[EF], shift: Sequence[EF], alpha_powers: Sequence[EF], columns: Sequence[str]) -> None: self.flat = flat self.shift = shift self.alpha_powers = alpha_powers @@ -734,9 +714,7 @@ def eval_air_extension_table(evaluator: ConstraintEvaluator, logup_beta_eq: list evaluator.assert_zero((acc[k] - (v_a_v_b[k] + acc_tail[k])) * flag_dot_product) # eq: acc ← (2·v_a·v_b − v_a − v_b + 1) · (acc_tail or 1 at group end). e_eq = [2 * v_a_v_b[k] - v_a_tilde[k] - v_b[k] + (ONE if k == 0 else ZERO) for k in range(EF.DIMENSION)] - acc_tail_or_one = [acc_next[0] * not_start_next + flag_start_next] + [ - acc_next[k] * not_start_next for k in range(1, EF.DIMENSION) - ] + acc_tail_or_one = [acc_next[0] * not_start_next + flag_start_next] + [acc_next[k] * not_start_next for k in range(1, EF.DIMENSION)] eq_result = quintic_mul(e_eq, acc_tail_or_one, ZERO) for k in range(EF.DIMENSION): evaluator.assert_zero((acc[k] - eq_result[k]) * flag_eq) @@ -887,36 +865,34 @@ def eval_air_poseidon16_table(evaluator: ConstraintEvaluator, logup_beta_eq: lis def verify_execution( - bytecode_multilinear: list[int], # trusted-source (and thus contains only valid instructions) + bytecode_multilinear: list[int], # trusted-source (and thus contains only valid instructions) public_input: Sequence[Fp], proof: Proof, ) -> None: bytecode_log_size = log2_strict(len(bytecode_multilinear)) - log2_ceil(N_INSTRUCTION_COLUMNS) ending_pc = (1 << bytecode_log_size) - 1 bytecode_hash = sponge_hash([Fp(v) for v in bytecode_multilinear]) - assert len(public_input) == PUBLIC_INPUT_SIZE, "InvalidProof: public_input length mismatch" + assert len(public_input) == PUBLIC_INPUT_SIZE, "public_input length mismatch" fiat_shamir = FiatShamir(proof, poseidon16_compress(bytecode_hash, SNARK_DOMAIN_SEP)) # domain separator across bytecodes fiat_shamir.observe_scalars(public_input) log_inv_rate, log_memory, *table_log_n_rows = [int(x.value) for x in fiat_shamir.next_base_scalars(2 + len(TABLES))] assert MIN_WHIR_LOG_INV_RATE <= log_inv_rate <= MAX_WHIR_LOG_INV_RATE, "InvalidRate" - assert MIN_LOG_MEMORY_SIZE <= log_memory <= MAX_LOG_MEMORY_SIZE, "InvalidProof: log_memory out of range" - assert MIN_BYTECODE_LOG_SIZE <= bytecode_log_size <= MAX_BYTECODE_LOG_SIZE, "InvalidProof: bytecode log_size out of range" - assert log_memory >= max(max(table_log_n_rows, default=0), bytecode_log_size), "InvalidProof: memory smaller than tables/bytecode" + assert MIN_LOG_MEMORY_SIZE <= log_memory <= MAX_LOG_MEMORY_SIZE, "log_memory out of range" + assert MIN_BYTECODE_LOG_SIZE <= bytecode_log_size <= MAX_BYTECODE_LOG_SIZE, "bytecode log_size out of range" + assert log_memory >= max(max(table_log_n_rows, default=0), bytecode_log_size), "memory smaller than tables/bytecode" for table, log_height in zip(TABLES, table_log_n_rows): - assert MIN_LOG_N_ROWS_PER_TABLE <= log_height <= table.max_log_height, f"InvalidProof: table {table.name} log_n_rows={log_height} not in [{MIN_LOG_N_ROWS_PER_TABLE}, {table.max_log_height}]" + assert MIN_LOG_N_ROWS_PER_TABLE <= log_height <= table.max_log_height, ( + f"table {table.name} log_n_rows={log_height} not in [{MIN_LOG_N_ROWS_PER_TABLE}, {table.max_log_height}]" + ) log_heights = {t.name: h for t, h in zip(TABLES, table_log_n_rows)} n_max = sort_tables_by_height(TABLES, log_heights)[0][1] - total_stacked = ( - (2 << log_memory) - + (1 << max(bytecode_log_size, n_max)) - + sum(t.n_columns << log_heights[t.name] for t in TABLES) - ) + total_stacked = (2 << log_memory) + (1 << max(bytecode_log_size, n_max)) + sum(t.n_columns << log_heights[t.name] for t in TABLES) stacked_n_vars = log2_ceil(total_stacked) - assert stacked_n_vars <= TWO_ADICITY + WHIR_INITIAL_FOLDING_FACTOR - log_inv_rate, "InvalidProof: stacked_n_vars exceeds WHIR domain bound" + assert stacked_n_vars <= TWO_ADICITY + WHIR_INITIAL_FOLDING_FACTOR - log_inv_rate, "tacked_n_vars exceeds WHIR domain bound" cfg = WHIR_CONFIGS[(log_inv_rate, stacked_n_vars)] nood = cfg["commitment_ood_samples"] parsed_commitment = WhirCommitment.read(fiat_shamir, stacked_n_vars, nood) @@ -942,9 +918,7 @@ def verify_execution( initial_sum, offset = ZERO, 0 for table in TABLES: - initial_sum += alpha_powers[offset] * ( - logup["precompile_nums"][table.name] * table.precompile_bus_interaction_sign - ) + initial_sum += alpha_powers[offset] * (logup["precompile_nums"][table.name] * table.precompile_bus_interaction_sign) initial_sum += alpha_powers[offset + 1] * (logup_gamma - logup["precompile_dens"][table.name]) offset += table.n_constraints sc_point, sc_value = verify_sumcheck(fiat_shamir, initial_sum, n_max, max(t.air_degree + 1 for t in TABLES)) @@ -978,9 +952,7 @@ def verify_execution( [(0, logup["memory_eval"]), (1, logup["memory_acc_eval"])], ), SparseStatements(stacked_n_vars, pm_point, [(0, pm_eval)]), - SparseStatements( - stacked_n_vars, gkr_point[-bytecode_log_size:], [(bytecode_acc_idx, logup["value_bytecode_acc"])] - ), + SparseStatements(stacked_n_vars, gkr_point[-bytecode_log_size:], [(bytecode_acc_idx, logup["value_bytecode_acc"])]), ] global_statements = stacked_pcs_global_statements( stacked_n_vars, @@ -994,8 +966,9 @@ def verify_execution( ) verify_whir(fiat_shamir, cfg, parsed_commitment, global_statements) - assert fiat_shamir.offset == len(fiat_shamir.transcript), f"InvalidProof: transcript not fully consumed ({fiat_shamir.offset}/{len(fiat_shamir.transcript)} scalars read)" - assert not fiat_shamir.openings, f"InvalidProof: {len(fiat_shamir.openings)} Merkle openings unused" + assert fiat_shamir.offset == len(fiat_shamir.transcript), f"transcript not fully consumed ({fiat_shamir.offset}/{len(fiat_shamir.transcript)})" + assert not fiat_shamir.openings, f"{len(fiat_shamir.openings)} Merkle openings unused" + if __name__ == "__main__": vector_path = Path(__file__).resolve().parents[3] / "target" / "zkvm_test_vectors" / "proof.json" @@ -1014,8 +987,7 @@ def verify_execution( proof = Proof( transcript=fp_list(raw["proof"]["transcript"]), merkle_openings=[ - MerkleOpening(leaf_data=fp_list(o["leaf_data"]), path=[fp_list(d) for d in o["path"]]) - for o in raw["proof"]["merkle_openings"] + MerkleOpening(leaf_data=fp_list(o["leaf_data"]), path=[fp_list(d) for d in o["path"]]) for o in raw["proof"]["merkle_openings"] ], ) From 994bd1ab81c34a1af9c485f5384141d1667f251f Mon Sep 17 00:00:00 2001 From: Tom Wambsgans Date: Wed, 3 Jun 2026 19:53:23 +0400 Subject: [PATCH 28/31] wip --- .../lean_prover/python-verifier/verifier.py | 32 +++++++++---------- 1 file changed, 16 insertions(+), 16 deletions(-) diff --git a/crates/lean_prover/python-verifier/verifier.py b/crates/lean_prover/python-verifier/verifier.py index 79194735..6d8e0c17 100644 --- a/crates/lean_prover/python-verifier/verifier.py +++ b/crates/lean_prover/python-verifier/verifier.py @@ -37,7 +37,7 @@ } MIN_LOG_MEMORY_SIZE, MAX_LOG_MEMORY_SIZE = 16, 26 -MIN_LOG_N_ROWS_PER_TABLE, MIN_BYTECODE_LOG_SIZE, MAX_BYTECODE_LOG_SIZE = 8, 8, 22 +MIN_LOG_HEIGHT_PER_TABLE, MIN_BYTECODE_LOG_SIZE, MAX_BYTECODE_LOG_SIZE = 8, 8, 22 N_VARS_TO_SEND_GKR_COEFFS = 5 N_RUNTIME_COLUMNS, N_INSTRUCTION_COLUMNS = 8, 12 @@ -95,13 +95,13 @@ def eval_air(self, col_evals: Sequence[EF], alpha_powers: Sequence[EF], logup_be self.air_constraints_fn(constraint_evaluator, logup_beta_eq) return constraint_evaluator.accumulator - def boundary_statements(self, stacked_n_vars: int, offset: int, log_n_rows: int, ending_pc: int) -> list["SparseStatements"]: + def boundary_statements(self, stacked_n_vars: int, offset: int, log_height: int, ending_pc: int) -> list["SparseStatements"]: if self.name != "execution": return [] - pc_col_offset = offset + (self.col("pc") << log_n_rows) + pc_col_offset = offset + (self.col("pc") << log_height) return [ SparseStatements(stacked_n_vars, [], [(pc_col_offset + idx, EF(pc))]) - for idx, pc in [(0, STARTING_PC), ((1 << log_n_rows) - 1, ending_pc)] + for idx, pc in [(0, STARTING_PC), ((1 << log_height) - 1, ending_pc)] ] @@ -569,9 +569,9 @@ def pref_at(offset: int, log_height: int) -> EF: # Per-table section table_offsets: dict[str, int] = {} - for table, log_n_rows in tables_sorted: + for table, log_height in tables_sorted: table_offsets[table.name] = offset - offset += table.n_bus_interactions << log_n_rows + offset += table.n_bus_interactions << log_height final_offset = offset precompile_nums: dict[str, EF] = {} @@ -876,17 +876,17 @@ def verify_execution( fiat_shamir = FiatShamir(proof, poseidon16_compress(bytecode_hash, SNARK_DOMAIN_SEP)) # domain separator across bytecodes fiat_shamir.observe_scalars(public_input) - log_inv_rate, log_memory, *table_log_n_rows = [int(x.value) for x in fiat_shamir.next_base_scalars(2 + len(TABLES))] + log_inv_rate, log_memory, *table_log_heights = [int(x.value) for x in fiat_shamir.next_base_scalars(2 + len(TABLES))] assert MIN_WHIR_LOG_INV_RATE <= log_inv_rate <= MAX_WHIR_LOG_INV_RATE, "InvalidRate" assert MIN_LOG_MEMORY_SIZE <= log_memory <= MAX_LOG_MEMORY_SIZE, "log_memory out of range" assert MIN_BYTECODE_LOG_SIZE <= bytecode_log_size <= MAX_BYTECODE_LOG_SIZE, "bytecode log_size out of range" - assert log_memory >= max(max(table_log_n_rows, default=0), bytecode_log_size), "memory smaller than tables/bytecode" - for table, log_height in zip(TABLES, table_log_n_rows): - assert MIN_LOG_N_ROWS_PER_TABLE <= log_height <= table.max_log_height, ( - f"table {table.name} log_n_rows={log_height} not in [{MIN_LOG_N_ROWS_PER_TABLE}, {table.max_log_height}]" + assert log_memory >= max(max(table_log_heights, default=0), bytecode_log_size), "memory smaller than tables/bytecode" + for table, log_height in zip(TABLES, table_log_heights): + assert MIN_LOG_HEIGHT_PER_TABLE <= log_height <= table.max_log_height, ( + f"table {table.name} log_heights={log_height} not in [{MIN_LOG_HEIGHT_PER_TABLE}, {table.max_log_height}]" ) - log_heights = {t.name: h for t, h in zip(TABLES, table_log_n_rows)} + log_heights = {t.name: h for t, h in zip(TABLES, table_log_heights)} n_max = sort_tables_by_height(TABLES, log_heights)[0][1] total_stacked = (2 << log_memory) + (1 << max(bytecode_log_size, n_max)) + sum(t.n_columns << log_heights[t.name] for t in TABLES) @@ -926,15 +926,15 @@ def verify_execution( committed = {t.name: [(gkr_point[-log_heights[t.name] :], logup["columns_evals"][t.name], {})] for t in TABLES} my_air_final, offset = ZERO, 0 for table in TABLES: - log_n_rows = log_heights[table.name] + log_height = log_heights[table.name] col_evals = fiat_shamir.next_extension_scalars_vec(table.n_columns + table.n_shift) alphas = alpha_powers[offset : offset + table.n_constraints] offset += table.n_constraints constraint_eval = table.eval_air(col_evals, alphas, logup_beta_eq) - natural_pt = list(reversed(sc_point[-log_n_rows:])) if log_n_rows else [] - k_t = math.prod(sc_point[: n_max - log_n_rows]) - my_air_final += k_t * eq_poly(gkr_point[-log_n_rows:], natural_pt) * constraint_eval + natural_pt = list(reversed(sc_point[-log_height:])) + k_t = math.prod(sc_point[: n_max - log_height]) + my_air_final += k_t * eq_poly(gkr_point[-log_height:], natural_pt) * constraint_eval eq_vals = {i: col_evals[i] for i in range(table.n_columns)} next_vals = {j: col_evals[table.n_columns + j] for j in range(table.n_shift)} From be82b8d1294faf5e5fc0079f671e6b5c3dc70797 Mon Sep 17 00:00:00 2001 From: Tom Wambsgans Date: Wed, 3 Jun 2026 20:09:53 +0400 Subject: [PATCH 29/31] wip --- .../lean_prover/python-verifier/verifier.py | 35 ++++++++----------- 1 file changed, 14 insertions(+), 21 deletions(-) diff --git a/crates/lean_prover/python-verifier/verifier.py b/crates/lean_prover/python-verifier/verifier.py index 6d8e0c17..8ece3da9 100644 --- a/crates/lean_prover/python-verifier/verifier.py +++ b/crates/lean_prover/python-verifier/verifier.py @@ -335,7 +335,7 @@ class WhirCommitment: ood_answers: list[EF] @classmethod - def read(cls, fs: "FiatShamir", num_variables: int, n_ood: int) -> "WhirCommitment": + def parse(cls, fs: "FiatShamir", num_variables: int, n_ood: int) -> "WhirCommitment": return cls( num_variables, fs.next_base_scalars(DIGEST_ELEMS), @@ -397,7 +397,7 @@ def verify_whir( if is_final: final_coeffs = fiat_shamir.next_extension_scalars_vec(1 << current_vars) else: - new_commitment = WhirCommitment.read(fiat_shamir, current_vars, round_params["ood_samples"]) + new_commitment = WhirCommitment.parse(fiat_shamir, current_vars, round_params["ood_samples"]) log_height = log_domain - folding_factor gen = Fp(KB_TWO_ADIC_GENERATORS[log_height]) @@ -611,11 +611,7 @@ def request_column_evals_dedup(cols: Sequence[int]) -> list[EF]: assert num == claim_num, "logup: numerators value mismatch" assert den == claim_den, "logup: denominators value mismatch" - return { - "memory_eval": memory_eval, "memory_acc_eval": memory_acc_eval, - "value_bytecode_acc": value_bytecode_acc, "precompile_nums": precompile_nums, "precompile_dens": precompile_dens, - "gkr_point": gkr_point, "columns_evals": columns_evals, - } # fmt: skip + return memory_eval, memory_acc_eval, value_bytecode_acc, precompile_nums, precompile_dens, gkr_point, columns_evals class Cols(dict): @@ -882,26 +878,24 @@ def verify_execution( assert MIN_BYTECODE_LOG_SIZE <= bytecode_log_size <= MAX_BYTECODE_LOG_SIZE, "bytecode log_size out of range" assert log_memory >= max(max(table_log_heights, default=0), bytecode_log_size), "memory smaller than tables/bytecode" for table, log_height in zip(TABLES, table_log_heights): - assert MIN_LOG_HEIGHT_PER_TABLE <= log_height <= table.max_log_height, ( - f"table {table.name} log_heights={log_height} not in [{MIN_LOG_HEIGHT_PER_TABLE}, {table.max_log_height}]" - ) + assert MIN_LOG_HEIGHT_PER_TABLE <= log_height <= table.max_log_height, f"table {table.name}: invalid height" log_heights = {t.name: h for t, h in zip(TABLES, table_log_heights)} n_max = sort_tables_by_height(TABLES, log_heights)[0][1] - total_stacked = (2 << log_memory) + (1 << max(bytecode_log_size, n_max)) + sum(t.n_columns << log_heights[t.name] for t in TABLES) - + total_stacked = ( + (2 << log_memory) + (1 << max(bytecode_log_size, n_max)) + sum(t.n_columns << log_heights[t.name] for t in TABLES) + ) # memory + memory_acc + bytecode_acc + biggest_table + second_biggest_table + etc + smallest_table stacked_n_vars = log2_ceil(total_stacked) assert stacked_n_vars <= TWO_ADICITY + WHIR_INITIAL_FOLDING_FACTOR - log_inv_rate, "tacked_n_vars exceeds WHIR domain bound" cfg = WHIR_CONFIGS[(log_inv_rate, stacked_n_vars)] - nood = cfg["commitment_ood_samples"] - parsed_commitment = WhirCommitment.read(fiat_shamir, stacked_n_vars, nood) + parsed_commitment = WhirCommitment.parse(fiat_shamir, stacked_n_vars, cfg["commitment_ood_samples"]) logup_gamma = fiat_shamir.sample_ef() # the quotient denominator fiat_shamir.duplex() logup_beta = fiat_shamir.sample_many_ef(log2_ceil(N_INSTRUCTION_COLUMNS + 2)) # the bus-tuple hashing seeds logup_beta_eq = eval_eq(logup_beta) - logup = verify_generic_logup( + memory_eval, memory_acc_eval, value_bytecode_acc, precompile_nums, precompile_dens, gkr_point, columns_evals = verify_generic_logup( fiat_shamir, logup_gamma, logup_beta, @@ -911,19 +905,18 @@ def verify_execution( TABLES, log_heights, ) - gkr_point = logup["gkr_point"] air_alpha = fiat_shamir.sample_ef() alpha_powers = ef_powers(air_alpha, sum(t.n_constraints for t in TABLES)) initial_sum, offset = ZERO, 0 for table in TABLES: - initial_sum += alpha_powers[offset] * (logup["precompile_nums"][table.name] * table.precompile_bus_interaction_sign) - initial_sum += alpha_powers[offset + 1] * (logup_gamma - logup["precompile_dens"][table.name]) + initial_sum += alpha_powers[offset] * (precompile_nums[table.name] * table.precompile_bus_interaction_sign) + initial_sum += alpha_powers[offset + 1] * (logup_gamma - precompile_dens[table.name]) offset += table.n_constraints sc_point, sc_value = verify_sumcheck(fiat_shamir, initial_sum, n_max, max(t.air_degree + 1 for t in TABLES)) - committed = {t.name: [(gkr_point[-log_heights[t.name] :], logup["columns_evals"][t.name], {})] for t in TABLES} + committed = {t.name: [(gkr_point[-log_heights[t.name] :], columns_evals[t.name], {})] for t in TABLES} my_air_final, offset = ZERO, 0 for table in TABLES: log_height = log_heights[table.name] @@ -949,10 +942,10 @@ def verify_execution( SparseStatements( stacked_n_vars, gkr_point[-log_memory:], - [(0, logup["memory_eval"]), (1, logup["memory_acc_eval"])], + [(0, memory_eval), (1, memory_acc_eval)], ), SparseStatements(stacked_n_vars, pm_point, [(0, pm_eval)]), - SparseStatements(stacked_n_vars, gkr_point[-bytecode_log_size:], [(bytecode_acc_idx, logup["value_bytecode_acc"])]), + SparseStatements(stacked_n_vars, gkr_point[-bytecode_log_size:], [(bytecode_acc_idx, value_bytecode_acc)]), ] global_statements = stacked_pcs_global_statements( stacked_n_vars, From e6db3428712c5f3f0b2a23f9320c3dc7ee5ae677 Mon Sep 17 00:00:00 2001 From: Tom Wambsgans Date: Wed, 3 Jun 2026 20:26:49 +0400 Subject: [PATCH 30/31] wip --- .../lean_prover/python-verifier/verifier.py | 135 ++++++++---------- 1 file changed, 57 insertions(+), 78 deletions(-) diff --git a/crates/lean_prover/python-verifier/verifier.py b/crates/lean_prover/python-verifier/verifier.py index 8ece3da9..19a0a026 100644 --- a/crates/lean_prover/python-verifier/verifier.py +++ b/crates/lean_prover/python-verifier/verifier.py @@ -443,41 +443,6 @@ def verify_whir( assert final_sc_value == eval_weights * final_value, "WHIR final sumcheck check failed" -def stacked_pcs_global_statements( - stacked_n_vars: int, - memory_n_vars: int, - bytecode_n_vars: int, - previous_statements: list[SparseStatements], - tables: Sequence[Table], - heights: dict[str, int], - committed_statements: dict[str, list[tuple[list[EF], dict[int, EF], dict[int, EF]]]], - ending_pc: int, -) -> list[SparseStatements]: - tables_sorted = sort_tables_by_height(tables, heights) - table_offsets: dict[str, int] = {} - layout_offset = (2 << memory_n_vars) + (1 << max(bytecode_n_vars, tables_sorted[0][1])) - for table, n_vars in tables_sorted: - table_offsets[table.name] = layout_offset - layout_offset += table.n_columns << n_vars - - res = list(previous_statements) - - def values_at(d: dict[int, EF], col_base: int) -> list[tuple[int, EF]]: - return [(col_base + i, v) for i, v in sorted(d.items())] - - for table in tables: - n_vars = heights[table.name] - offset = table_offsets[table.name] - col_base = offset >> n_vars - res.extend(table.boundary_statements(stacked_n_vars, offset, n_vars, ending_pc)) - for point, eq_values, next_values in committed_statements[table.name]: - if next_values: - res.append(SparseStatements(stacked_n_vars, point, values_at(next_values, col_base), True)) - res.append(SparseStatements(stacked_n_vars, point, values_at(eq_values, col_base))) - - return res - - def verify_gkr_quotient(fiat_shamir: FiatShamir, n_vars: int) -> tuple[EF, list[EF], EF, EF]: assert n_vars > N_VARS_TO_SEND_GKR_COEFFS @@ -515,7 +480,7 @@ def sort_tables_by_height(tables: Sequence[Table], heights: dict[str, int]) -> l return sorted([(t, heights[t.name]) for t in tables], key=lambda x: (-x[1], x[0].name)) -def verify_generic_logup( +def verify_logup( fiat_shamir: FiatShamir, gamma: EF, # quotient denominator challenge beta: list[EF], # bus-tuple hashing seed @@ -726,13 +691,6 @@ def eval_air_extension_table(evaluator: ConstraintEvaluator, logup_beta_eq: list evaluator.assert_zero(flag_start_next * (len_col - ONE)) -def do_2_full_round(state: list[EF], rc1: list[Fp], rc2: list[Fp]) -> list[EF]: - for rc in (rc1, rc2): - sbox = [(s + c).cube() for s, c in zip(state, rc)] - state = [dot_product(sbox, row) for row in POSEIDON_AIR_MDS_DENSE] - return state - - def eval_air_poseidon16_table(evaluator: ConstraintEvaluator, logup_beta_eq: list[EF]) -> None: multiplicity, nu_b, nu_c , flag_out4, flag_out8, flag_left, offset_left, addr_left_lo, addr_left_hi, flag_permute = (evaluator.flat[k] for k in POSEIDON_COLUMNS[:10]) # fmt: skip inputs = evaluator.flat.arr("input", POSEIDON_WIDTH) @@ -753,6 +711,13 @@ def eval_air_poseidon16_table(evaluator: ConstraintEvaluator, logup_beta_eq: lis evaluator.assert_zero(flag_left * (offset_left - addr_left_lo)) evaluator.assert_zero(not_flag_left * (nu_a - addr_left_lo)) state = list(inputs) + + def do_2_full_round(state: list[EF], rc1: list[Fp], rc2: list[Fp]) -> list[EF]: + for rc in (rc1, rc2): + sbox = [(s + c).cube() for s, c in zip(state, rc)] + state = [dot_product(sbox, row) for row in POSEIDON_AIR_MDS_DENSE] + return state + # 2-by-2 initial full rounds for r in range(POSEIDON_QUARTER_FULL_ROUNDS): state = do_2_full_round(state, POSEIDON_AIR_INITIAL_CONSTANTS[2 * r], POSEIDON_AIR_INITIAL_CONSTANTS[2 * r + 1]) @@ -880,22 +845,27 @@ def verify_execution( for table, log_height in zip(TABLES, table_log_heights): assert MIN_LOG_HEIGHT_PER_TABLE <= log_height <= table.max_log_height, f"table {table.name}: invalid height" - log_heights = {t.name: h for t, h in zip(TABLES, table_log_heights)} - n_max = sort_tables_by_height(TABLES, log_heights)[0][1] + table_log_heights = {t.name: h for t, h in zip(TABLES, table_log_heights)} + tables_sorted = sort_tables_by_height(TABLES, table_log_heights) + n_max = tables_sorted[0][1] total_stacked = ( - (2 << log_memory) + (1 << max(bytecode_log_size, n_max)) + sum(t.n_columns << log_heights[t.name] for t in TABLES) + (2 << log_memory) + (1 << max(bytecode_log_size, n_max)) + sum(t.n_columns << table_log_heights[t.name] for t in TABLES) ) # memory + memory_acc + bytecode_acc + biggest_table + second_biggest_table + etc + smallest_table stacked_n_vars = log2_ceil(total_stacked) assert stacked_n_vars <= TWO_ADICITY + WHIR_INITIAL_FOLDING_FACTOR - log_inv_rate, "tacked_n_vars exceeds WHIR domain bound" cfg = WHIR_CONFIGS[(log_inv_rate, stacked_n_vars)] + + # 1] Parse WHIR commitment parsed_commitment = WhirCommitment.parse(fiat_shamir, stacked_n_vars, cfg["commitment_ood_samples"]) logup_gamma = fiat_shamir.sample_ef() # the quotient denominator fiat_shamir.duplex() logup_beta = fiat_shamir.sample_many_ef(log2_ceil(N_INSTRUCTION_COLUMNS + 2)) # the bus-tuple hashing seeds logup_beta_eq = eval_eq(logup_beta) - memory_eval, memory_acc_eval, value_bytecode_acc, precompile_nums, precompile_dens, gkr_point, columns_evals = verify_generic_logup( + + # 2] Verify logup bus interractions + memory_eval, memory_acc_eval, value_bytecode_acc, precompile_nums, precompile_dens, gkr_point, columns_evals = verify_logup( fiat_shamir, logup_gamma, logup_beta, @@ -903,61 +873,70 @@ def verify_execution( log_memory, bytecode_multilinear, TABLES, - log_heights, + table_log_heights, ) - air_alpha = fiat_shamir.sample_ef() - alpha_powers = ef_powers(air_alpha, sum(t.n_constraints for t in TABLES)) + alpha = fiat_shamir.sample_ef() + alpha_powers = ef_powers(alpha, sum(t.n_constraints for t in TABLES)) initial_sum, offset = ZERO, 0 for table in TABLES: initial_sum += alpha_powers[offset] * (precompile_nums[table.name] * table.precompile_bus_interaction_sign) initial_sum += alpha_powers[offset + 1] * (logup_gamma - precompile_dens[table.name]) offset += table.n_constraints + + # 3] verify batched AIR sumcheck sc_point, sc_value = verify_sumcheck(fiat_shamir, initial_sum, n_max, max(t.air_degree + 1 for t in TABLES)) - committed = {t.name: [(gkr_point[-log_heights[t.name] :], columns_evals[t.name], {})] for t in TABLES} - my_air_final, offset = ZERO, 0 + committed_column_evals = {t.name: [(gkr_point[-table_log_heights[t.name] :], columns_evals[t.name], {})] for t in TABLES} + air_final_value, offset = ZERO, 0 for table in TABLES: - log_height = log_heights[table.name] - col_evals = fiat_shamir.next_extension_scalars_vec(table.n_columns + table.n_shift) + log_height = table_log_heights[table.name] + col_evals = fiat_shamir.next_extension_scalars_vec(table.n_shift + table.n_columns) alphas = alpha_powers[offset : offset + table.n_constraints] offset += table.n_constraints constraint_eval = table.eval_air(col_evals, alphas, logup_beta_eq) - - natural_pt = list(reversed(sc_point[-log_height:])) - k_t = math.prod(sc_point[: n_max - log_height]) - my_air_final += k_t * eq_poly(gkr_point[-log_height:], natural_pt) * constraint_eval - + natural_point = list(reversed(sc_point[-log_height:])) + air_final_value += math.prod(sc_point[:-log_height]) * eq_poly(gkr_point[-log_height:], natural_point) * constraint_eval eq_vals = {i: col_evals[i] for i in range(table.n_columns)} next_vals = {j: col_evals[table.n_columns + j] for j in range(table.n_shift)} - committed[table.name].append((natural_pt, eq_vals, next_vals)) - assert my_air_final == sc_value, "AIR sumcheck: claimed value mismatch" + committed_column_evals[table.name].append((natural_point, eq_vals, next_vals)) + assert air_final_value == sc_value, "AIR sumcheck: claimed value mismatch" - pm_point = fiat_shamir.sample_many_ef(log2_strict(PUBLIC_INPUT_SIZE)) - pm_eval = eval_multilinear_by_evals(public_input, pm_point) + public_memory_point = fiat_shamir.sample_many_ef(log2_strict(PUBLIC_INPUT_SIZE)) + public_memory_eval = eval_multilinear_by_evals(public_input, public_memory_point) - bytecode_acc_idx = (2 << log_memory) >> bytecode_log_size - previous_statements = [ + bytecode_acc_offset = (2 << log_memory) >> bytecode_log_size # offset within the stacked polynomial + pcs_statements = [ SparseStatements( stacked_n_vars, gkr_point[-log_memory:], [(0, memory_eval), (1, memory_acc_eval)], ), - SparseStatements(stacked_n_vars, pm_point, [(0, pm_eval)]), - SparseStatements(stacked_n_vars, gkr_point[-bytecode_log_size:], [(bytecode_acc_idx, value_bytecode_acc)]), + SparseStatements(stacked_n_vars, public_memory_point, [(0, public_memory_eval)]), + SparseStatements(stacked_n_vars, gkr_point[-bytecode_log_size:], [(bytecode_acc_offset, value_bytecode_acc)]), ] - global_statements = stacked_pcs_global_statements( - stacked_n_vars, - log_memory, - bytecode_log_size, - previous_statements, - TABLES, - log_heights, - committed, - ending_pc, - ) - verify_whir(fiat_shamir, cfg, parsed_commitment, global_statements) + table_offsets: dict[str, int] = {} + layout_offset = (2 << log_memory) + (1 << max(bytecode_log_size, tables_sorted[0][1])) + for table, log_height in tables_sorted: + table_offsets[table.name] = layout_offset + layout_offset += table.n_columns << log_height + + def values_at(d: dict[int, EF], col_base: int) -> list[tuple[int, EF]]: + return [(col_base + i, v) for i, v in sorted(d.items())] + + for table in TABLES: + log_height = table_log_heights[table.name] + offset = table_offsets[table.name] + col_base = offset >> log_height + pcs_statements.extend(table.boundary_statements(stacked_n_vars, offset, log_height, ending_pc)) + for point, eq_values, next_values in committed_column_evals[table.name]: + if next_values: + pcs_statements.append(SparseStatements(stacked_n_vars, point, values_at(next_values, col_base), True)) + pcs_statements.append(SparseStatements(stacked_n_vars, point, values_at(eq_values, col_base))) + + # 4] Open the PCS + verify_whir(fiat_shamir, cfg, parsed_commitment, pcs_statements) assert fiat_shamir.offset == len(fiat_shamir.transcript), f"transcript not fully consumed ({fiat_shamir.offset}/{len(fiat_shamir.transcript)})" assert not fiat_shamir.openings, f"{len(fiat_shamir.openings)} Merkle openings unused" From e42daef3a5b1805c1d3a9d4ae8fa980cddd7390e Mon Sep 17 00:00:00 2001 From: Tom Wambsgans Date: Thu, 4 Jun 2026 00:55:31 +0400 Subject: [PATCH 31/31] wip --- misc/minimal_zkVM.tex | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/misc/minimal_zkVM.tex b/misc/minimal_zkVM.tex index 33dbd25c..7ffc14a4 100644 --- a/misc/minimal_zkVM.tex +++ b/misc/minimal_zkVM.tex @@ -96,7 +96,7 @@ \subsection{Notations} \begin{itemize} \item $[a, b] = \{a, a + 1, \ldots, b\}$ \item $[a, b) = \{a, a + 1, \ldots, b-1\}$ - \item $\hat{eq}(x, y) = x y + (1 - x)(1 - y)$ + \item $\hat{eq}\big((x_1, \dots, x_n), (y_1, \dots, y_n)\big) = \prod_{i = 1}^{n}{\big(x_i y_i + (1 - x_i)(1 - y_i)\big)}$ \item \label{sec:preliminaries-embed} $\mathsf{embed}(x_1, x_2, x_3, x_4, x_5)$ is the extension field element built from $\extdeg$ base field elements: $x_1 + x_2 X + x_3 X^2 + x_4 X^3 + x_5 X^4$, using $X^5 + X^2 - 1$ as the irreducible polynomial defining the extension (see \Cref{sec:field}). \item Given a number of bits $n > 0$, and an $n$-bit integer $x$, we denote by $[x]_\text{bits}^n$ its big-endian bit-decomposition in $n$ boolean values. For example, $[14]_\text{bits}^4 = (1, 1, 1, 0)$ \item $h_\textsc{memory}$ is the log-size of the (read-only) memory; the memory has $2^{h_\textsc{memory}}$ words (always a power of 2). $h_\textsc{bytecode}$ is the bytecode log-size (the bytecode has $2^{h_\textsc{bytecode}}$ instructions), and $h_\textsc{exec}$, $h_\textsc{poseidon}$, $h_\textsc{extension}$ are the log-heights of the three tables (see \Cref{sec:table-sizes}). @@ -143,15 +143,15 @@ \subsubsection{Shifted polynomial}\label{sec:shifts} \subsubsection{The next multilinear}\label{sec:next-mle} -For $n \geq 0$, identify each $x \in \{0,1\}^n$ with the integer it encodes in big-endian. We define $\nextmle$ as the multilinear polynomial in $2n$ variables whose values on the boolean hypercube are +For $n \geq 0$, we define $\nextmle$ as the multilinear polynomial in $2n$ variables whose values on the boolean hypercube are \[ -\nextmle(x, y) = +\nextmle([x]_\text{bits}^n \,\, || \,\, [y]_\text{bits}^n) = \begin{cases} 1 & \text{if } y = x + 1, \\ 1 & \text{if } x = y = 2^n - 1, \\ 0 & \text{otherwise}, \end{cases} -\qquad x, y \in \{0,1\}^n . +\qquad x, y \in [0, 2^n) \] $\nextmle$ can be efficiently evaluated: @@ -160,7 +160,7 @@ \subsubsection{The next multilinear}\label{sec:next-mle} \] The leading product is the wrap-around case $x = y = 2^n - 1$. The $k$-th term of the sum is the case $y = x + 1$ with the carry stopping at bit $k$: the high bits ($j < k$) are unchanged, bit $k$ flips $0 \to 1$, and the low bits ($j > k$) flip $1 \to 0$. -The polynomial $\nextmle$ enables sumcheck-based evaluation of shifted MLEs (\cref{sec:shifts}): for every $v \in \Fp^{2^n}$, +This enables sumcheck-based evaluation of shifted MLEs (\cref{sec:shifts}): for every $v \in \Fp^{2^n}$, \[ \overset{\scriptscriptstyle\wedge}{\textsf{shift}(v)}(x) \;=\; \sum_{y \in \{0,1\}^n} \nextmle(x, y)\; \hat{v}(y) , \] @@ -170,7 +170,7 @@ \subsection{Lemmas} \subsubsection{Schwartz-Zippel} \begin{lemma}[Schwartz-Zippel]\label{lem:sz} -Let $P \in \Fq[X_1, \ldots, X_n]$ be a non-zero polynomial of total degree $d$. Then +Let $P \in \Fq[X_1, \ldots, X_n]$ be a non-zero polynomial of total degree at most $d$. Then \[ \Pr_{\vec\beta \,\overset{\$}{\leftarrow}\, \Fq^n}\big[P(\vec\beta) = 0\big] \;\leq\; \frac{d}{q}. \]