Skip to content

Commit c15af5d

Browse files
committed
more compiler tests
1 parent 85a02fc commit c15af5d

7 files changed

Lines changed: 324 additions & 0 deletions

File tree

crates/lean_compiler/tests/test_compiler.rs

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -274,3 +274,37 @@ def main():
274274
"#;
275275
compile_and_run(&ProgramSource::Raw(program.to_string()), &[], false);
276276
}
277+
278+
#[test]
279+
#[rustfmt::skip]
280+
fn test_soundness_suite() {
281+
#[allow(clippy::type_complexity)]
282+
let cases: &[(&str, &[u32], &[(usize, u32)])] = &[
283+
("soundness_0", &[3, 6, 7, 10, 9, 20, 26, 1], &[(0, 4), (1, 7), (2, 8), (3, 11), (4, 10), (5, 21), (6, 27), (7, 0), (7, 2)]),
284+
("soundness_1", &[5, 10, 6, 7, 42, 9, 5, 4], &[(0, 6), (1, 11), (2, 7), (3, 8), (4, 43), (5, 10), (6, 6), (7, 5)]),
285+
("soundness_2", &[3, 4, 5, 29, 7, 1, 17, 46], &[(0, 2), (1, 5), (2, 6), (3, 30), (4, 8), (5, 0), (5, 2), (6, 18), (7, 47)]),
286+
("soundness_3", &[4, 2, 14, 120, 5, 10, 50, 55], &[(0, 5), (1, 3), (2, 15), (3, 121), (4, 6), (5, 11), (6, 51), (7, 56)]),
287+
("soundness_4", &[5, 10, 10, 3, 4, 19, 20, 1], &[(0, 6), (1, 11), (2, 11), (3, 4), (4, 5), (5, 20), (6, 50), (7, 0), (7, 2)]),
288+
("soundness_5", &[3, 4, 7, 19, 49, 28, 1, 3], &[(0, 4), (1, 5), (2, 8), (3, 20), (4, 50), (5, 29), (6, 0), (6, 2), (7, 4)]),
289+
];
290+
291+
let to_input = |v: &[u32]| v.iter().copied().map(F::new).collect::<Vec<_>>();
292+
293+
for &(name, valid, perturbations) in cases {
294+
let path = format!("{}/{}.py", test_data_dir(), name);
295+
let bytecode = compile_program(&ProgramSource::Filepath(path));
296+
297+
try_execute_bytecode(&bytecode, &to_input(valid), &ExecutionWitness::default(), false)
298+
.unwrap_or_else(|err| panic!("{name}: valid input {valid:?} must succeed, got {err:?}"));
299+
300+
for &(idx, bad_value) in perturbations {
301+
let mut input = valid.to_vec();
302+
input[idx] = bad_value;
303+
let res = try_execute_bytecode(&bytecode, &to_input(&input), &ExecutionWitness::default(), false);
304+
assert!(
305+
res.is_err(),
306+
"{name}: perturbation p[{idx}]={bad_value} (input {input:?}) unexpectedly succeeded",
307+
);
308+
}
309+
}
310+
}
Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
from snark_lib import *
2+
3+
4+
def main():
5+
p = 0
6+
a = p[0]
7+
b = p[1]
8+
c = p[2]
9+
d = p[3]
10+
e = p[4]
11+
f = p[5]
12+
g = p[6]
13+
h = p[7]
14+
15+
assert double(a) == b
16+
assert square_plus_one(a) == d
17+
assert a + c == 10
18+
assert e < 10
19+
assert f <= 20
20+
21+
acc: Mut = 0
22+
for i in unroll(0, 4):
23+
acc = acc + p[i]
24+
assert acc == g
25+
26+
if h == 1:
27+
assert a + b == 9
28+
else:
29+
assert a == 0
30+
return
31+
32+
33+
@inline
34+
def double(x):
35+
return x + x
36+
37+
38+
def square_plus_one(x):
39+
return x * x + 1
Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
from snark_lib import *
2+
3+
4+
def main():
5+
p = 0
6+
n = p[0]
7+
sum_range = p[1]
8+
x = p[2]
9+
y = p[3]
10+
prod_xy = p[4]
11+
outer = p[5]
12+
inner_bound = p[6]
13+
v = p[7]
14+
15+
assert n == 5
16+
17+
s: Mut = 0
18+
for i in range(0, 5):
19+
s = s + i
20+
assert s == sum_range
21+
22+
assert mul(x, y) == prod_xy
23+
24+
nested: Mut = 0
25+
for i in unroll(0, 3):
26+
for j in unroll(0, 3):
27+
nested = nested + i * j
28+
assert nested == outer
29+
30+
assert v < inner_bound
31+
assert inner_bound == v + 1
32+
return
33+
34+
35+
def mul(a, b):
36+
return a * b
Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,58 @@
1+
from snark_lib import *
2+
3+
4+
def main():
5+
p = 0
6+
mode = p[0]
7+
x = p[1]
8+
y = p[2]
9+
expected = p[3]
10+
secondary = p[4]
11+
flag = p[5]
12+
offset = p[6]
13+
total = p[7]
14+
15+
computed: Imu
16+
match mode:
17+
case 0:
18+
computed = add_op(x, y)
19+
case 1:
20+
computed = sub_op(x, y)
21+
case 2:
22+
computed = mul_op(x, y)
23+
case 3:
24+
computed = combined(x, y)
25+
assert computed == expected
26+
27+
adjusted: Imu
28+
if flag == 0:
29+
adjusted = bump(secondary, 1)
30+
elif flag == 1:
31+
adjusted = bump(secondary, 10)
32+
else:
33+
adjusted = bump(secondary, 100)
34+
assert adjusted == offset
35+
36+
assert total == expected + offset
37+
return
38+
39+
40+
def add_op(a, b):
41+
return a + b
42+
43+
44+
def sub_op(a, b):
45+
return a - b
46+
47+
48+
def mul_op(a, b):
49+
return a * b
50+
51+
52+
def combined(a, b):
53+
return mul_op(a, b) + add_op(a, b)
54+
55+
56+
@inline
57+
def bump(v, k):
58+
return v + k
Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
from snark_lib import *
2+
3+
4+
def main():
5+
p = 0
6+
n = p[0]
7+
seed = p[1]
8+
sum_expected = p[2]
9+
prod_expected = p[3]
10+
max_val = p[4]
11+
upper = p[5]
12+
w = p[6]
13+
expected_final = p[7]
14+
15+
assert n == 4
16+
17+
arr = Array(4)
18+
for i in unroll(0, 4):
19+
arr[i] = seed + i
20+
21+
s: Mut = 0
22+
for i in range(0, 4):
23+
s = s + arr[i]
24+
assert s == sum_expected
25+
26+
prod: Mut = 1
27+
for i in unroll(0, 4):
28+
prod = times(prod, arr[i])
29+
assert prod == prod_expected
30+
31+
assert max_val < upper
32+
assert upper <= 100
33+
assert upper == max_val + 5
34+
assert w + max_val == expected_final
35+
return
36+
37+
38+
@inline
39+
def times(a, b):
40+
return a * b
Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,54 @@
1+
from snark_lib import *
2+
3+
4+
def main():
5+
p = 0
6+
n = p[0]
7+
expected_sum_pos = p[1]
8+
expected_sum_neg = p[2]
9+
x = p[3]
10+
y = p[4]
11+
expected_pipeline = p[5]
12+
threshold = p[6]
13+
threshold_check = p[7]
14+
15+
assert n == 5
16+
17+
markers = Array(5)
18+
for i in unroll(0, 5):
19+
markers[i] = i
20+
21+
sum_pos: Mut = 0
22+
sum_neg: Mut = 0
23+
for i in range(0, 5):
24+
m = markers[i]
25+
if m == 0:
26+
sum_neg = sum_neg + 10
27+
else:
28+
sum_pos = sum_pos + m
29+
assert sum_pos == expected_sum_pos
30+
assert sum_neg == expected_sum_neg
31+
32+
assert pipeline(x, y) == expected_pipeline
33+
34+
if threshold_check == 1:
35+
assert threshold < 50
36+
else:
37+
assert threshold == 0
38+
39+
assert threshold_check * (1 - threshold_check) == 0
40+
return
41+
42+
43+
@inline
44+
def pipeline(a, b):
45+
return wrapper(a, b) + a
46+
47+
48+
def wrapper(a, b):
49+
return inner(a, b) + b
50+
51+
52+
@inline
53+
def inner(a, b):
54+
return a * b
Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,63 @@
1+
from snark_lib import *
2+
3+
4+
def main():
5+
p = 0
6+
seed = p[0]
7+
n = p[1]
8+
last_write = p[2]
9+
match_tally = p[3]
10+
pipeline_squared = p[4]
11+
paired = p[5]
12+
flag = p[6]
13+
alt = p[7]
14+
15+
assert n == 4
16+
17+
counter: Mut = 0
18+
for i in range(0, 4):
19+
counter = 2 * i + 1
20+
assert counter == last_write
21+
22+
acc: Mut = seed
23+
for i in range(0, 4):
24+
match i:
25+
case 0:
26+
acc = acc + 1
27+
case 1:
28+
acc = acc + 3
29+
case 2:
30+
acc = acc + 5
31+
case 3:
32+
acc = acc + 7
33+
assert acc == match_tally
34+
35+
assert sqr_via_pipeline(seed + n) == pipeline_squared
36+
37+
assert paired_sum(seed, n) == paired
38+
39+
chosen: Imu
40+
if flag == 1:
41+
chosen = seed
42+
else:
43+
chosen = seed * 2
44+
assert chosen == alt
45+
46+
assert flag * (1 - flag) == 0
47+
return
48+
49+
50+
@inline
51+
def sqr_via_pipeline(x):
52+
return mul_boxed(x, x)
53+
54+
55+
def mul_boxed(a, b):
56+
return a * b
57+
58+
59+
def paired_sum(a, b):
60+
total: Mut = 0
61+
for i in range(0, 4):
62+
total = total + a + b
63+
return total

0 commit comments

Comments
 (0)