Skip to content

Commit cdf7f9f

Browse files
committed
small recursion program speedup
1 parent 32688b2 commit cdf7f9f

2 files changed

Lines changed: 19 additions & 10 deletions

File tree

crates/rec_aggregation/utils.py

Lines changed: 16 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -81,20 +81,32 @@ def poly_eq_extension(point, n: Const):
8181
return res + (2**n - 1) * DIM
8282

8383

84-
def eq_mle_extension(a, b, n):
84+
@inline
85+
def eq_mle_extension_to(a, b, dst, n):
8586
debug_assert(n < 33)
8687
debug_assert(0 < n)
88+
match_range(n, range(1, 33), lambda i: poly_eq_ee(a, b, dst, i))
89+
return
90+
91+
92+
def eq_mle_extension(a, b, n):
8793
res = Array(DIM)
88-
match_range(n, range(1, 33), lambda i: poly_eq_ee(a, b, res, i))
94+
eq_mle_extension_to(a, b, res, n)
8995
return res
9096

9197

9298
@inline
93-
def eq_mle_base_extension(a, b, n):
99+
def eq_mle_base_extension_to(a, b, dst, n):
94100
debug_assert(n < 33)
95101
debug_assert(0 < n)
102+
match_range(n, range(1, 33), lambda i: poly_eq_be(a, b, dst, i))
103+
return
104+
105+
106+
@inline
107+
def eq_mle_base_extension(a, b, n):
96108
res = Array(DIM)
97-
match_range(n, range(1, 33), lambda i: poly_eq_be(a, b, res, i))
109+
eq_mle_base_extension_to(a, b, res, n)
98110
return res
99111

100112

crates/rec_aggregation/whir.py

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -113,8 +113,7 @@ def whir_open(
113113
all_ood_recovered_evals = Array(num_oods[0] * DIM)
114114
for i in range(0, num_oods[0]):
115115
expanded_from_univariate = expand_from_univariate_ext(ood_points_commit + i * DIM, n_vars)
116-
ood_rec = eq_mle_extension(expanded_from_univariate, folding_randomness_global, n_vars)
117-
copy_5(ood_rec, all_ood_recovered_evals + i * DIM)
116+
eq_mle_extension_to(expanded_from_univariate, folding_randomness_global, all_ood_recovered_evals + i * DIM, n_vars)
118117
s: Mut = Array(DIM)
119118
dot_product_ee_dynamic(
120119
all_ood_recovered_evals,
@@ -132,8 +131,7 @@ def whir_open(
132131
my_folding_randomness += folding_factors[i] * DIM
133132
for j in range(0, num_oods[i + 1]):
134133
expanded_from_univariate = expand_from_univariate_ext(all_ood_points[i] + j * DIM, n_vars_remaining)
135-
ood_rec = eq_mle_extension(expanded_from_univariate, my_folding_randomness, n_vars_remaining)
136-
copy_5(ood_rec, my_ood_recovered_evals + j * DIM)
134+
eq_mle_extension_to(expanded_from_univariate, my_folding_randomness, my_ood_recovered_evals + j * DIM, n_vars_remaining)
137135
summed_ood = Array(DIM)
138136
dot_product_ee_dynamic(
139137
my_ood_recovered_evals,
@@ -146,8 +144,7 @@ def whir_open(
146144
circle_value_i = all_circle_values[i]
147145
for j in range(0, num_queries[i]): # unroll ?
148146
expanded_from_univariate = expand_from_univariate_base(circle_value_i[j], n_vars_remaining)
149-
temp = eq_mle_base_extension(expanded_from_univariate, my_folding_randomness, n_vars_remaining)
150-
copy_5(temp, s6s + j * DIM)
147+
eq_mle_base_extension_to(expanded_from_univariate, my_folding_randomness, s6s + j * DIM, n_vars_remaining)
151148
s7 = Array(DIM)
152149
dot_product_ee_dynamic(
153150
s6s,

0 commit comments

Comments
 (0)