Skip to content

Commit ddf467b

Browse files
committed
add parser logic for if statement
1 parent 6bd62dc commit ddf467b

5 files changed

Lines changed: 512 additions & 6 deletions

File tree

src/ast.rs

Lines changed: 227 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -239,6 +239,8 @@ pub enum SingleExpressionInner {
239239
Call(Call),
240240
/// Match expression.
241241
Match(Match),
242+
/// If expression.
243+
If(If),
242244
}
243245

244246
/// Call of a user-defined or of a builtin function.
@@ -403,6 +405,38 @@ impl MatchArm {
403405
}
404406
}
405407

408+
#[derive(Clone, Debug)]
409+
pub struct If {
410+
scrutinee: Arc<Expression>,
411+
then_arm: Arc<Expression>,
412+
else_arm: Arc<Expression>,
413+
span: Span,
414+
}
415+
416+
impl If {
417+
/// Access the expression who's output is deconstructed in the `if`.
418+
pub fn scrutinee(&self) -> &Expression {
419+
&self.scrutinee
420+
}
421+
422+
/// Access the branch that handles the `true` portion of the `if`.
423+
pub fn then_arm(&self) -> &Expression {
424+
&self.then_arm
425+
}
426+
427+
/// Access the branch that handles the `false` or `else` portion of the `if`.
428+
pub fn else_arm(&self) -> &Expression {
429+
&self.else_arm
430+
}
431+
432+
/// Access the span of the if statement.
433+
pub fn span(&self) -> &Span {
434+
&self.span
435+
}
436+
}
437+
438+
impl_eq_hash!(If; scrutinee, then_arm, else_arm);
439+
406440
/// Item when analyzing modules.
407441
#[derive(Clone, Debug, Eq, PartialEq, Hash)]
408442
pub enum ModuleItem {
@@ -462,6 +496,7 @@ pub enum ExprTree<'a> {
462496
Single(&'a SingleExpression),
463497
Call(&'a Call),
464498
Match(&'a Match),
499+
If(&'a If),
465500
}
466501

467502
impl TreeLike for ExprTree<'_> {
@@ -502,13 +537,19 @@ impl TreeLike for ExprTree<'_> {
502537
}
503538
S::Call(call) => Tree::Unary(Self::Call(call)),
504539
S::Match(match_) => Tree::Unary(Self::Match(match_)),
540+
S::If(if_) => Tree::Unary(Self::If(if_)),
505541
},
506542
Self::Call(call) => Tree::Nary(call.args().iter().map(Self::Expression).collect()),
507543
Self::Match(match_) => Tree::Nary(Arc::new([
508544
Self::Expression(match_.scrutinee()),
509545
Self::Expression(match_.left().expression()),
510546
Self::Expression(match_.right().expression()),
511547
])),
548+
Self::If(if_) => Tree::Nary(Arc::new([
549+
Self::Expression(if_.scrutinee()),
550+
Self::Expression(if_.then_arm()),
551+
Self::Expression(if_.else_arm()),
552+
])),
512553
}
513554
}
514555
}
@@ -1059,6 +1100,9 @@ impl AbstractSyntaxTree for SingleExpression {
10591100
parse::SingleExpressionInner::Match(match_) => {
10601101
Match::analyze(match_, ty, scope).map(SingleExpressionInner::Match)?
10611102
}
1103+
parse::SingleExpressionInner::If(if_) => {
1104+
If::analyze(if_, ty, scope).map(SingleExpressionInner::If)?
1105+
}
10621106
};
10631107

