From 2321d582fae11901c762ff27c1f4b7ec7891631e Mon Sep 17 00:00:00 2001 From: Tom Wambsgans Date: Mon, 1 Jun 2026 22:18:45 +0400 Subject: [PATCH 1/2] zkdsl compiler: remove the possibility for a `range` loop to mutate variables defined in external scope. Use manual buffers instead. --- .../lean_compiler/src/a_simplify_lang/mod.rs | 243 ++---------------- .../tests/test_data/error_100.py | 11 + .../tests/test_data/error_101.py | 13 + .../tests/test_data/error_102.py | 12 + .../lean_compiler/tests/test_data/error_99.py | 12 + crates/lean_compiler/zkDSL.md | 171 ++++++------ .../rec_aggregation/zkdsl_implem/hashing.py | 14 +- crates/rec_aggregation/zkdsl_implem/main.py | 39 ++- .../rec_aggregation/zkdsl_implem/recursion.py | 7 +- crates/rec_aggregation/zkdsl_implem/utils.py | 18 +- crates/rec_aggregation/zkdsl_implem/whir.py | 74 ++++-- 11 files changed, 273 insertions(+), 341 deletions(-) create mode 100644 crates/lean_compiler/tests/test_data/error_100.py create mode 100644 crates/lean_compiler/tests/test_data/error_101.py create mode 100644 crates/lean_compiler/tests/test_data/error_102.py create mode 100644 crates/lean_compiler/tests/test_data/error_99.py diff --git a/crates/lean_compiler/src/a_simplify_lang/mod.rs b/crates/lean_compiler/src/a_simplify_lang/mod.rs index 62f1edb9..cc062866 100644 --- a/crates/lean_compiler/src/a_simplify_lang/mod.rs +++ b/crates/lean_compiler/src/a_simplify_lang/mod.rs @@ -321,8 +321,7 @@ pub fn simplify_program(mut program: Program) -> Result { program.functions.remove(&name); } - let mut mutable_loop_counter = Counter::new(); - transform_mutable_in_loops_in_program(&mut program, &mut mutable_loop_counter)?; + check_no_loop_carried_mutables(&program)?; let mut new_functions = BTreeMap::new(); let mut counters = Counters::default(); @@ -956,28 +955,6 @@ fn substitute_const_vars_in_expr(expr: &mut Expression, const_var_exprs: &BTreeM changed } -// ============================================================================ -// TRANSFORMATION: Mutable variables in non-unrolled loops -// ============================================================================ -// -// This transformation handles mutable variables that are modified inside -// non-unrolled loops by using buffers to store intermediate values. -// -// For a loop like: -// for i in start..end { x += i; } -// -// We transform it to: -// size = end - start; -// x_buff = Array(size + 1); -// x_buff[0] = x; -// for i in start..end { -// buff_idx = i - start; -// mut x_body = x_buff[buff_idx]; -// x_body += i; -// x_buff[buff_idx + 1] = x_body; -// } -// x = x_buff[size]; - /// Finds mutable variables that are: /// 1. Defined OUTSIDE this block (external) /// 2. Re-assigned INSIDE this block @@ -1052,216 +1029,45 @@ fn find_assigned_external_vars_helper( } } -fn transform_mutable_in_loops_in_program(program: &mut Program, counter: &mut Counter) -> Result<(), String> { - for func in program.functions.values_mut() { - transform_mutable_in_loops_in_lines(&mut func.body, &program.const_arrays, counter, &BTreeSet::new())?; +/// Reject any `range` / `parallel_range` loop that reassigns a mutable variable +/// defined in an enclosing scope ("loop-carried mutable"). +fn check_no_loop_carried_mutables(program: &Program) -> Result<(), String> { + for func in program.functions.values() { + check_loop_carried_mutables_in_lines(&func.body, &program.const_arrays, &BTreeSet::new())?; } Ok(()) } -fn transform_mutable_in_loops_in_lines( - lines: &mut Vec, +fn check_loop_carried_mutables_in_lines( + lines: &[Line], const_arrays: &BTreeMap, - counter: &mut Counter, outer_mut_vars: &BTreeSet, ) -> Result<(), String> { let mut local_mut_vars = outer_mut_vars.clone(); - let mut i = 0; - while i < lines.len() { - match &mut lines[i] { - Line::ForLoop { body, loop_kind, .. } if loop_kind.is_unroll() => { - transform_mutable_in_loops_in_lines(body, const_arrays, counter, &local_mut_vars)?; - i += 1; - } + for line in lines { + match line { Line::ForLoop { - iterator, - start, - end, body, - loop_kind: loop_kind @ (LoopKind::Range | LoopKind::ParallelRange), + loop_kind: LoopKind::Range | LoopKind::ParallelRange, location, + .. } => { - let loop_kind = loop_kind.clone(); - transform_mutable_in_loops_in_lines(body, const_arrays, counter, &local_mut_vars)?; + check_loop_carried_mutables_in_lines(body, const_arrays, &local_mut_vars)?; let modified_vars = find_modified_external_vars(body, const_arrays, &local_mut_vars); - - if modified_vars.is_empty() { - // No mutable variables modified, no transformation needed - i += 1; - continue; - } - - if loop_kind.is_parallel() { + if !modified_vars.is_empty() { return Err(format!( - "parallel loop at {location} carries mutable variable(s) {modified_vars:?} across iterations; use a sequential `range` loop" + "loop at {location} reassigns enclosing-scope mutable(s) {modified_vars:?}; \ + loop-carried mutables are unsupported: use an explicit buffer (see zkDSL.md, \"For loops\")" )); } - - let suffix = counter.get_next(); - - // Generate the transformed code - let mut new_lines = Vec::new(); - - let location = *location; - - // Create size variable: @loop_size_{suffix} = end - start - let size_var = format!("@loop_size_{suffix}"); - - new_lines.push(Line::Statement { - targets: vec![AssignmentTarget::Var { - var: size_var.clone(), - is_mutable: false, - }], - value: Expression::MathExpr(MathOperation::Sub, vec![end.clone(), start.clone()]), - location, - }); - - let mut var_to_buff: BTreeMap = BTreeMap::new(); // var -> (buff_name, body_name) - - for var in &modified_vars { - let buff_name = format!("@loop_buff_{var}_{suffix}"); - let body_name = format!("@loop_body_{var}_{suffix}"); - - // buff = Array(size + 1) - new_lines.push(Line::Statement { - targets: vec![AssignmentTarget::Var { - var: buff_name.clone(), - is_mutable: false, - }], - value: Expression::FunctionCall { - function_name: "Array".to_string(), - args: vec![Expression::MathExpr( - // TODO opti in case there is only one mutated var - MathOperation::Add, - vec![Expression::var(size_var.clone()), Expression::one()], - )], - location, - }, - location, - }); - - // buff[0] = var (current value) - new_lines.push(Line::Statement { - targets: vec![AssignmentTarget::ArrayAccess { - array: buff_name.clone().into(), - index: Box::new(Expression::zero()), - }], - value: Expression::var(var.clone()), - location, - }); - - var_to_buff.insert(var.clone(), (buff_name, body_name)); - } - - // Transform the loop body - let iterator = iterator.clone(); - let mut new_body = Vec::new(); - - // buff_idx = i - start (or just i when start is zero) - let buff_idx_var = format!("@loop_buff_idx_{suffix}"); - - new_body.push(Line::Statement { - targets: vec![AssignmentTarget::Var { - var: buff_idx_var.clone(), - is_mutable: false, - }], - value: Expression::MathExpr( - MathOperation::Sub, - vec![Expression::var(iterator.clone()), start.clone()], - ), - location, - }); - - // For each modified variable: mut body_var = buff[buff_idx] - for (var, (buff_name, body_name)) in &var_to_buff { - new_body.push(Line::Statement { - targets: vec![AssignmentTarget::Var { - var: body_name.clone(), - is_mutable: true, - }], - value: Expression::ArrayAccess { - array: buff_name.clone().into(), - index: vec![Expression::Value( - VarOrConstMallocAccess::Var(buff_idx_var.clone()).into(), - )], - }, - location, - }); - - // Replace all references to var with body_name in the original body - transform_vars_in_lines(body, &|v: &Var| { - if v == var { - VarTransform::Rename(body_name.clone()) - } else { - VarTransform::Keep - } - }); - } - - // Add the original body (now modified to use body_vars) - new_body.append(body); - - // next_idx = buff_idx + 1 - let next_idx_var = format!("@loop_next_idx_{suffix}"); - new_body.push(Line::Statement { - targets: vec![AssignmentTarget::Var { - var: next_idx_var.clone(), - is_mutable: false, - }], - value: Expression::MathExpr( - MathOperation::Add, - vec![Expression::var(buff_idx_var.clone()), Expression::one()], - ), - location, - }); - - // For each modified variable: buff[next_idx] = body_var - for (buff_name, body_name) in var_to_buff.values() { - new_body.push(Line::Statement { - targets: vec![AssignmentTarget::ArrayAccess { - array: buff_name.clone().into(), - index: Expression::var(next_idx_var.clone()).into(), - }], - value: Expression::var(body_name.clone()), - location, - }); - } - - // Create the new loop - new_lines.push(Line::ForLoop { - iterator: iterator.clone(), - start: start.clone(), - end: end.clone(), - body: new_body, - loop_kind, - location, - }); - - // After the loop: var = buff[size] - for (var, (buff_name, _body_name)) in &var_to_buff { - new_lines.push(Line::Statement { - targets: vec![AssignmentTarget::Var { - var: var.clone(), - is_mutable: false, - }], - value: Expression::ArrayAccess { - array: buff_name.clone().into(), - index: vec![Expression::var(size_var.clone())], - }, - location, - }); - } - - // Replace the original loop with the new lines - let num_new = new_lines.len(); - lines.splice(i..=i, new_lines); - i += num_new; } - line @ (Line::IfCondition { .. } | Line::Match { .. }) => { - for block in line.nested_blocks_mut() { - transform_mutable_in_loops_in_lines(block, const_arrays, counter, &local_mut_vars)?; + Line::ForLoop { body, .. } => { + check_loop_carried_mutables_in_lines(body, const_arrays, &local_mut_vars)?; + } + Line::IfCondition { .. } | Line::Match { .. } => { + for block in line.nested_blocks() { + check_loop_carried_mutables_in_lines(block, const_arrays, &local_mut_vars)?; } - i += 1; } Line::Statement { targets, .. } => { for target in targets { @@ -1269,11 +1075,8 @@ fn transform_mutable_in_loops_in_lines( local_mut_vars.insert(var.clone()); } } - i += 1; - } - _ => { - i += 1; } + _ => {} } } Ok(()) diff --git a/crates/lean_compiler/tests/test_data/error_100.py b/crates/lean_compiler/tests/test_data/error_100.py new file mode 100644 index 00000000..f636dadb --- /dev/null +++ b/crates/lean_compiler/tests/test_data/error_100.py @@ -0,0 +1,11 @@ +from snark_lib import * + +# Error: a Mut carried across a `parallel_range` loop is rejected, same as `range`. + + +def main(): + acc: Mut = 0 + for i in parallel_range(0, 4): + acc = acc + i + assert acc == 6 + return diff --git a/crates/lean_compiler/tests/test_data/error_101.py b/crates/lean_compiler/tests/test_data/error_101.py new file mode 100644 index 00000000..ed56d757 --- /dev/null +++ b/crates/lean_compiler/tests/test_data/error_101.py @@ -0,0 +1,13 @@ +from snark_lib import * + +# Error: the enclosing Mut `c` is reassigned inside an `if` nested in a `range` +# loop — detection must look inside nested blocks, not just the loop's top level. + + +def main(): + c: Mut = 0 + for i in range(0, 5): + if i == 2: + c = c + 1 + assert c == 1 + return diff --git a/crates/lean_compiler/tests/test_data/error_102.py b/crates/lean_compiler/tests/test_data/error_102.py new file mode 100644 index 00000000..d7febace --- /dev/null +++ b/crates/lean_compiler/tests/test_data/error_102.py @@ -0,0 +1,12 @@ +from snark_lib import * + +# Error: `counter` (enclosing Mut) is reassigned inside a nested `range` loop. + + +def main(): + counter: Mut = 0 + for i in range(0, 3): + for j in range(0, 2): + counter = counter + 1 + assert counter == 6 + return diff --git a/crates/lean_compiler/tests/test_data/error_99.py b/crates/lean_compiler/tests/test_data/error_99.py new file mode 100644 index 00000000..4539647e --- /dev/null +++ b/crates/lean_compiler/tests/test_data/error_99.py @@ -0,0 +1,12 @@ +from snark_lib import * + +# Error: `total` (a Mut from the enclosing scope) is reassigned inside a `range` +# loop. Loop-carried mutables are not supported; use an explicit buffer instead. + + +def main(): + total: Mut = 0 + for i in range(0, 5): + total = total + i + assert total == 10 + return diff --git a/crates/lean_compiler/zkDSL.md b/crates/lean_compiler/zkDSL.md index 30bd4641..e3d1f11b 100644 --- a/crates/lean_compiler/zkDSL.md +++ b/crates/lean_compiler/zkDSL.md @@ -351,19 +351,46 @@ The general-purpose runtime loop. `a` and `b` may be runtime values. The compiler lowers the loop to a recursive function. ```python -sum: Mut = 0 -for i in range(1, 11): - sum += i -assert sum == 55 +for i in range(0, n): + out[i] = f(i) ``` -Mutable variables carried across iterations are supported transparently. +**Mutability:** A `range` range connot mutate variables defined outside its scope. -*Under the hood: the compiler inserts a buffer array, stores the per-iteration value into it, and reads the final value back after the loop.* +```python +total: Mut = 0 +for i in range(0, n): + total = total + a[i] # ERROR: loop-carried mutable +``` + +Mutable variables defined inside the loop are fine: -Restrictions: No `return` inside the body +````python +y = 9 +for i in range(0, n): + x: Mut = f(i) + x += 4 + x *= y + assert x < 455 +```` -*Under the hood: because the loop is lowered to a recursive function.* +*Under the hood: a `range` loop gets transformed into a recursive function, breaking compiler SSA renaming.* + +The solution, to mutate data beyond the loop's scope, is to use (read-only) buffers: + +```python +def sum(arr, n): + total_buf = Array(n + 1) + total_buf[0] = 0 + for i in range(0, n): + total: Mut = total_buf[i] # loop-LOCAL mutable: fine + total += arr[i] + total_buf[i + 1] = total + result = total_buf[n] # final value, after the loop + return result +``` + +Additional restrictions: no `return` inside the body (dos not concern `unroll` loops). #### `unroll(a, b)`: compile-time unrolling @@ -393,8 +420,7 @@ for i in parallel_range(0, n): Because there is no synchronization, the loop body must be iteration-independent: -- No `Mut` variables carried across iterations (each iteration writes only to - its own call frame and to addresses disjoint from every other iteration). +- Each iteration writes only to its own call frame and to addresses disjoint from every other iteration (no concurrent writes). - Identical memory footprint per iteration. - Identical hint consumption per iteration (witness hints, XMSS-specific decomposition hints, Merkle hints, etc.). @@ -780,49 +806,28 @@ def main(): y: Mut = 3 x += y y += x + x_buf = Array(3) # 2 iterations + 1 slot + y_buf = Array(3) + x_buf[0] = x + y_buf[0] = y for i in range(4, 6): - x += i - x += y - y = i - y += x - assert x == 35 - assert y == 40 - return -``` - -Step 1 — the compiler replaces mutable-across-loop variables with index buffers, since memory -is write-once: - -```python -def main(): - x: Mut = 0 - y: Mut = 3 - x += y - y += x - size = 6 - 4 - x_buff = Array(size + 1) - x_buff[0] = x - y_buff = Array(size + 1) - y_buff[0] = y - for i in range(4, 6): - buff_idx = i - 4 - x_body: Mut = x_buff[buff_idx] - y_body: Mut = y_buff[buff_idx] - x_body += i - x_body += y_body - y_body = i - y_body += x_body - next_idx = buff_idx + 1 - x_buff[next_idx] = x_body - y_buff[next_idx] = y_body - x = x_buff[size] - y = y_buff[size] + idx = i - 4 + x_cur: Mut = x_buf[idx] + y_cur: Mut = y_buf[idx] + x_cur += i + x_cur += y_cur + y_cur = i + y_cur += x_cur + x_buf[idx + 1] = x_cur + y_buf[idx + 1] = y_cur + x = x_buf[2] + y = y_buf[2] assert x == 35 assert y == 40 return ``` -Step 2 — SSA-rename all reassignments to fresh names: +Step 1 — SSA-rename all reassignments to fresh names: ```python def main(): @@ -830,30 +835,28 @@ def main(): y = 3 x2 = x + y y2 = y + x2 - size = 6 - 4 - x_buff = Array(size + 1) - x_buff[0] = x2 - y_buff = Array(size + 1) - y_buff[0] = y2 + x_buf = Array(3) + y_buf = Array(3) + x_buf[0] = x2 + y_buf[0] = y2 for i in range(4, 6): - buff_idx = i - 4 - x_body1 = x_buff[buff_idx] - y_body1 = y_buff[buff_idx] - x_body2 = x_body1 + i - x_body3 = x_body2 + y_body1 - y_body2 = i - y_body3 = y_body2 + x_body3 - next_idx = buff_idx + 1 - x_buff[next_idx] = x_body3 - y_buff[next_idx] = y_body3 - x3 = x_buff[size] - y3 = y_buff[size] + idx = i - 4 + x_cur1 = x_buf[idx] + y_cur1 = y_buf[idx] + x_cur2 = x_cur1 + i + x_cur3 = x_cur2 + y_cur1 + y_cur2 = i + y_cur3 = y_cur2 + x_cur3 + x_buf[idx + 1] = x_cur3 + y_buf[idx + 1] = y_cur3 + x3 = x_buf[2] + y3 = y_buf[2] assert x3 == 35 assert y3 == 40 return ``` -Step 3 — lower the runtime loop to a recursive function: +Step 2 — lower the runtime loop to a recursive function: ```python def main(): @@ -861,33 +864,31 @@ def main(): y = 3 x2 = x + y y2 = y + x2 - size = 6 - 4 - x_buff = Array(size + 1) - x_buff[0] = x2 - y_buff = Array(size + 1) - y_buff[0] = y2 - loop_helper(4, x_buff, y_buff) - x3 = x_buff[size] - y3 = y_buff[size] + x_buf = Array(3) + y_buf = Array(3) + x_buf[0] = x2 + y_buf[0] = y2 + loop_helper(4, x_buf, y_buf) + x3 = x_buf[2] + y3 = y_buf[2] assert x3 == 35 assert y3 == 40 return -def loop_helper(i, x_buff, y_buff): +def loop_helper(i, x_buf, y_buf): if i == 6: return else: - buff_idx = i - 4 - x_body1 = x_buff[buff_idx] - y_body1 = y_buff[buff_idx] - x_body2 = x_body1 + i - x_body3 = x_body2 + y_body1 - y_body2 = i - y_body3 = y_body2 + x_body3 - next_idx = buff_idx + 1 - x_buff[next_idx] = x_body3 - y_buff[next_idx] = y_body3 - loop_helper(i + 1, x_buff, y_buff) + idx = i - 4 + x_cur1 = x_buf[idx] + y_cur1 = y_buf[idx] + x_cur2 = x_cur1 + i + x_cur3 = x_cur2 + y_cur1 + y_cur2 = i + y_cur3 = y_cur2 + x_cur3 + x_buf[idx + 1] = x_cur3 + y_buf[idx + 1] = y_cur3 + loop_helper(i + 1, x_buf, y_buf) return ``` diff --git a/crates/rec_aggregation/zkdsl_implem/hashing.py b/crates/rec_aggregation/zkdsl_implem/hashing.py index fb6e7ebf..dffc45fa 100644 --- a/crates/rec_aggregation/zkdsl_implem/hashing.py +++ b/crates/rec_aggregation/zkdsl_implem/hashing.py @@ -165,16 +165,24 @@ def slice_hash_runtime(data, num_chunks): states = Array((num_chunks - 1) * DIGEST_LEN) poseidon16_permute_half(iv, data, states) n_iters = num_chunks - 2 - state_ptr: Mut = states - data_ptr: Mut = data + DIGEST_LEN n_chunks_outer, remainder = euclidian_div_runtime(n_iters, PARTIAL_UNROLL_BATCH) - for _ in range(0, n_chunks_outer): + carry = Array((n_chunks_outer + 1) * 2) + carry[0] = states + carry[1] = data + DIGEST_LEN + for c in range(0, n_chunks_outer): + base = c * 2 + state_ptr: Mut = carry[base] + data_ptr: Mut = carry[base + 1] for _ in unroll(0, PARTIAL_UNROLL_BATCH): new_state = state_ptr + DIGEST_LEN poseidon16_permute_half(state_ptr, data_ptr, new_state) state_ptr = new_state data_ptr += DIGEST_LEN + carry[base + 2] = state_ptr + carry[base + 3] = data_ptr + state_ptr = carry[n_chunks_outer * 2] + data_ptr = carry[n_chunks_outer * 2 + 1] final_state_ptr = match_range( remainder, diff --git a/crates/rec_aggregation/zkdsl_implem/main.py b/crates/rec_aggregation/zkdsl_implem/main.py index 2fa65d17..f4b0f72b 100644 --- a/crates/rec_aggregation/zkdsl_implem/main.py +++ b/crates/rec_aggregation/zkdsl_implem/main.py @@ -175,12 +175,13 @@ def main(): pk = all_pubkeys + idx * PUB_KEY_SIZE xmss_verify(pk, message, merkle_chunks_for_slot) - counter: Mut = n_raw_xmss - n_bytecode_claims = n_recursions * 2 bytecode_claims = Array(n_bytecode_claims) + counter_outer_buf = Array(n_recursions + 1) + counter_outer_buf[0] = n_raw_xmss for rec_idx in range(0, n_recursions): + counter: Mut = counter_outer_buf[rec_idx] n_sub = aggregate_sizes[rec_idx] assert n_sub != 0 assert n_sub <= MAX_N_SIGS @@ -190,19 +191,33 @@ def main(): running_hash: Mut = build_iv(n_sub * PUB_KEY_SIZE) n_first = n_sub - 1 n_chunks, remainder = euclidian_div_runtime(n_first, PARTIAL_UNROLL_BATCH) - j: Mut = 0 - for _ in range(0, n_chunks): + pubkey_idx: Mut = 0 + inner_carry = Array((n_chunks + 1) * 3) + inner_carry[0] = counter + inner_carry[1] = running_hash + inner_carry[2] = pubkey_idx + for c in range(0, n_chunks): + base = c * 3 + cur_counter: Mut = inner_carry[base] + cur_running_hash: Mut = inner_carry[base + 1] + cur_pubkey_idx: Mut = inner_carry[base + 2] for u in unroll(0, PARTIAL_UNROLL_BATCH): - counter, running_hash = absorb_recursive_pubkey( - j + u, sub_indices_arr, n_total, all_pubkeys, buffer, counter, running_hash + cur_counter, cur_running_hash = absorb_recursive_pubkey( + cur_pubkey_idx + u, sub_indices_arr, n_total, all_pubkeys, buffer, cur_counter, cur_running_hash ) - j += PARTIAL_UNROLL_BATCH + cur_pubkey_idx += PARTIAL_UNROLL_BATCH + inner_carry[base + 3] = cur_counter + inner_carry[base + 4] = cur_running_hash + inner_carry[base + 5] = cur_pubkey_idx + counter = inner_carry[n_chunks * 3] + running_hash = inner_carry[n_chunks * 3 + 1] + pubkey_idx = inner_carry[n_chunks * 3 + 2] # Tail iterations tail_counter, tail_running_hash = match_range( remainder, range(0, PARTIAL_UNROLL_BATCH), lambda r: absorb_n_pubkeys_const( - r, j, sub_indices_arr, n_total, all_pubkeys, buffer, counter, running_hash + r, pubkey_idx, sub_indices_arr, n_total, all_pubkeys, buffer, counter, running_hash ), ) counter = tail_counter @@ -229,7 +244,9 @@ def main(): bytecode_claims[2 * rec_idx] = single_message_data_buf + BYTECODE_CLAIM_OFFSET bytecode_claims[2 * rec_idx + 1] = recursion(inner_pub_mem, initial_fiat_shamir_cap) + counter_outer_buf[rec_idx + 1] = counter + counter = counter_outer_buf[n_recursions] assert counter == n_total if n_recursions == 0: @@ -258,14 +275,18 @@ def reduce_bytecode_claims(bytecode_claims, n_bytecode_claims, bytecode_claim_ou count_block[0] = n_bytecode_claims for k in unroll(1, DIGEST_LEN): count_block[k] = 0 - running_capacity: Mut = slice_hash_continue(reduction_capacity, count_block, 1) + rc_buf = Array(n_bytecode_claims) + rc_buf[0] = slice_hash_continue(reduction_capacity, count_block, 1) for i in range(0, n_bytecode_claims - 1): + running_capacity: Mut = rc_buf[i] claim_ptr = bytecode_claims[i] for k in unroll(BYTECODE_CLAIM_SIZE, BYTECODE_CLAIM_SIZE_PADDED): assert claim_ptr[k] == 0 running_capacity = slice_hash_continue(running_capacity, claim_ptr, BYTECODE_CLAIM_NUM_CHUNKS) + rc_buf[i + 1] = running_capacity + running_capacity: Mut = rc_buf[n_bytecode_claims - 1] last_claim = bytecode_claims[n_bytecode_claims - 1] for k in unroll(BYTECODE_CLAIM_SIZE, BYTECODE_CLAIM_SIZE_PADDED): assert last_claim[k] == 0 diff --git a/crates/rec_aggregation/zkdsl_implem/recursion.py b/crates/rec_aggregation/zkdsl_implem/recursion.py index 2578f8a8..adec39c2 100644 --- a/crates/rec_aggregation/zkdsl_implem/recursion.py +++ b/crates/rec_aggregation/zkdsl_implem/recursion.py @@ -630,10 +630,13 @@ def verify_gkr_quotient(prev_fs, n_vars): claims_num[LOGUP_GKR_N_VARS_TO_SEND_COEFFS - 1] = first_claim_num claims_den[LOGUP_GKR_N_VARS_TO_SEND_COEFFS - 1] = first_claim_den + fs_buf = Array(n_vars - LOGUP_GKR_N_VARS_TO_SEND_COEFFS + 1) + fs_buf[0] = fs for i in range(LOGUP_GKR_N_VARS_TO_SEND_COEFFS, n_vars): - fs, points[i], claims_num[i], claims_den[i] = verify_gkr_quotient_step( - fs, i, points[i - 1], claims_num[i - 1], claims_den[i - 1] + fs_buf[i - LOGUP_GKR_N_VARS_TO_SEND_COEFFS + 1], points[i], claims_num[i], claims_den[i] = verify_gkr_quotient_step( + fs_buf[i - LOGUP_GKR_N_VARS_TO_SEND_COEFFS], i, points[i - 1], claims_num[i - 1], claims_den[i - 1] ) + fs = fs_buf[n_vars - LOGUP_GKR_N_VARS_TO_SEND_COEFFS] return ( fs, diff --git a/crates/rec_aggregation/zkdsl_implem/utils.py b/crates/rec_aggregation/zkdsl_implem/utils.py index 5a2db8e8..c3d9d46f 100644 --- a/crates/rec_aggregation/zkdsl_implem/utils.py +++ b/crates/rec_aggregation/zkdsl_implem/utils.py @@ -708,10 +708,13 @@ def mle_of_zeros_then_ones(point, n_zeros, n_vars): bits, _ = checked_decompose_bits(n_zeros) - res: Mut = Array(DIM) - set_to_one(res) + res_0 = Array(DIM) + set_to_one(res_0) + res_buf = Array(n_vars + 1) + res_buf[0] = res_0 for i in range(0, n_vars): + res: Mut = res_buf[i] p = point + (n_vars - 1 - i) * DIM if bits[F_BITS - 1 - i] == 0: one_minus_p = one_minus_self_extension_ret(p) @@ -719,7 +722,8 @@ def mle_of_zeros_then_ones(point, n_zeros, n_vars): res = add_extension_ret(tmp, p) else: res = mul_extension_ret(p, res) - return res + res_buf[i + 1] = res + return res_buf[n_vars] def mle_of_zeros_then_ones_pow2(point, log_n_zeros: Const, n_vars): @@ -727,11 +731,11 @@ def mle_of_zeros_then_ones_pow2(point, log_n_zeros: Const, n_vars): if log_n_zeros == n_vars: return ZERO_VEC_PTR n_factors = n_vars - log_n_zeros - prod: Mut = one_minus_self_extension_ret(point) + prod_buf = Array(n_factors) + prod_buf[0] = one_minus_self_extension_ret(point) for i in range(1, n_factors): - new_prod = mul_extension_ret(prod, one_minus_self_extension_ret(point + i * DIM)) - prod = new_prod - return sub_base_extension_ret(1, prod) + prod_buf[i] = mul_extension_ret(prod_buf[i - 1], one_minus_self_extension_ret(point + i * DIM)) + return sub_base_extension_ret(1, prod_buf[n_factors - 1]) @inline diff --git a/crates/rec_aggregation/zkdsl_implem/whir.py b/crates/rec_aggregation/zkdsl_implem/whir.py index d14a10ef..253b95a6 100644 --- a/crates/rec_aggregation/zkdsl_implem/whir.py +++ b/crates/rec_aggregation/zkdsl_implem/whir.py @@ -24,9 +24,6 @@ def whir_open( combination_randomness_powers_0, prev_claimed_sum, ): - fs: Mut = prev_fs - root: Mut = prev_root - claimed_sum: Mut = prev_claimed_sum n_rounds, n_final_vars, num_queries, num_oods, query_grinding_bits, folding_grinding = get_whir_params( n_vars, initial_log_inv_rate ) @@ -40,8 +37,17 @@ def whir_open( all_circle_values = Array(n_rounds + 1) all_combination_randomness_powers = Array(n_rounds) - domain_sz: Mut = n_vars + initial_log_inv_rate + carry = Array((n_rounds + 1) * 4) + carry[0] = prev_fs + carry[1] = prev_root + carry[2] = prev_claimed_sum + carry[3] = n_vars + initial_log_inv_rate for r in range(0, n_rounds): + base = r * 4 + fs: Mut = carry[base] + root: Mut = carry[base + 1] + claimed_sum: Mut = carry[base + 2] + domain_sz: Mut = carry[base + 3] is_first_round: Imm if r == 0: is_first_round = 1 @@ -72,6 +78,14 @@ def whir_open( domain_sz -= WHIR_FIRST_RS_REDUCTION_FACTOR else: domain_sz -= 1 + carry[base + 4] = fs + carry[base + 5] = root + carry[base + 6] = claimed_sum + carry[base + 7] = domain_sz + fs: Mut = carry[n_rounds * 4] + root = carry[n_rounds * 4 + 1] + claimed_sum: Mut = carry[n_rounds * 4 + 2] + domain_sz = carry[n_rounds * 4 + 3] fs, all_folding_randomness[n_rounds], claimed_sum = sumcheck_verify_with_grinding( fs, WHIR_SUBSEQUENT_FOLDING_FACTOR, claimed_sum, 2, folding_grinding[n_rounds] @@ -112,11 +126,15 @@ def whir_open( folding_randomness_global = Array(n_vars * DIM) - start: Mut = folding_randomness_global + start_buf = Array(n_rounds + 2) + start_buf[0] = folding_randomness_global for i in range(0, n_rounds + 1): + start: Mut = start_buf[i] for j in range(0, folding_factors[i]): copy_5(all_folding_randomness[i] + j * DIM, start + j * DIM) start += folding_factors[i] * DIM + start_buf[i + 1] = start + start = start_buf[n_rounds + 1] for j in range(0, n_final_vars): copy_5(all_folding_randomness[n_rounds + 1] + j * DIM, start + j * DIM) @@ -126,17 +144,23 @@ def whir_open( poly_eq_extension_dynamic_to( expanded_from_univariate, folding_randomness_global, all_ood_recovered_evals + i * DIM, n_vars ) - s: Mut = Array(DIM) + s_init = Array(DIM) dot_product_ee_dynamic( all_ood_recovered_evals, combination_randomness_powers_0, - s, + s_init, num_oods[0], ) - n_vars_remaining: Mut = n_vars - my_folding_randomness: Mut = folding_randomness_global + eval_carry = Array((n_rounds + 1) * 3) + eval_carry[0] = n_vars + eval_carry[1] = folding_randomness_global + eval_carry[2] = s_init for i in range(0, n_rounds): + base = i * 3 + n_vars_remaining: Mut = eval_carry[base] + my_folding_randomness: Mut = eval_carry[base + 1] + s: Mut = eval_carry[base + 2] n_vars_remaining -= folding_factors[i] my_ood_recovered_evals = Array(num_oods[i + 1] * DIM) combination_randomness_powers = all_combination_randomness_powers[i] @@ -168,6 +192,10 @@ def whir_open( ) s = add_extension_ret(s, s7) s = add_extension_ret(summed_ood, s) + eval_carry[base + 3] = n_vars_remaining + eval_carry[base + 4] = my_folding_randomness + eval_carry[base + 5] = s + s = eval_carry[n_rounds * 3 + 2] final_value = match_range( n_final_vars, range(MAX_NUM_VARIABLES_TO_SEND_COEFFS - WHIR_SUBSEQUENT_FOLDING_FACTOR, MAX_NUM_VARIABLES_TO_SEND_COEFFS + 1), @@ -185,16 +213,24 @@ def sumcheck_verify(fs, n_steps, claimed_sum, degree: Const): def sumcheck_verify_helper(prev_fs, n_steps, prev_claimed_sum, degree: Const, challenges): - fs: Mut = prev_fs - claimed_sum: Mut = prev_claimed_sum + carry = Array((n_steps + 1) * 2) + carry[0] = prev_fs + carry[1] = prev_claimed_sum for sc_round in range(0, n_steps): + base = sc_round * 2 + fs: Mut = carry[base] + claimed_sum: Mut = carry[base + 1] fs, poly = fs_receive_ef_inlined(fs, degree + 1) polynomial_sum_at_0_and_1(poly, degree, claimed_sum) fs, rand = fs_sample_ef(fs) claimed_sum = univariate_polynomial_eval(poly, rand, degree) copy_5(rand, challenges + sc_round * DIM) + carry[base + 2] = fs + carry[base + 3] = claimed_sum - return fs, claimed_sum + final_fs = carry[n_steps * 2] + final_claimed_sum = carry[n_steps * 2 + 1] + return final_fs, final_claimed_sum def sumcheck_verify_reversed(fs, n_steps, claimed_sum, degree: Const): @@ -227,18 +263,26 @@ def sumcheck_verify_reversed_helper_const(prev_fs, n_steps: Const, prev_claimed_ def sumcheck_verify_with_grinding(prev_fs, n_steps, prev_claimed_sum, degree: Const, folding_grinding_bits): - fs: Mut = prev_fs - claimed_sum: Mut = prev_claimed_sum challenges = Array(n_steps * DIM) + carry = Array((n_steps + 1) * 2) + carry[0] = prev_fs + carry[1] = prev_claimed_sum for sc_round in range(0, n_steps): + base = sc_round * 2 + fs: Mut = carry[base] + claimed_sum: Mut = carry[base + 1] fs, poly = fs_receive_ef_inlined(fs, degree + 1) polynomial_sum_at_0_and_1(poly, degree, claimed_sum) fs = fs_grinding(fs, folding_grinding_bits) fs, rand = fs_sample_ef(fs) claimed_sum = univariate_polynomial_eval(poly, rand, degree) copy_5(rand, challenges + sc_round * DIM) + carry[base + 2] = fs + carry[base + 3] = claimed_sum - return fs, challenges, claimed_sum + final_fs = carry[n_steps * 2] + final_claimed_sum = carry[n_steps * 2 + 1] + return final_fs, challenges, final_claimed_sum @inline From 0934727b38aeea94995f7b62dff3faccf2c04f18 Mon Sep 17 00:00:00 2001 From: Tom Wambsgans Date: Mon, 1 Jun 2026 23:58:05 +0400 Subject: [PATCH 2/2] make tests pass --- .../tests/test_data/program_130.py | 20 -- .../tests/test_data/program_131.py | 11 - .../tests/test_data/program_132.py | 11 - .../tests/test_data/program_133.py | 16 - .../tests/test_data/program_134.py | 23 -- .../tests/test_data/program_135.py | 22 -- .../tests/test_data/program_136.py | 20 -- .../tests/test_data/program_137.py | 22 -- .../tests/test_data/program_138.py | 17 - .../tests/test_data/program_139.py | 25 -- .../tests/test_data/program_140.py | 13 - .../tests/test_data/program_141.py | 27 -- .../tests/test_data/program_142.py | 296 ------------------ .../tests/test_data/program_143.py | 8 +- .../tests/test_data/program_167.py | 6 +- .../tests/test_data/program_171.py | 197 +++++++----- .../tests/test_data/program_183.py | 10 +- 17 files changed, 131 insertions(+), 613 deletions(-) delete mode 100644 crates/lean_compiler/tests/test_data/program_130.py delete mode 100644 crates/lean_compiler/tests/test_data/program_131.py delete mode 100644 crates/lean_compiler/tests/test_data/program_132.py delete mode 100644 crates/lean_compiler/tests/test_data/program_133.py delete mode 100644 crates/lean_compiler/tests/test_data/program_134.py delete mode 100644 crates/lean_compiler/tests/test_data/program_135.py delete mode 100644 crates/lean_compiler/tests/test_data/program_136.py delete mode 100644 crates/lean_compiler/tests/test_data/program_137.py delete mode 100644 crates/lean_compiler/tests/test_data/program_138.py delete mode 100644 crates/lean_compiler/tests/test_data/program_139.py delete mode 100644 crates/lean_compiler/tests/test_data/program_140.py delete mode 100644 crates/lean_compiler/tests/test_data/program_141.py delete mode 100644 crates/lean_compiler/tests/test_data/program_142.py diff --git a/crates/lean_compiler/tests/test_data/program_130.py b/crates/lean_compiler/tests/test_data/program_130.py deleted file mode 100644 index 5749d8f1..00000000 --- a/crates/lean_compiler/tests/test_data/program_130.py +++ /dev/null @@ -1,20 +0,0 @@ -from snark_lib import * -# Test: Mutable variables in non-unrolled loops -# This tests the automatic buffer transformation for mutable variables - - -def main(): - x: Mut = 0 - y: Mut = 3 - x += y - y += x - assert x == 3 - assert y == 6 - for i in range(4, 6): - x += i - x += y - y = i - y += x - assert x == 35 - assert y == 40 - return diff --git a/crates/lean_compiler/tests/test_data/program_131.py b/crates/lean_compiler/tests/test_data/program_131.py deleted file mode 100644 index c36f8fce..00000000 --- a/crates/lean_compiler/tests/test_data/program_131.py +++ /dev/null @@ -1,11 +0,0 @@ -from snark_lib import * -# Test: Simple mutable variable in non-unrolled loop -# Sum of 1 to 10 - - -def main(): - s: Mut = 0 - for i in range(1, 11): - s += i - assert s == 55 - return diff --git a/crates/lean_compiler/tests/test_data/program_132.py b/crates/lean_compiler/tests/test_data/program_132.py deleted file mode 100644 index 611ca57d..00000000 --- a/crates/lean_compiler/tests/test_data/program_132.py +++ /dev/null @@ -1,11 +0,0 @@ -from snark_lib import * -# Test: Mutable variables with different operations in non-unrolled loop - - -def main(): - product: Mut = 1 - for i in range(1, 6): - product *= i - # 1 * 2 * 3 * 4 * 5 = 120 - assert product == 120 - return diff --git a/crates/lean_compiler/tests/test_data/program_133.py b/crates/lean_compiler/tests/test_data/program_133.py deleted file mode 100644 index e25c7b34..00000000 --- a/crates/lean_compiler/tests/test_data/program_133.py +++ /dev/null @@ -1,16 +0,0 @@ -from snark_lib import * -# Test: Nested non-unrolled loops with mutable variables -# Computes sum of i*j for i in 0..3, j in 0..4 - - -def main(): - total: Mut = 0 - for i in range(0, 3): - for j in range(0, 4): - total += i * j - # i=0: 0*0 + 0*1 + 0*2 + 0*3 = 0 - # i=1: 1*0 + 1*1 + 1*2 + 1*3 = 6 - # i=2: 2*0 + 2*1 + 2*2 + 2*3 = 12 - # total = 0 + 6 + 12 = 18 - assert total == 18 - return diff --git a/crates/lean_compiler/tests/test_data/program_134.py b/crates/lean_compiler/tests/test_data/program_134.py deleted file mode 100644 index 1d2a4363..00000000 --- a/crates/lean_compiler/tests/test_data/program_134.py +++ /dev/null @@ -1,23 +0,0 @@ -from snark_lib import * -# Test: Conditionals inside non-unrolled loop with mutable variables -# Tests if/else branches that modify mutable variables differently - - -def main(): - a: Mut = 0 - b: Mut = 100 - for i in range(0, 5): - if i == 2: - a += 10 - b -= 50 - else: - a += 1 - b -= 1 - # i=0: a=1, b=99 - # i=1: a=2, b=98 - # i=2: a=12, b=48 - # i=3: a=13, b=47 - # i=4: a=14, b=46 - assert a == 14 - assert b == 46 - return diff --git a/crates/lean_compiler/tests/test_data/program_135.py b/crates/lean_compiler/tests/test_data/program_135.py deleted file mode 100644 index 1952905a..00000000 --- a/crates/lean_compiler/tests/test_data/program_135.py +++ /dev/null @@ -1,22 +0,0 @@ -from snark_lib import * -# Test: Match statement inside non-unrolled loop with mutable variables - - -def main(): - score: Mut = 0 - for i in range(0, 4): - match i: - case 0: - score += 100 - case 1: - score += 50 - case 2: - score += 25 - case 3: - score += 10 - # i=0: score=100 - # i=1: score=150 - # i=2: score=175 - # i=3: score=185 - assert score == 185 - return diff --git a/crates/lean_compiler/tests/test_data/program_136.py b/crates/lean_compiler/tests/test_data/program_136.py deleted file mode 100644 index 07a3b515..00000000 --- a/crates/lean_compiler/tests/test_data/program_136.py +++ /dev/null @@ -1,20 +0,0 @@ -from snark_lib import * -# Test: Complex nested loops with multiple mutable variables -# Outer loop updates one set of vars, inner loop updates another, -# and they interact with each other - - -def main(): - outer_sum: Mut = 0 - inner_count: Mut = 0 - for i in range(1, 4): - outer_sum += i * 10 - for j in range(0, i): - inner_count += 1 - outer_sum += j - # i=1: outer_sum=10, inner: j=0: inner_count=1, outer_sum=10 - # i=2: outer_sum=30, inner: j=0: inner_count=2, outer_sum=30; j=1: inner_count=3, outer_sum=31 - # i=3: outer_sum=61, inner: j=0: inner_count=4, outer_sum=61; j=1: inner_count=5, outer_sum=62; j=2: inner_count=6, outer_sum=64 - assert outer_sum == 64 - assert inner_count == 6 - return diff --git a/crates/lean_compiler/tests/test_data/program_137.py b/crates/lean_compiler/tests/test_data/program_137.py deleted file mode 100644 index 4cda126c..00000000 --- a/crates/lean_compiler/tests/test_data/program_137.py +++ /dev/null @@ -1,22 +0,0 @@ -from snark_lib import * -# Test: Deeply nested conditionals inside non-unrolled loop - - -def main(): - result: Mut = 0 - for i in range(0, 6): - if i == 0: - result += 1 - elif i == 1: - result += 2 - elif i == 2: - result += 4 - elif i == 3: - result += 8 - elif i == 4: - result += 16 - else: - result += 32 - # Powers of 2: 1 + 2 + 4 + 8 + 16 + 32 = 63 - assert result == 63 - return diff --git a/crates/lean_compiler/tests/test_data/program_138.py b/crates/lean_compiler/tests/test_data/program_138.py deleted file mode 100644 index f2977c49..00000000 --- a/crates/lean_compiler/tests/test_data/program_138.py +++ /dev/null @@ -1,17 +0,0 @@ -from snark_lib import * -# Test: Mix of unrolled outer loop and non-unrolled inner loop with mutable vars - - -def main(): - total: Mut = 0 - for i in unroll(0, 3): - # Inner loop is non-unrolled, uses mutable variable - inner_sum: Mut = 0 - for j in range(0, 4): - inner_sum += j + i - total += inner_sum - # i=0: inner_sum = 0+1+2+3 = 6, total = 6 - # i=1: inner_sum = 1+2+3+4 = 10, total = 16 - # i=2: inner_sum = 2+3+4+5 = 14, total = 30 - assert total == 30 - return diff --git a/crates/lean_compiler/tests/test_data/program_139.py b/crates/lean_compiler/tests/test_data/program_139.py deleted file mode 100644 index c49f1373..00000000 --- a/crates/lean_compiler/tests/test_data/program_139.py +++ /dev/null @@ -1,25 +0,0 @@ -from snark_lib import * -# Test: Mutable variable with array operations inside non-unrolled loop - - -def main(): - arr = Array(5) - arr[0] = 10 - arr[1] = 20 - arr[2] = 30 - arr[3] = 40 - arr[4] = 50 - - sum: Mut = 0 - prev: Mut = 0 - for i in range(0, 5): - val = arr[i] - sum += val - # Track running difference - diff = val - prev - prev = val - # sum = 10 + 20 + 30 + 40 + 50 = 150 - # prev = 50 (last value) - assert sum == 150 - assert prev == 50 - return diff --git a/crates/lean_compiler/tests/test_data/program_140.py b/crates/lean_compiler/tests/test_data/program_140.py deleted file mode 100644 index ca8c7129..00000000 --- a/crates/lean_compiler/tests/test_data/program_140.py +++ /dev/null @@ -1,13 +0,0 @@ -from snark_lib import * -# Test: Three levels of nested non-unrolled loops with mutable variable - - -def main(): - count: Mut = 0 - for i in range(0, 2): - for j in range(0, 3): - for k in range(0, 4): - count += 1 - # Total iterations: 2 * 3 * 4 = 24 - assert count == 24 - return diff --git a/crates/lean_compiler/tests/test_data/program_141.py b/crates/lean_compiler/tests/test_data/program_141.py deleted file mode 100644 index 1ef3b0a2..00000000 --- a/crates/lean_compiler/tests/test_data/program_141.py +++ /dev/null @@ -1,27 +0,0 @@ -from snark_lib import * -# Test: Match with conditions inside non-unrolled loop - - -def main(): - a: Mut = 0 - b: Mut = 0 - for i in range(0, 3): - match i: - case 0: - a += 1 - if a == 1: - b += 10 - case 1: - a += 2 - if a == 3: - b += 20 - case 2: - a += 4 - if a == 7: - b += 40 - # i=0: a=1, b=10 - # i=1: a=3, b=30 - # i=2: a=7, b=70 - assert a == 7 - assert b == 70 - return diff --git a/crates/lean_compiler/tests/test_data/program_142.py b/crates/lean_compiler/tests/test_data/program_142.py deleted file mode 100644 index 2bc3e0f2..00000000 --- a/crates/lean_compiler/tests/test_data/program_142.py +++ /dev/null @@ -1,296 +0,0 @@ -from snark_lib import * -# Comprehensive stress test for mutable variables in non-unrolled loops -# Tests: nested loops, conditionals, match, multiple mutable vars, edge cases - - -def main(): - # ========================================================================= - # TEST 1: Triple nested loops with multiple interacting mutable variables - # ========================================================================= - a: Mut = 0 - b: Mut = 1 - c: Mut = 100 - for i in range(0, 3): - for j in range(0, 4): - for k in range(0, 2): - a += 1 - b += a - c -= 1 - # a = 3*4*2 = 24 increments = 24 - # b = 1 + 1 + 3 + 6 + 10 + 15 + 21 + 28 + 36 + 45 + 55 + 66 + 78 + 91 + 105 + 120 + 136 + 153 + 171 + 190 + 210 + 231 + 253 + 276 + 300 = 301 - # c = 100 - 24 = 76 - assert a == 24 - assert b == 301 - assert c == 76 - - # ========================================================================= - # TEST 2: Mutable variable modified differently in if/else branches - # ========================================================================= - x: Mut = 0 - y: Mut = 0 - for i in range(0, 8): - if i == 0: - x += 100 - y += 1 - elif i == 1: - x += 50 - y += 2 - elif i == 2: - x += 25 - y += 4 - elif i == 3: - x -= 10 - y += 8 - else: - x += i - y *= 2 - # i=0: x=100, y=1 - # i=1: x=150, y=3 - # i=2: x=175, y=7 - # i=3: x=165, y=15 - # i=4: x=169, y=30 - # i=5: x=174, y=60 - # i=6: x=180, y=120 - # i=7: x=187, y=240 - assert x == 187 - assert y == 240 - - # ========================================================================= - # TEST 3: Match statements with mutable variables in nested loop - # ========================================================================= - score: Mut = 0 - multiplier: Mut = 1 - for round in range(0, 3): - for action in range(0, 4): - match action: - case 0: - score += 10 * multiplier - case 1: - score += 5 * multiplier - multiplier += 1 - case 2: - score -= 2 * multiplier - case 3: - multiplier *= 2 - score += multiplier - # Round 0: action 0: score=10, mult=1 - # action 1: score=15, mult=2 - # action 2: score=11, mult=2 - # action 3: mult=4, score=15 - # Round 1: action 0: score=55, mult=4 - # action 1: score=75, mult=5 - # action 2: score=65, mult=5 - # action 3: mult=10, score=75 - # Round 2: action 0: score=175, mult=10 - # action 1: score=225, mult=11 - # action 2: score=203, mult=11 - # action 3: mult=22, score=225 - assert score == 225 - assert multiplier == 22 - - # ========================================================================= - # TEST 4: Loop with non-zero start index - # ========================================================================= - sum_from_5: Mut = 0 - for i in range(5, 10): - sum_from_5 += i - # 5 + 6 + 7 + 8 + 9 = 35 - assert sum_from_5 == 35 - - # ========================================================================= - # TEST 5: Single iteration loop (edge case) - # ========================================================================= - single: Mut = 42 - for i in range(7, 8): - single += i - assert single == 49 - - # ========================================================================= - # TEST 6: Mutable variable reassigned multiple times per iteration - # ========================================================================= - multi: Mut = 0 - for i in range(1, 5): - multi += i - multi *= 2 - multi -= 1 - multi += i - # i=1: multi = 0+1=1, *2=2, -1=1, +1=2 - # i=2: multi = 2+2=4, *2=8, -1=7, +2=9 - # i=3: multi = 9+3=12, *2=24, -1=23, +3=26 - # i=4: multi = 26+4=30, *2=60, -1=59, +4=63 - assert multi == 63 - - # ========================================================================= - # TEST 7: Mutable variables with array operations - # ========================================================================= - arr = Array(6) - arr[0] = 1 - arr[1] = 2 - arr[2] = 4 - arr[3] = 8 - arr[4] = 16 - arr[5] = 32 - - arr_sum: Mut = 0 - arr_prod: Mut = 1 - last_val: Mut = 0 - for idx in range(0, 6): - val = arr[idx] - arr_sum += val - arr_prod *= val + 1 - last_val = val - # sum = 1+2+4+8+16+32 = 63 - # prod = 2*3*5*9*17*33 = 151470 - # last_val = 32 - assert arr_sum == 63 - assert arr_prod == 151470 - assert last_val == 32 - - # ========================================================================= - # TEST 8: Nested conditionals inside nested loops - # ========================================================================= - complex: Mut = 0 - for i in range(0, 3): - for j in range(0, 3): - if i == j: - if i == 0: - complex += 100 - elif i == 1: - complex += 200 - else: - complex += 300 - else: - if i == 0: - complex += 1 - else: - complex += 2 - # i=0,j=0: i==j, i==0: +100 -> 100 - # i=0,j=1: i!=j, i==0: +1 -> 101 - # i=0,j=2: i!=j, i==0: +1 -> 102 - # i=1,j=0: i!=j, i!=0: +2 -> 104 - # i=1,j=1: i==j, i==1: +200 -> 304 - # i=1,j=2: i!=j, i!=0: +2 -> 306 - # i=2,j=0: i!=j, i!=0: +2 -> 308 - # i=2,j=1: i!=j, i!=0: +2 -> 310 - # i=2,j=2: i==j, i==2: +300 -> 610 - assert complex == 610 - - # ========================================================================= - # TEST 9: Function calls with mutable variables - # ========================================================================= - func_result: Mut = 0 - for i in range(1, 6): - increment = compute_increment(i) - func_result += increment - # compute_increment(1) = 1 - # compute_increment(2) = 4 - # compute_increment(3) = 9 - # compute_increment(4) = 16 - # compute_increment(5) = 25 - # sum = 1 + 4 + 9 + 16 + 25 = 55 - assert func_result == 55 - - # ========================================================================= - # TEST 10: Outer mutable modified by inner loop result - # ========================================================================= - outer_acc: Mut = 0 - for i in range(1, 4): - inner_acc: Mut = 0 - for j in range(0, i): - inner_acc += j + 1 - outer_acc += inner_acc * i - # i=1: inner_acc = 1, outer_acc = 1*1 = 1 - # i=2: inner_acc = 1+2 = 3, outer_acc = 1 + 3*2 = 7 - # i=3: inner_acc = 1+2+3 = 6, outer_acc = 7 + 6*3 = 25 - assert outer_acc == 25 - - # ========================================================================= - # TEST 11: Large number of iterations - # ========================================================================= - big_sum: Mut = 0 - for i in range(0, 100): - big_sum += 1 - assert big_sum == 100 - - # ========================================================================= - # TEST 12: Mutable with division and subtraction - # ========================================================================= - countdown: Mut = 1000 - steps: Mut = 0 - for i in range(1, 11): - countdown -= i * 10 - steps += 1 - # countdown = 1000 - 10 - 20 - 30 - 40 - 50 - 60 - 70 - 80 - 90 - 100 - # = 1000 - 550 = 450 - assert countdown == 450 - assert steps == 10 - - # ========================================================================= - # TEST 13: Mix of unrolled inner and non-unrolled outer - # ========================================================================= - mixed: Mut = 0 - for i in range(0, 4): - for j in unroll(0, 3): - mixed += i * 3 + j - # i=0: 0+1+2 = 3 - # i=1: 3+4+5 = 12 - # i=2: 6+7+8 = 21 - # i=3: 9+10+11 = 30 - # total = 3+12+21+30 = 66 - assert mixed == 66 - - # ========================================================================= - # TEST 14: Multiple mutable variables, some modified some not per iteration - # ========================================================================= - always: Mut = 0 - sometimes: Mut = 100 - rarely: Mut = 1000 - for i in range(0, 10): - always += 1 - if i == 3: - sometimes += 50 - if i == 7: - sometimes -= 25 - rarely += 500 - if i == 9: - rarely *= 2 - assert always == 10 - assert sometimes == 125 - assert rarely == 3000 - - # ========================================================================= - # TEST 15: Chained mutable dependencies in same iteration - # ========================================================================= - chain_a: Mut = 1 - chain_b: Mut = 0 - chain_c: Mut = 0 - for i in range(0, 5): - chain_a *= 2 - chain_b = chain_a + i - chain_c += chain_b - # i=0: a=2, b=2+0=2, c=0+2=2 - # i=1: a=4, b=4+1=5, c=2+5=7 - # i=2: a=8, b=8+2=10, c=7+10=17 - # i=3: a=16, b=16+3=19, c=17+19=36 - # i=4: a=32, b=32+4=36, c=36+36=72 - assert chain_a == 32 - assert chain_b == 36 - assert chain_c == 72 - - # ========================================================================= - # TEST 16: Zero-iteration loop (edge case - empty range) - # No iterations should occur for 5..5 - # ========================================================================= - zero_iter: Mut = 999 - for i in range(5, 5): - zero_iter = 0 - assert zero_iter == 999 - - # ========================================================================= - # All tests passed! - # ========================================================================= - return - - -def compute_increment(n): - return n * n diff --git a/crates/lean_compiler/tests/test_data/program_143.py b/crates/lean_compiler/tests/test_data/program_143.py index 1e16c510..8a4913d2 100644 --- a/crates/lean_compiler/tests/test_data/program_143.py +++ b/crates/lean_compiler/tests/test_data/program_143.py @@ -193,16 +193,18 @@ def double(x): # Inline function: multiply by 3 @inline def triple(x): - y: Mut = x + y = x two: Imm match y - x + 1: case 0: assert False case 1: two = 2 + y_buf = Array(two + 1) + y_buf[0] = y for i in range(0, two): - y = y + x - return y + y_buf[i + 1] = y_buf[i] + x + return y_buf[two] # Inline function that calls another inline function diff --git a/crates/lean_compiler/tests/test_data/program_167.py b/crates/lean_compiler/tests/test_data/program_167.py index 17ac8968..2cc2d63d 100644 --- a/crates/lean_compiler/tests/test_data/program_167.py +++ b/crates/lean_compiler/tests/test_data/program_167.py @@ -5,8 +5,10 @@ def main(): x = (len(ARR) + ARR[2]) / ARR[3] - sum: Mut = 0 + sum_buf = Array(x + 1) + sum_buf[0] = 0 for i in range(0, x): - sum += 1 + sum_buf[i + 1] = sum_buf[i] + 1 + sum = sum_buf[x] assert sum == 2 return diff --git a/crates/lean_compiler/tests/test_data/program_171.py b/crates/lean_compiler/tests/test_data/program_171.py index 4af97391..84e142dd 100644 --- a/crates/lean_compiler/tests/test_data/program_171.py +++ b/crates/lean_compiler/tests/test_data/program_171.py @@ -11,30 +11,35 @@ @inline def count_up(n): """Count from 0 to n-1, return the sum""" - acc: Mut = 0 + acc_buf = Array(n + 1) + acc_buf[0] = 0 for i in range(0, n): - acc = acc + 1 - return acc + acc_buf[i + 1] = acc_buf[i] + 1 + return acc_buf[n] @inline def sum_range(start, end): """Sum integers from start to end-1""" - total: Mut = 0 + total_buf = Array(end - start + 1) + total_buf[0] = 0 for i in range(start, end): - total = total + i - return total + idx = i - start + total_buf[idx + 1] = total_buf[idx] + i + return total_buf[end - start] @inline def double_count(n): """Two mutable variables in same function""" - a: Mut = 0 - b: Mut = 100 + a_buf = Array(n + 1) + b_buf = Array(n + 1) + a_buf[0] = 0 + b_buf[0] = 100 for i in range(0, n): - a = a + 1 - b = b - 1 - return a + b + a_buf[i + 1] = a_buf[i] + 1 + b_buf[i + 1] = b_buf[i] - 1 + return a_buf[n] + b_buf[n] # ============================================================================ @@ -45,19 +50,21 @@ def double_count(n): @inline def inner_loop(k): """Inner inline function""" - x: Mut = 0 + x_buf = Array(k + 1) + x_buf[0] = 0 for j in range(0, k): - x = x + j - return x + x_buf[j + 1] = x_buf[j] + j + return x_buf[k] @inline def outer_with_inner(n): """Outer inline that calls inner inline""" - result: Mut = 0 + result_buf = Array(n + 1) + result_buf[0] = 0 for i in range(0, n): - result = result + inner_loop(i) - return result + result_buf[i + 1] = result_buf[i] + inner_loop(i) + return result_buf[n] @inline @@ -76,25 +83,29 @@ def deep_nested(a): @inline def complex_muts(n): """Multiple mutable variables with interdependencies""" - x: Mut = 0 - y: Mut = 1 - z: Mut = 2 + x_buf = Array(n + 1) + y_buf = Array(n + 1) + z_buf = Array(n + 1) + x_buf[0] = 0 + y_buf[0] = 1 + z_buf[0] = 2 for i in range(0, n): - temp = x + y - x = y - y = z - z = temp + z - return x + y + z + temp = x_buf[i] + y_buf[i] + x_buf[i + 1] = y_buf[i] + y_buf[i + 1] = z_buf[i] + z_buf[i + 1] = temp + z_buf[i] + return x_buf[n] + y_buf[n] + z_buf[n] @inline def with_immutable(n): """Mix of mutable and immutable inside inline""" - m: Mut = 0 + m_buf = Array(n + 1) + m_buf[0] = 0 for i in range(0, n): imm = i * 2 - m = m + imm - final_imm = m + 1000 + m_buf[i + 1] = m_buf[i] + imm + final_imm = m_buf[n] + 1000 return final_imm @@ -155,25 +166,30 @@ def inline_with_nested_branch(a, b): @inline def multi_return_inline(n): """Inline returning multiple values""" - a: Mut = 0 - b: Mut = 100 + a_buf = Array(n + 1) + b_buf = Array(n + 1) + a_buf[0] = 0 + b_buf[0] = 100 for i in range(0, n): - a = a + 1 - b = b + 2 - return a, b + a_buf[i + 1] = a_buf[i] + 1 + b_buf[i + 1] = b_buf[i] + 2 + return a_buf[n], b_buf[n] @inline def triple_return(x): """Inline returning three values with different computations""" - m1: Mut = x - m2: Mut = x * 2 - m3: Mut = x * 3 + m1_buf = Array(4) + m2_buf = Array(4) + m3_buf = Array(4) + m1_buf[0] = x + m2_buf[0] = x * 2 + m3_buf[0] = x * 3 for i in range(0, 3): - m1 = m1 + 1 - m2 = m2 + 2 - m3 = m3 + 3 - return m1, m2, m3 + m1_buf[i + 1] = m1_buf[i] + 1 + m2_buf[i + 1] = m2_buf[i] + 2 + m3_buf[i + 1] = m3_buf[i] + 3 + return m1_buf[3], m2_buf[3], m3_buf[3] # ============================================================================ @@ -184,40 +200,44 @@ def triple_return(x): @inline def level_d(x): """Deepest level""" - acc: Mut = x + acc_buf = Array(3) + acc_buf[0] = x for i in range(0, 2): - acc = acc + 1 - return acc + acc_buf[i + 1] = acc_buf[i] + 1 + return acc_buf[2] @inline def level_c(x): """Calls level_d""" tmp = level_d(x) - acc: Mut = tmp + acc_buf = Array(3) + acc_buf[0] = tmp for i in range(0, 2): - acc = acc + 10 - return acc + acc_buf[i + 1] = acc_buf[i] + 10 + return acc_buf[2] @inline def level_b(x): """Calls level_c""" tmp = level_c(x) - acc: Mut = tmp + acc_buf = Array(3) + acc_buf[0] = tmp for i in range(0, 2): - acc = acc + 100 - return acc + acc_buf[i + 1] = acc_buf[i] + 100 + return acc_buf[2] @inline def level_a(x): """Calls level_b - 4 levels deep""" tmp = level_b(x) - acc: Mut = tmp + acc_buf = Array(3) + acc_buf[0] = tmp for i in range(0, 2): - acc = acc + 1000 - return acc + acc_buf[i + 1] = acc_buf[i] + 1000 + return acc_buf[2] # ============================================================================ @@ -257,26 +277,29 @@ def inline_modify_array(base): @inline def chain_a(x): - m: Mut = x + m_buf = Array(3) + m_buf[0] = x for i in range(0, 2): - m = m + 1 - return m + m_buf[i + 1] = m_buf[i] + 1 + return m_buf[2] @inline def chain_b(x): - m: Mut = x + m_buf = Array(3) + m_buf[0] = x for i in range(0, 2): - m = m * 2 - return m + m_buf[i + 1] = m_buf[i] * 2 + return m_buf[2] @inline def chain_c(x): - m: Mut = x + m_buf = Array(3) + m_buf[0] = x for i in range(0, 2): - m = m + 10 - return m + m_buf[i + 1] = m_buf[i] + 10 + return m_buf[2] # ============================================================================ @@ -287,28 +310,38 @@ def chain_c(x): @inline def many_vars(seed): """Inline with 10 mutable variables""" - v0: Mut = seed - v1: Mut = seed + 1 - v2: Mut = seed + 2 - v3: Mut = seed + 3 - v4: Mut = seed + 4 - v5: Mut = seed + 5 - v6: Mut = seed + 6 - v7: Mut = seed + 7 - v8: Mut = seed + 8 - v9: Mut = seed + 9 + v0_buf = Array(4) + v1_buf = Array(4) + v2_buf = Array(4) + v3_buf = Array(4) + v4_buf = Array(4) + v5_buf = Array(4) + v6_buf = Array(4) + v7_buf = Array(4) + v8_buf = Array(4) + v9_buf = Array(4) + v0_buf[0] = seed + v1_buf[0] = seed + 1 + v2_buf[0] = seed + 2 + v3_buf[0] = seed + 3 + v4_buf[0] = seed + 4 + v5_buf[0] = seed + 5 + v6_buf[0] = seed + 6 + v7_buf[0] = seed + 7 + v8_buf[0] = seed + 8 + v9_buf[0] = seed + 9 for i in range(0, 3): - v0 = v0 + v1 - v1 = v1 + v2 - v2 = v2 + v3 - v3 = v3 + v4 - v4 = v4 + v5 - v5 = v5 + v6 - v6 = v6 + v7 - v7 = v7 + v8 - v8 = v8 + v9 - v9 = v9 + 1 - return v0 + v1 + v2 + v3 + v4 + v5 + v6 + v7 + v8 + v9 + v0_buf[i + 1] = v0_buf[i] + v1_buf[i] + v1_buf[i + 1] = v1_buf[i] + v2_buf[i] + v2_buf[i + 1] = v2_buf[i] + v3_buf[i] + v3_buf[i + 1] = v3_buf[i] + v4_buf[i] + v4_buf[i + 1] = v4_buf[i] + v5_buf[i] + v5_buf[i + 1] = v5_buf[i] + v6_buf[i] + v6_buf[i + 1] = v6_buf[i] + v7_buf[i] + v7_buf[i + 1] = v7_buf[i] + v8_buf[i] + v8_buf[i + 1] = v8_buf[i] + v9_buf[i] + v9_buf[i + 1] = v9_buf[i] + 1 + return v0_buf[3] + v1_buf[3] + v2_buf[3] + v3_buf[3] + v4_buf[3] + v5_buf[3] + v6_buf[3] + v7_buf[3] + v8_buf[3] + v9_buf[3] # ============================================================================ diff --git a/crates/lean_compiler/tests/test_data/program_183.py b/crates/lean_compiler/tests/test_data/program_183.py index 3cb3aea1..601b53ba 100644 --- a/crates/lean_compiler/tests/test_data/program_183.py +++ b/crates/lean_compiler/tests/test_data/program_183.py @@ -2,18 +2,22 @@ # declared in the (outer) loop scope and read after the branch. The `match_range` # expansion must reuse the outer cell, not shadow it — otherwise the read after the # branch sees uninitialized memory. `i` is a runtime loop variable, so neither the -# `if` nor the `match_range` folds at compile time. +# `if` nor the `match_range` folds at compile time. The accumulator is threaded +# through an explicit buffer (loop-carried mutables are unsupported). def sq(n: Const): return n * n def main(): - acc: Mut = 0 + acc_buf = Array(4) + acc_buf[0] = 0 for i in range(1, 4): + idx = i - 1 contrib: Imm if i != 0: contrib = match_range(i, range(1, 4), lambda k: sq(k)) else: contrib = 0 - acc = acc + contrib + acc_buf[idx + 1] = acc_buf[idx] + contrib + acc = acc_buf[3] assert acc == 14 # sq(1) + sq(2) + sq(3) = 1 + 4 + 9 return