@@ -89,6 +89,45 @@ pub fn batch_fold_multilinears<
8989 }
9090}
9191
92+ pub fn fold_multilinear_at_bit <
93+ EF : PrimeCharacteristicRing + Copy + Send + Sync ,
94+ IF : Copy + Sub < Output = IF > + Send + Sync ,
95+ OF : Copy + Add < IF , Output = OF > + Send + Sync ,
96+ Mul : Fn ( IF , EF ) -> OF + Sync + Send ,
97+ > (
98+ m : & [ IF ] ,
99+ alpha : EF ,
100+ bit : usize ,
101+ mul_if_of : & Mul ,
102+ ) -> Vec < OF > {
103+ let new_size = m. len ( ) / 2 ;
104+ assert ! ( m. len( ) >= 2 * ( 1 << bit) , "bit out of range for slice length" ) ;
105+ let stride = 1usize << bit;
106+ let lo_mask = stride - 1 ;
107+ let mut res = unsafe { uninitialized_vec ( new_size) } ;
108+
109+ let compute = |new_j : usize | {
110+ let i_hi = new_j >> bit;
111+ let i_lo = new_j & lo_mask;
112+ let i0 = ( i_hi << ( bit + 1 ) ) | i_lo;
113+ let i1 = i0 | stride;
114+ mul_if_of ( m[ i1] - m[ i0] , alpha) + m[ i0]
115+ } ;
116+
117+ if new_size < PARALLEL_THRESHOLD {
118+ for ( new_j, res_v) in res. iter_mut ( ) . enumerate ( ) {
119+ * res_v = compute ( new_j) ;
120+ }
121+ } else {
122+ ( 0 ..new_size)
123+ . into_par_iter ( )
124+ . with_min_len ( PARALLEL_THRESHOLD )
125+ . map ( compute)
126+ . collect_into_vec ( & mut res) ;
127+ }
128+ res
129+ }
130+
92131pub fn fold_multilinear <
93132 EF : PrimeCharacteristicRing + Copy + Send + Sync ,
94133 IF : Copy + Sub < Output = IF > + Send + Sync ,
@@ -116,6 +155,31 @@ pub fn fold_multilinear<
116155 res
117156}
118157
158+ pub fn batch_fold_multilinears_at_bit <
159+ EF : PrimeCharacteristicRing + Copy + Send + Sync ,
160+ IF : Copy + Sub < Output = IF > + Send + Sync ,
161+ OF : Copy + Add < IF , Output = OF > + Send + Sync ,
162+ F : Fn ( IF , EF ) -> OF + Sync + Send ,
163+ > (
164+ polys : & [ & [ IF ] ] ,
165+ alpha : EF ,
166+ bit : usize ,
167+ mul_if_of : F ,
168+ ) -> Vec < Vec < OF > > {
169+ let total_size: usize = polys. iter ( ) . map ( |p| p. len ( ) ) . sum ( ) ;
170+ if total_size < PARALLEL_THRESHOLD {
171+ polys
172+ . iter ( )
173+ . map ( |poly| fold_multilinear_at_bit ( poly, alpha, bit, & mul_if_of) )
174+ . collect ( )
175+ } else {
176+ polys
177+ . par_iter ( )
178+ . map ( |poly| fold_multilinear_at_bit ( poly, alpha, bit, & mul_if_of) )
179+ . collect ( )
180+ }
181+ }
182+
119183/// Returns a vector of uninitialized elements of type `A` with the specified length.
120184/// # Safety
121185/// Entries should be overwritten before use.
0 commit comments