10641108
Ok(Self {
@@ -1426,6 +1470,28 @@ impl AbstractSyntaxTree for Match {
14261470
}
14271471
}
14281472

1473+
impl AbstractSyntaxTree for If {
1474+
type From = parse::If;
1475+
1476+
fn analyze(from: &Self::From, ty: &ResolvedType, scope: &mut Scope) -> Result<Self, RichError> {
1477+
let scrutinee =
1478+
Expression::analyze(from.scrutinee(), &ResolvedType::boolean(), scope).map(Arc::new)?;
1479+
scope.push_scope();
1480+
let ast_then = Expression::analyze(from.then_arm(), ty, scope).map(Arc::new)?;
1481+
scope.pop_scope();
1482+
scope.push_scope();
1483+
let ast_else = Expression::analyze(from.else_arm(), ty, scope).map(Arc::new)?;
1484+
scope.pop_scope();
1485+
1486+
Ok(Self {
1487+
scrutinee,
1488+
then_arm: ast_then,
1489+
else_arm: ast_else,
1490+
span: *from.as_ref(),
1491+
})
1492+
}
1493+
}
1494+
14291495
fn analyze_named_module(
14301496
name: ModuleName,
14311497
from: &parse::ModuleProgram,
@@ -1559,6 +1625,12 @@ impl AsRef<Span> for Match {
15591625
}
15601626
}
15611627

1628+
impl AsRef<Span> for If {
1629+
fn as_ref(&self) -> &Span {
1630+
&self.span
1631+
}
1632+
}
1633+
15621634
impl AsRef<Span> for Module {
15631635
fn as_ref(&self) -> &Span {
15641636
&self.span
@@ -1570,3 +1642,158 @@ impl AsRef<Span> for ModuleAssignment {
15701642
&self.span
15711643
}
15721644
}
1645+
1646+
#[cfg(test)]
1647+
mod test {
1648+
use super::*;
1649+
use crate::parse::{self, ParseFromStr};
1650+
use crate::types::UIntType;
1651+
1652+
/// Helper to check if an expression is a constant, unwrapping blocks if needed
1653+
fn is_constant_expr(expr: &Expression) -> bool {
1654+
match expr.inner() {
1655+
ExpressionInner::Single(single) => {
1656+
matches!(single.inner(), SingleExpressionInner::Constant(_))
1657+
}
1658+
ExpressionInner::Block(_, Some(inner_expr)) => is_constant_expr(inner_expr),
1659+
_ => false,
1660+
}
1661+
}
1662+
1663+
/// Helper to check if an expression is a block with statements
1664+
fn is_block_with_statements(expr: &Expression) -> bool {
1665+
matches!(expr.inner(), ExpressionInner::Block(stmts, Some(_)) if !stmts.is_empty())
1666+
}
1667+
1668+
fn parse_if(input: &str) -> parse::If {
1669+
// Parse the if expression
1670+
let parsed_expr = parse::Expression::parse_from_str(input).expect("Failed to parse");
1671+
1672+
// Extract the parsed If from the expression
1673+
let parsed_if = match parsed_expr.inner() {
1674+
parse::ExpressionInner::Single(single) => match single.inner() {
1675+
parse::SingleExpressionInner::If(if_) => if_.clone(),
1676+
_ => panic!("Expected If expression"),
1677+
},
1678+
_ => panic!("Expected Single expression"),
1679+
};
1680+
parsed_if
1681+
}
1682+
1683+
#[test]
1684+
fn test_if_expression_analyze() {
1685+
let input = "if true { 0 } else { 1 }";
1686+
1687+
let parsed_if = &parse_if(input);
1688+
1689+
// Analyze the if expression with u8 as the expected type
1690+
let expected_type = ResolvedType::from(UIntType::U8);
1691+
let mut scope = Scope::default();
1692+
let ast_if = If::analyze(parsed_if, &expected_type, &mut scope)
1693+
.expect("Failed to analyze If expression");
1694+
1695+
// Verify the structure
1696+
assert_eq!(
1697+
ast_if.scrutinee().ty(),
1698+
&ResolvedType::boolean(),
1699+
"Scrutinee should be boolean type"
1700+
);
1701+
assert_eq!(
1702+
ast_if.then_arm().ty(),
1703+
&expected_type,
1704+
"Then arm should have u8 type"
1705+
);
1706+
assert_eq!(
1707+
ast_if.else_arm().ty(),
1708+
&expected_type,
1709+
"Else arm should have u8 type"
1710+
);
1711+
1712+
// Verify scrutinee is a boolean constant
1713+
match ast_if.scrutinee().inner() {
1714+
ExpressionInner::Single(single) => match single.inner() {
1715+
SingleExpressionInner::Constant(_) => {
1716+
// Boolean constant verified
1717+
}
1718+
_ => panic!("Expected boolean constant for scrutinee"),
1719+
},
1720+
_ => panic!("Expected single expression for scrutinee"),
1721+
}
1722+
1723+
// Verify both arms are constants (may be wrapped in blocks)
1724+
assert!(
1725+
is_constant_expr(ast_if.then_arm()),
1726+
"Then arm should be a constant"
1727+
);
1728+
assert!(
1729+
is_constant_expr(ast_if.else_arm()),
1730+
"Else arm should be a constant"
1731+
);
1732+
}
1733+
1734+
#[test]
1735+
fn test_if_expression_with_complex_arms() {
1736+
let input = "if false { let x: u8 = 5; x } else { 10 }";
1737+
1738+
let parsed_expr = parse::Expression::parse_from_str(input).expect("Failed to parse");
1739+
let expected_type = ResolvedType::from(UIntType::U8);
1740+
1741+
// Analyze the entire expression (which will handle the if internally)
1742+
let ast_expr = Expression::analyze_const(&parsed_expr, &expected_type)
1743+
.expect("Failed to analyze expression");
1744+
1745+
// Verify the expression is an If
1746+
match ast_expr.inner() {
1747+
ExpressionInner::Single(single) => match single.inner() {
1748+
SingleExpressionInner::If(ast_if) => {
1749+
assert_eq!(ast_if.scrutinee().ty(), &ResolvedType::boolean());
1750+
assert_eq!(ast_if.then_arm().ty(), &expected_type);
1751+
assert_eq!(ast_if.else_arm().ty(), &expected_type);
1752+
1753+
// Verify then arm is a block with statements and else arm is a constant
1754+
assert!(
1755+
is_block_with_statements(ast_if.then_arm()),
1756+
"Then arm should be a block with statements"
1757+
);
1758+
assert!(
1759+
is_constant_expr(ast_if.else_arm()),
1760+
"Else arm should be a constant"
1761+
);
1762+
}
1763+
_ => panic!("Expected If expression"),
1764+
},
1765+
_ => panic!("Expected Single expression"),
1766+
}
1767+
}
1768+
1769+
#[test]
1770+
fn test_if_valid_parse_but_invalid_ast() {
1771+
let input = "if false { let x: u8 = 5; } else { 10 }";
1772+
1773+
let parsed_if = &parse_if(input);
1774+
let expected_type = ResolvedType::from(UIntType::U8);
1775+
let mut scope = Scope::default();
1776+
let ast_if_result = If::analyze(parsed_if, &expected_type, &mut scope);
1777+
1778+
assert!(ast_if_result
1779+
.err()
1780+
.map(|e| matches!(e.error(), Error::ExpressionTypeMismatch(..)))
1781+
.unwrap());
1782+
}
1783+
1784+
#[test]
1785+
fn test_if_valid_parse_but_invalid_scrutinee() {
1786+
let input = "if (()) { 1 } else { 10 }";
1787+
1788+
let parsed_if = &parse_if(input);
1789+
let expected_type = ResolvedType::from(UIntType::U8);
1790+
let mut scope = Scope::default();
1791+
let ast_if_result = If::analyze(parsed_if, &expected_type, &mut scope);
1792+
1793+
// Expected type of scrutinee is `bool`
1794+
assert!(ast_if_result
1795+
.err()
1796+
.map(|e| matches!(e.error(), Error::ExpressionUnexpectedType(..)))
1797+
.unwrap());
1798+
}
1799+
}

src/compile/mod.rs

Lines changed: 76 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ use simplicity::{types, Cmr, FailEntropy};
1212
use self::builtins::array_fold;
1313
use crate::array::{BTreeSlice, Partition};
1414
use crate::ast::{
15-
Call, CallName, Expression, ExpressionInner, Match, Program, SingleExpression,
15+
Call, CallName, Expression, ExpressionInner, If, Match, Program, SingleExpression,
1616
SingleExpressionInner, Statement,
1717
};
1818
use crate::debug::CallTracker;
@@ -355,6 +355,7 @@ impl SingleExpression {
355355
}
356356
SingleExpressionInner::Call(call) => call.compile(scope)?,
357357
SingleExpressionInner::Match(match_) => match_.compile(scope)?,
358+
SingleExpressionInner::If(if_) => if_.compile(scope)?,
358359
};
359360

360361
scope
@@ -680,3 +681,77 @@ impl Match {
680681
input.comp(&output).with_span(self)
681682
}
682683
}
684+
685+
impl If {
686+
fn compile<'brand>(
687+
&self,
688+
scope: &mut Scope<'brand>,
689+
) -> Result<PairBuilder<ProgNode<'brand>>, RichError> {
690+
scope.push_scope();
691+
scope.insert(Pattern::Ignore);
692+
let then_arm = self.then_arm().compile(scope)?;
693+
scope.pop_scope();
694+
scope.push_scope();
695+
scope.insert(Pattern::Ignore);
696+
let else_arm = self.else_arm().compile(scope)?;
697+
scope.pop_scope();
698+
699+
let scrutinee = self.scrutinee().compile(scope)?;
700+
let input = scrutinee.pair(PairBuilder::iden(scope.ctx()));
701+
// Left = false, right = true
702+
let output = ProgNode::case(else_arm.as_ref(), then_arm.as_ref()).with_span(self)?;
703+
input.comp(&output).with_span(self)
704+
}
705+
}
706+
707+
#[cfg(test)]
708+
mod tests {
709+
use std::sync::Arc;
710+
711+
use super::*;
712+
use crate::parse::ParseFromStr;
713+
use crate::witness::Arguments;
714+
use crate::{ast, parse};
715+
716+
fn compile_program(
717+
input: &str,
718+
) -> Result<Arc<named::CommitNode<Elements>>, crate::error::RichError> {
719+
let parse_program = parse::Program::parse_from_str(input).expect("Failed to parse");
720+
let ast_program = ast::Program::analyze(&parse_program).expect("Failed to analyze");
721+
ast_program.compile(Arguments::default(), false)
722+
}
723+
724+
#[test]
725+
fn match_equivalent_to_if_compiles() {
726+
// The same logic expressed using `match`, which is known to compile correctly.
727+
// Used as a baseline to confirm the test infrastructure works.
728+
let input_match = r#"fn main() {
729+
let x: u16 = 2;
730+
let _s: (bool, u16) = match true {
731+
true => jet::add_16(x, 2),
732+
false => jet::add_16(x, 3),
733+
};
734+
}"#;
735+
let match_node = compile_program(input_match).expect("Match expression should compile");
736+
// Verifies that an if expression with non-unit arms compiles correctly.
737+
//
738+
// This works because in Simplicity types are binary tries: `1 × A = A`
739+
// definitionally (unit contributes zero bits), so the bool scrutinee
740+
// `Either<1, 1>` pairs correctly with the `case` combinator's type
741+
// `(1+1) × Input`.
742+
let input_if = r#"fn main() {
743+
let x: u16 = 2;
744+
let _u: (bool, u16) = if true {
745+
jet::add_16(x, 2)
746+
} else {
747+
jet::add_16(x, 3)
748+
};
749+
}"#;
750+
let if_node = compile_program(input_if).expect("If expression should compile");
751+
752+
assert_eq!(
753+
match_node.display_expr().to_string(),
754+
if_node.display_expr().to_string()
755+
);
756+
}
757+
}

0 commit comments

Comments
 (0)