Skip to content

Commit 43461d9

Browse files
committed
improve precompile arguments
1 parent e353417 commit 43461d9

14 files changed

Lines changed: 254 additions & 314 deletions

File tree

crates/lean_compiler/src/a_simplify_lang.rs

Lines changed: 37 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,10 @@ use crate::{
44
parser::{ConstArrayValue, parse_program},
55
};
66
use backend::PrimeCharacteristicRing;
7-
use lean_vm::{Boolean, BooleanExpr, CustomHint, EXT_OP_FUNCTIONS, FunctionName, SourceLocation, Table, TableT};
7+
use lean_vm::{
8+
Boolean, BooleanExpr, CustomHint, ExtensionOpMode, FunctionName, PrecompileArgs, PrecompileCompTimeArgs,
9+
SourceLocation, Table, TableT,
10+
};
811
use std::{
912
collections::{BTreeMap, BTreeSet},
1013
fmt::{Display, Formatter},
@@ -56,6 +59,8 @@ impl From<Var> for VarOrConstMallocAccess {
5659
}
5760
}
5861

62+
pub type SimplePrecompile = PrecompileArgs<SimpleExpr, ConstExpression>;
63+
5964
#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)]
6065
pub enum SimpleLine {
6166
Match {
@@ -92,10 +97,7 @@ pub enum SimpleLine {
9297
FunctionRet {
9398
return_data: Vec<SimpleExpr>,
9499
},
95-
Precompile {
96-
table: Table,
97-
args: Vec<SimpleExpr>,
98-
},
100+
Precompile(SimplePrecompile),
99101
Panic {
100102
message: Option<String>,
101103
},
@@ -155,7 +157,7 @@ impl SimpleLine {
155157
| Self::RawAccess { .. }
156158
| Self::FunctionCall { .. }
157159
| Self::FunctionRet { .. }
158-
| Self::Precompile { .. }
160+
| Self::Precompile(..)
159161
| Self::Panic { .. }
160162
| Self::CustomHint(..)
161163
| Self::Print { .. }
@@ -180,7 +182,7 @@ impl SimpleLine {
180182
| Self::RawAccess { .. }
181183
| Self::FunctionCall { .. }
182184
| Self::FunctionRet { .. }
183-
| Self::Precompile { .. }
185+
| Self::Precompile(..)
184186
| Self::Panic { .. }
185187
| Self::CustomHint(..)
186188
| Self::Print { .. }
@@ -204,9 +206,8 @@ impl SimpleLine {
204206
Self::Match { value, .. } => vec![value],
205207
Self::IfNotZero { condition, .. } => vec![condition],
206208
Self::HintMAlloc { size, .. } => vec![size],
207-
Self::Precompile { args, .. } | Self::FunctionCall { args, .. } | Self::CustomHint(_, args) => {
208-
args.iter().collect()
209-
}
209+
Self::Precompile(precompile) => precompile.operand_exprs().to_vec(),
210+
Self::FunctionCall { args, .. } | Self::CustomHint(_, args) => args.iter().collect(),
210211
Self::FunctionRet { return_data } => return_data.iter().collect(),
211212
Self::Print { content, .. } => content.iter().collect(),
212213
Self::DebugAssert(boolean, _) => vec![&boolean.left, &boolean.right],
@@ -2117,11 +2118,7 @@ fn simplify_lines(
21172118

21182119
// Special handling for extension_op precompile
21192120
// Signature: func(ptr_a, ptr_b, ptr_res) or func(ptr_a, ptr_b, ptr_res, length)
2120-
if let Some(mode) = EXT_OP_FUNCTIONS
2121-
.iter()
2122-
.find(|(name, _)| *name == function_name.as_str())
2123-
.map(|(_, mode)| *mode)
2124-
{
2121+
if let Some(mode) = ExtensionOpMode::from_name(function_name) {
21252122
if !targets.is_empty() {
21262123
return Err(format!(
21272124
"Precompile {function_name} should not return values, at {location}"
@@ -2133,22 +2130,24 @@ fn simplify_lines(
21332130
args.len()
21342131
));
21352132
}
2136-
let mut simplified_args: Vec<SimpleExpr> = args[..3]
2133+
let simplified_args = args[..3]
21372134
.iter()
21382135
.map(|arg| simplify_expr(ctx, state, const_malloc, arg, &mut res))
21392136
.collect::<Result<Vec<_>, _>>()?;
2140-
// Inject size (aux_1) and mode (aux_2)
2141-
let size: SimpleExpr = if args.len() == 4 {
2137+
2138+
let size = if args.len() == 4 {
21422139
simplify_expr(ctx, state, const_malloc, &args[3], &mut res)?
2140+
.as_constant()
2141+
.expect("extension op size must be a constant")
21432142
} else {
2144-
SimpleExpr::one()
2143+
ConstExpression::one()
21452144
};
2146-
simplified_args.push(size);
2147-
simplified_args.push(SimpleExpr::Constant(mode.into()));
2148-
res.push(SimpleLine::Precompile {
2149-
table: Table::extension_op(),
2150-
args: simplified_args,
2151-
});
2145+
res.push(SimpleLine::Precompile(PrecompileArgs {
2146+
arg_0: simplified_args[0].clone(),
2147+
arg_1: simplified_args[1].clone(),
2148+
res: simplified_args[2].clone(),
2149+
data: PrecompileCompTimeArgs::ExtensionOp { size, mode },
2150+
}));
21522151
continue;
21532152
}
21542153

@@ -2159,14 +2158,22 @@ fn simplify_lines(
21592158
"Precompile {function_name} should not return values, at {location}"
21602159
));
21612160
}
2161+
if args.len() != 3 {
2162+
return Err(format!(
2163+
"Precompile {function_name} expects 3 arguments (ptr_a, ptr_b, ptr_res), got {}, at {location}",
2164+
args.len()
2165+
));
2166+
}
21622167
let simplified_args = args
21632168
.iter()
21642169
.map(|arg| simplify_expr(ctx, state, const_malloc, arg, &mut res))
21652170
.collect::<Result<Vec<_>, _>>()?;
2166-
res.push(SimpleLine::Precompile {
2167-
table: Table::poseidon16(),
2168-
args: simplified_args,
2169-
});
2171+
res.push(SimpleLine::Precompile(PrecompileArgs {
2172+
arg_0: simplified_args[0].clone(),
2173+
arg_1: simplified_args[1].clone(),
2174+
res: simplified_args[2].clone(),
2175+
data: PrecompileCompTimeArgs::Poseidon16,
2176+
}));
21702177
continue;
21712178
}
21722179

@@ -3962,16 +3969,7 @@ impl SimpleLine {
39623969
.join(", ");
39633970
format!("return {return_data_str}")
39643971
}
3965-
Self::Precompile {
3966-
table: precompile,
3967-
args,
3968-
} => {
3969-
format!(
3970-
"{}({})",
3971-
&precompile.name(),
3972-
args.iter().map(|arg| format!("{arg}")).collect::<Vec<_>>().join(", ")
3973-
)
3974-
}
3972+
Self::Precompile(precompile) => format!("{precompile}"),
39753973
Self::Print { line_info: _, content } => {
39763974
let content_str = content.iter().map(|c| format!("{c}")).collect::<Vec<_>>().join(", ");
39773975
format!("print({content_str})")

crates/lean_compiler/src/b_compile_intermediate.rs

Lines changed: 23 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -508,14 +508,9 @@ fn compile_lines(
508508
return Ok(instructions);
509509
}
510510

511-
SimpleLine::Precompile { table, args, .. } => {
512-
match table {
513-
Table::ExtensionOp(_) => assert_eq!(args.len(), 5),
514-
Table::Poseidon16(_) => assert_eq!(args.len(), 3),
515-
Table::Execution(_) => unreachable!(),
516-
}
517-
// if arg_c is constant, create a variable (in memory) to hold it
518-
let arg_c = if let SimpleExpr::Constant(cst) = &args[2] {
511+
SimpleLine::Precompile(precompile) => {
512+
// if res is constant, create a variable (in memory) to hold it
513+
let res = if let SimpleExpr::Constant(cst) = &precompile.res {
519514
instructions.push(IntermediateInstruction::Computation {
520515
operation: Operation::Add,
521516
arg_a: IntermediateValue::Constant(cst.clone()),
@@ -528,27 +523,25 @@ fn compile_lines(
528523
compiler.stack_pos += 1;
529524
IntermediateValue::MemoryAfterFp { offset: offset.into() }
530525
} else {
531-
try_precompile_fp_relative(&args[2], compiler)
532-
.unwrap_or_else(|| IntermediateValue::from_simple_expr(&args[2], compiler))
526+
try_precompile_fp_relative(&precompile.res, compiler)
527+
.unwrap_or_else(|| IntermediateValue::from_simple_expr(&precompile.res, compiler))
533528
};
534-
let (arg_a, arg_b) = match (
535-
try_precompile_fp_relative(&args[0], compiler),
536-
try_precompile_fp_relative(&args[1], compiler),
529+
let (left, right) = match (
530+
try_precompile_fp_relative(&precompile.arg_0, compiler),
531+
try_precompile_fp_relative(&precompile.arg_1, compiler),
537532
) {
538533
(Some(a), Some(b)) => (a, b),
539534
_ => (
540-
IntermediateValue::from_simple_expr(&args[0], compiler),
541-
IntermediateValue::from_simple_expr(&args[1], compiler),
535+
IntermediateValue::from_simple_expr(&precompile.arg_0, compiler),
536+
IntermediateValue::from_simple_expr(&precompile.arg_1, compiler),
542537
),
543538
};
544-
instructions.push(IntermediateInstruction::Precompile {
545-
table: *table,
546-
arg_a,
547-
arg_b,
548-
arg_c,
549-
aux_1: args.get(3).unwrap_or(&SimpleExpr::zero()).as_constant().unwrap(),
550-
aux_2: args.get(4).unwrap_or(&SimpleExpr::zero()).as_constant().unwrap(),
551-
});
539+
instructions.push(IntermediateInstruction::Precompile(PrecompileArgs {
540+
arg_0: left,
541+
arg_1: right,
542+
res,
543+
data: precompile.data.clone(),
544+
}));
552545
}
553546

554547
SimpleLine::FunctionRet { return_data } => {
@@ -886,20 +879,21 @@ fn collect_use_info(
886879
}
887880
}
888881

889-
if let SimpleLine::Precompile { args, .. } = line {
890-
// args[0] & args[1] count only if both are fp-rel-capable
891-
let both_capable = args[..2]
882+
if let SimpleLine::Precompile(precompile) = line {
883+
let exprs = precompile.operand_exprs();
884+
// exprs[0] & exprs[1] count only if both are fp-rel-capable
885+
let both_capable = exprs[..2]
892886
.iter()
893887
.all(|a| matches!(a, SimpleExpr::Memory(VarOrConstMallocAccess::Var(v)) if fp_rel_capable.contains(v)));
894888
if both_capable {
895-
for arg in &args[..2] {
889+
for arg in &exprs[..2] {
896890
if let SimpleExpr::Memory(VarOrConstMallocAccess::Var(v)) = arg {
897891
*fp_rel_uses.entry(v.clone()).or_default() += 1;
898892
}
899893
}
900894
}
901-
// args[2]: independently fp-rel-capable
902-
if let Some(SimpleExpr::Memory(VarOrConstMallocAccess::Var(v))) = args.get(2)
895+
// exprs[2]: independently fp-rel-capable
896+
if let SimpleExpr::Memory(VarOrConstMallocAccess::Var(v)) = exprs[2]
903897
&& fp_rel_capable.contains(v)
904898
{
905899
*fp_rel_uses.entry(v.clone()).or_default() += 1;

crates/lean_compiler/src/c_compile_final.rs

Lines changed: 12 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ impl IntermediateInstruction {
2121
| Self::Deref { .. }
2222
| Self::JumpIfNotZero { .. }
2323
| Self::Jump { .. }
24-
| Self::Precompile { .. } => false,
24+
| Self::Precompile(..) => false,
2525
}
2626
}
2727
}
@@ -275,22 +275,17 @@ fn compile_block(
275275
let one = ConstExpression::one().into();
276276
codegen_jump(hints, low_level_bytecode, one, dest, updated_fp)
277277
}
278-
IntermediateInstruction::Precompile {
279-
table,
280-
arg_a,
281-
arg_b,
282-
arg_c,
283-
aux_1,
284-
aux_2,
285-
} => {
286-
low_level_bytecode.push(Instruction::Precompile {
287-
table,
288-
arg_a: arg_a.try_into_mem_or_fp_or_constant(compiler).unwrap(),
289-
arg_b: arg_b.try_into_mem_or_fp_or_constant(compiler).unwrap(),
290-
arg_c: arg_c.try_into_mem_or_fp_or_constant(compiler).unwrap(),
291-
aux_1: eval_const_expression_usize(&aux_1, compiler),
292-
aux_2: eval_const_expression_usize(&aux_2, compiler),
293-
});
278+
IntermediateInstruction::Precompile(precompile) => {
279+
let data = precompile
280+
.data
281+
.map_size(|size| eval_const_expression_usize(&size, compiler));
282+
let args = PrecompileArgs {
283+
arg_0: precompile.arg_0.try_into_mem_or_fp_or_constant(compiler).unwrap(),
284+
arg_1: precompile.arg_1.try_into_mem_or_fp_or_constant(compiler).unwrap(),
285+
res: precompile.res.try_into_mem_or_fp_or_constant(compiler).unwrap(),
286+
data,
287+
};
288+
low_level_bytecode.push(Instruction::Precompile(args));
294289
}
295290
IntermediateInstruction::CustomHint(hint, args) => {
296291
let hint = Hint::Custom(

crates/lean_compiler/src/instruction_encoder.rs

Lines changed: 10 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -46,41 +46,27 @@ pub fn field_representation(instr: &Instruction) -> [F; N_INSTRUCTION_COLUMNS] {
4646
set_nu_b(&mut fields, dest);
4747
set_nu_c(&mut fields, updated_fp);
4848
}
49-
Instruction::Precompile {
50-
table,
51-
arg_a,
52-
arg_b,
53-
arg_c,
54-
aux_1,
55-
aux_2,
56-
} => {
57-
let precompile_data = match *table {
58-
Table::Poseidon16(_) => POSEIDON_PRECOMPILE_DATA,
59-
Table::ExtensionOp(_) => {
60-
let size = *aux_1;
61-
let mode = *aux_2;
62-
assert!(
63-
EXT_OP_FUNCTIONS.iter().any(|(_, m)| *m == mode),
64-
"invalid extension_op mode={mode}"
65-
);
66-
assert!(size >= 1, "invalid extension_op size={size}");
67-
mode + EXT_OP_LEN_MULTIPLIER * size
49+
Instruction::Precompile(precompile) => {
50+
let precompile_data = match &precompile.data {
51+
PrecompileCompTimeArgs::Poseidon16 => POSEIDON_PRECOMPILE_DATA,
52+
PrecompileCompTimeArgs::ExtensionOp { size, mode } => {
53+
assert!(*size >= 1, "invalid extension_op size={size}");
54+
mode.flag_encoding() + EXT_OP_LEN_MULTIPLIER * size
6855
}
69-
_ => unreachable!("unknown precompile table"),
7056
};
7157
fields[instr_idx(COL_PRECOMPILE_DATA)] = F::from_usize(precompile_data);
72-
match (arg_a, arg_b) {
58+
match (precompile.arg_0, precompile.arg_1) {
7359
(MemOrFpOrConstant::FpRelative { offset: off_a }, MemOrFpOrConstant::FpRelative { offset: off_b }) => {
7460
fields[instr_idx(COL_FLAG_AB_FP)] = F::ONE;
75-
fields[instr_idx(COL_OPERAND_A)] = F::from_usize(*off_a);
76-
fields[instr_idx(COL_OPERAND_B)] = F::from_usize(*off_b);
61+
fields[instr_idx(COL_OPERAND_A)] = F::from_usize(off_a);
62+
fields[instr_idx(COL_OPERAND_B)] = F::from_usize(off_b);
7763
}
7864
(a, b) => {
7965
set_nu_a(&mut fields, &a.as_mem_or_constant());
8066
set_nu_b(&mut fields, &b.as_mem_or_constant());
8167
}
8268
}
83-
set_nu_c(&mut fields, arg_c);
69+
set_nu_c(&mut fields, &precompile.res);
8470
}
8571
}
8672
fields

crates/lean_compiler/src/ir/instruction.rs

Lines changed: 3 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
use super::value::IntermediateValue;
22
use crate::lang::{ConstExpression, MathOperation};
3-
use lean_vm::{BooleanExpr, CustomHint, Operation, SourceLocation, Table, TableT};
3+
use lean_vm::{BooleanExpr, CustomHint, Operation, PrecompileArgs, SourceLocation};
44
use std::fmt::{Display, Formatter};
55

66
/// Core instruction type for the intermediate representation.
@@ -27,14 +27,7 @@ pub enum IntermediateInstruction {
2727
dest: IntermediateValue,
2828
updated_fp: Option<IntermediateValue>,
2929
},
30-
Precompile {
31-
table: Table,
32-
arg_a: IntermediateValue,
33-
arg_b: IntermediateValue,
34-
arg_c: IntermediateValue,
35-
aux_1: ConstExpression,
36-
aux_2: ConstExpression,
37-
},
30+
Precompile(PrecompileArgs<IntermediateValue, ConstExpression>),
3831
// HINTS (does not appears in the final bytecode)
3932
Inverse {
4033
// If the value is zero, it will return zero.
@@ -151,16 +144,7 @@ impl Display for IntermediateInstruction {
151144
write!(f, "jump_if_not_zero {condition} to {dest}")
152145
}
153146
}
154-
Self::Precompile {
155-
table,
156-
arg_a,
157-
arg_b,
158-
arg_c,
159-
aux_1,
160-
aux_2,
161-
} => {
162-
write!(f, "{}({arg_a}, {arg_b}, {arg_c}, {aux_1}, {aux_2})", table.name())
163-
}
147+
Self::Precompile(precompile) => write!(f, "{precompile}"),
164148
Self::Inverse { arg, res_offset } => {
165149
write!(f, "m[fp + {res_offset}] = inverse({arg})")
166150
}

0 commit comments

Comments
 (0)