Skip to content

Commit b3d1518

Browse files
committed
Simplify Expressions Using Derivation Nodes
This allows simplification steps to be expanded by keeping track of their origin nodes, for both constant propagation and constant folding
1 parent 0350e9a commit b3d1518

15 files changed

Lines changed: 507 additions & 490 deletions

liquidjava-verifier/src/main/java/liquidjava/errors/ErrorHandler.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ public static <T> void printError(CtElement var, String moreInfo, Predicate expe
4141
// all message
4242
sb.append(sbtitle.toString() + "\n\n");
4343
sb.append("Type expected:" + expectedType.toString() + "\n");
44-
sb.append("Refinement found:\n" + cSMT.simplify() + "\n");
44+
sb.append("Refinement found:\n" + cSMT.simplify().getValue() + "\n");
4545
sb.append(printMap(map));
4646
sb.append("Location: " + var.getPosition() + "\n");
4747
sb.append("______________________________________________________\n");

liquidjava-verifier/src/main/java/liquidjava/rj_language/Predicate.java

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,8 @@
2121
import liquidjava.rj_language.ast.LiteralReal;
2222
import liquidjava.rj_language.ast.UnaryExpression;
2323
import liquidjava.rj_language.ast.Var;
24-
import liquidjava.rj_language.opt.DerivationNode;
24+
import liquidjava.rj_language.opt.derivation_node.DerivationNode;
25+
import liquidjava.rj_language.opt.derivation_node.ValDerivationNode;
2526
import liquidjava.rj_language.opt.ExpressionSimplifier;
2627
import liquidjava.rj_language.parsing.ParsingException;
2728
import liquidjava.rj_language.parsing.RefinementsParser;
@@ -214,7 +215,7 @@ public Expression getExpression() {
214215
return exp;
215216
}
216217

217-
public DerivationNode simplify() {
218+
public ValDerivationNode simplify() {
218219
return ExpressionSimplifier.simplify(exp.clone());
219220
}
220221

liquidjava-verifier/src/main/java/liquidjava/rj_language/ast/Expression.java

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,11 @@
11
package liquidjava.rj_language.ast;
22

3-
import com.microsoft.z3.Expr;
43
import java.util.ArrayList;
54
import java.util.List;
65
import java.util.Map;
6+
7+
import com.microsoft.z3.Expr;
8+
79
import liquidjava.processor.context.Context;
810
import liquidjava.processor.facade.AliasDTO;
911
import liquidjava.rj_language.ast.typing.TypeInfer;
@@ -47,6 +49,10 @@ public void setChild(int index, Expression element) {
4749
children.set(index, element);
4850
}
4951

52+
public boolean isLiteral() {
53+
return this instanceof LiteralInt || this instanceof LiteralReal || this instanceof LiteralBoolean;
54+
}
55+
5056
/**
5157
* Substitutes the expression first given expression by the second
5258
*

liquidjava-verifier/src/main/java/liquidjava/rj_language/opt/ConstantFolding.java

Lines changed: 139 additions & 88 deletions
Original file line numberDiff line numberDiff line change
@@ -2,135 +2,186 @@
22

33
import liquidjava.rj_language.ast.BinaryExpression;
44
import liquidjava.rj_language.ast.Expression;
5+
import liquidjava.rj_language.ast.GroupExpression;
56
import liquidjava.rj_language.ast.LiteralBoolean;
67
import liquidjava.rj_language.ast.LiteralInt;
78
import liquidjava.rj_language.ast.LiteralReal;
89
import liquidjava.rj_language.ast.UnaryExpression;
10+
import liquidjava.rj_language.opt.derivation_node.BinaryDerivationNode;
11+
import liquidjava.rj_language.opt.derivation_node.DerivationNode;
12+
import liquidjava.rj_language.opt.derivation_node.UnaryDerivationNode;
13+
import liquidjava.rj_language.opt.derivation_node.ValDerivationNode;
914

1015
public class ConstantFolding {
1116

12-
public static Expression fold(Expression exp) {
13-
// recursively simplify in all children
14-
if (exp.hasChildren()) {
15-
for (int i = 0; i < exp.getChildren().size(); i++) {
16-
Expression child = exp.getChildren().get(i);
17-
Expression propagatedChild = fold(child);
18-
exp.setChild(i, propagatedChild);
19-
}
20-
}
21-
22-
// try to fold the current expression
17+
public static ValDerivationNode fold(ValDerivationNode node) {
18+
Expression exp = node.getValue();
2319
if (exp instanceof BinaryExpression) {
24-
return foldBinaryExpression((BinaryExpression) exp);
20+
return foldBinary(node);
2521
}
2622
if (exp instanceof UnaryExpression) {
27-
return foldUnaryExpression((UnaryExpression) exp);
23+
return foldUnary(node);
24+
}
25+
if (exp instanceof GroupExpression) {
26+
GroupExpression group = (GroupExpression) exp;
27+
if (group.getChildren().size() == 1) {
28+
return fold(new ValDerivationNode(group.getChildren().get(0), node.getOrigin()));
29+
}
2830
}
29-
return exp;
31+
return node;
3032
}
3133

32-
private static Expression foldBinaryExpression(BinaryExpression binExp) {
33-
Expression left = binExp.getFirstOperand();
34-
Expression right = binExp.getSecondOperand();
34+
private static ValDerivationNode foldBinary(ValDerivationNode node) {
35+
BinaryExpression binExp = (BinaryExpression) node.getValue();
36+
DerivationNode parent = node.getOrigin();
37+
38+
// fold child nodes
39+
ValDerivationNode leftNode;
40+
ValDerivationNode rightNode;
41+
if (parent instanceof BinaryDerivationNode) {
42+
// has origin (from constant propagation)
43+
BinaryDerivationNode binaryOrigin = (BinaryDerivationNode) parent;
44+
leftNode = fold(binaryOrigin.getLeft());
45+
rightNode = fold(binaryOrigin.getRight());
46+
} else {
47+
// no origin
48+
leftNode = fold(new ValDerivationNode(binExp.getFirstOperand(), null));
49+
rightNode = fold(new ValDerivationNode(binExp.getSecondOperand(), null));
50+
}
51+
52+
Expression left = leftNode.getValue();
53+
Expression right = rightNode.getValue();
3554
String op = binExp.getOperator();
55+
binExp.setChild(0, left);
56+
binExp.setChild(1, right);
3657

37-
// arithmetic operations with integer literals
58+
// int and int
3859
if (left instanceof LiteralInt && right instanceof LiteralInt) {
3960
int l = ((LiteralInt) left).getValue();
4061
int r = ((LiteralInt) right).getValue();
41-
42-
return switch (op) {
43-
case "+" -> new LiteralInt(l + r);
44-
case "-" -> new LiteralInt(l - r);
45-
case "*" -> new LiteralInt(l * r);
46-
case "/" -> r != 0 ? new LiteralInt(l / r) : binExp;
47-
case "%" -> r != 0 ? new LiteralInt(l % r) : binExp;
48-
case "<" -> new LiteralBoolean(l < r);
49-
case "<=" -> new LiteralBoolean(l <= r);
50-
case ">" -> new LiteralBoolean(l > r);
51-
case ">=" -> new LiteralBoolean(l >= r);
52-
case "==" -> new LiteralBoolean(l == r);
53-
case "!=" -> new LiteralBoolean(l != r);
54-
default -> binExp;
62+
Expression res = switch (op) {
63+
case "+" -> new LiteralInt(l + r);
64+
case "-" -> new LiteralInt(l - r);
65+
case "*" -> new LiteralInt(l * r);
66+
case "/" -> r != 0 ? new LiteralInt(l / r) : null;
67+
case "%" -> r != 0 ? new LiteralInt(l % r) : null;
68+
case "<" -> new LiteralBoolean(l < r);
69+
case "<=" -> new LiteralBoolean(l <= r);
70+
case ">" -> new LiteralBoolean(l > r);
71+
case ">=" -> new LiteralBoolean(l >= r);
72+
case "==" -> new LiteralBoolean(l == r);
73+
case "!=" -> new LiteralBoolean(l != r);
74+
default -> null;
5575
};
76+
if (res != null)
77+
return new ValDerivationNode(res, new BinaryDerivationNode(leftNode, rightNode, op));
5678
}
57-
58-
// arithmetic operations with real literals
59-
if (left instanceof LiteralReal && right instanceof LiteralReal) {
79+
// real and real
80+
else if (left instanceof LiteralReal && right instanceof LiteralReal) {
6081
double l = ((LiteralReal) left).getValue();
6182
double r = ((LiteralReal) right).getValue();
62-
return switch (op) {
63-
case "+" -> new LiteralReal(l + r);
64-
case "-" -> new LiteralReal(l - r);
65-
case "*" -> new LiteralReal(l * r);
66-
case "/" -> r != 0.0 ? new LiteralReal(l / r) : binExp;
67-
case "%" -> r != 0.0 ? new LiteralReal(l % r) : binExp;
68-
case "<" -> new LiteralBoolean(l < r);
69-
case "<=" -> new LiteralBoolean(l <= r);
70-
case ">" -> new LiteralBoolean(l > r);
71-
case ">=" -> new LiteralBoolean(l >= r);
72-
case "==" -> new LiteralBoolean(l == r);
73-
case "!=" -> new LiteralBoolean(l != r);
74-
default -> binExp;
83+
Expression res = switch (op) {
84+
case "+" -> new LiteralReal(l + r);
85+
case "-" -> new LiteralReal(l - r);
86+
case "*" -> new LiteralReal(l * r);
87+
case "/" -> r != 0.0 ? new LiteralReal(l / r) : null;
88+
case "%" -> r != 0.0 ? new LiteralReal(l % r) : null;
89+
case "<" -> new LiteralBoolean(l < r);
90+
case "<=" -> new LiteralBoolean(l <= r);
91+
case ">" -> new LiteralBoolean(l > r);
92+
case ">=" -> new LiteralBoolean(l >= r);
93+
case "==" -> new LiteralBoolean(l == r);
94+
case "!=" -> new LiteralBoolean(l != r);
95+
default -> null;
7596
};
97+
if (res != null)
98+
return new ValDerivationNode(res, new BinaryDerivationNode(leftNode, rightNode, op));
7699
}
77100

78-
// mixed integer and real operations
79-
if ((left instanceof LiteralInt && right instanceof LiteralReal) || (left instanceof LiteralReal && right instanceof LiteralInt)) {
101+
// mixed int and real
102+
else if ((left instanceof LiteralInt && right instanceof LiteralReal)
103+
|| (left instanceof LiteralReal && right instanceof LiteralInt)) {
80104
double l = left instanceof LiteralInt ? ((LiteralInt) left).getValue() : ((LiteralReal) left).getValue();
81105
double r = right instanceof LiteralInt ? ((LiteralInt) right).getValue() : ((LiteralReal) right).getValue();
82-
return switch (op) {
83-
case "+" -> new LiteralReal(l + r);
84-
case "-" -> new LiteralReal(l - r);
85-
case "*" -> new LiteralReal(l * r);
86-
case "/" -> r != 0.0 ? new LiteralReal(l / r) : binExp;
87-
case "%" -> r != 0.0 ? new LiteralReal(l % r) : binExp;
88-
case "<" -> new LiteralBoolean(l < r);
89-
case "<=" -> new LiteralBoolean(l <= r);
90-
case ">" -> new LiteralBoolean(l > r);
91-
case ">=" -> new LiteralBoolean(l >= r);
92-
case "==" -> new LiteralBoolean(l == r);
93-
case "!=" -> new LiteralBoolean(l != r);
94-
default -> binExp;
106+
Expression res = switch (op) {
107+
case "+" -> new LiteralReal(l + r);
108+
case "-" -> new LiteralReal(l - r);
109+
case "*" -> new LiteralReal(l * r);
110+
case "/" -> r != 0.0 ? new LiteralReal(l / r) : null;
111+
case "%" -> r != 0.0 ? new LiteralReal(l % r) : null;
112+
case "<" -> new LiteralBoolean(l < r);
113+
case "<=" -> new LiteralBoolean(l <= r);
114+
case ">" -> new LiteralBoolean(l > r);
115+
case ">=" -> new LiteralBoolean(l >= r);
116+
case "==" -> new LiteralBoolean(l == r);
117+
case "!=" -> new LiteralBoolean(l != r);
118+
default -> null;
95119
};
120+
if (res != null)
121+
return new ValDerivationNode(res, new BinaryDerivationNode(leftNode, rightNode, op));
96122
}
97-
98-
// boolean operations with boolean literals
99-
if (left instanceof LiteralBoolean && right instanceof LiteralBoolean) {
123+
// bool and bool
124+
else if (left instanceof LiteralBoolean && right instanceof LiteralBoolean) {
100125
boolean l = ((LiteralBoolean) left).isBooleanTrue();
101126
boolean r = ((LiteralBoolean) right).isBooleanTrue();
102-
return switch (op) {
103-
case "&&" -> new LiteralBoolean(l && r);
104-
case "||" -> new LiteralBoolean(l || r);
105-
case "-->" -> new LiteralBoolean(!l || r);
106-
case "==" -> new LiteralBoolean(l == r);
107-
case "!=" -> new LiteralBoolean(l != r);
108-
default -> binExp;
127+
Expression res = switch (op) {
128+
case "&&" -> new LiteralBoolean(l && r);
129+
case "||" -> new LiteralBoolean(l || r);
130+
case "-->" -> new LiteralBoolean(!l || r);
131+
case "==" -> new LiteralBoolean(l == r);
132+
case "!=" -> new LiteralBoolean(l != r);
133+
default -> null;
109134
};
135+
if (res != null)
136+
return new ValDerivationNode(res, new BinaryDerivationNode(leftNode, rightNode, op));
110137
}
111-
// no folding, return original
112-
return binExp;
138+
139+
// no folding
140+
DerivationNode origin = (leftNode.getOrigin() != null || rightNode.getOrigin() != null)
141+
? new BinaryDerivationNode(leftNode, rightNode, op) : null;
142+
return new ValDerivationNode(binExp, origin);
113143
}
114144

115-
private static Expression foldUnaryExpression(UnaryExpression unaryExp) {
116-
Expression operand = unaryExp.getChildren().get(0);
145+
private static ValDerivationNode foldUnary(ValDerivationNode node) {
146+
UnaryExpression unaryExp = (UnaryExpression) node.getValue();
147+
DerivationNode parent = node.getOrigin();
148+
149+
// fold child node
150+
ValDerivationNode operandNode;
151+
if (parent instanceof UnaryDerivationNode) {
152+
// has origin (from constant propagation)
153+
UnaryDerivationNode unaryOrigin = (UnaryDerivationNode) parent;
154+
operandNode = fold(unaryOrigin.getOperand());
155+
} else {
156+
// no origin
157+
operandNode = fold(new ValDerivationNode(unaryExp.getChildren().get(0), null));
158+
}
159+
Expression operand = operandNode.getValue();
117160
String operator = unaryExp.getOp();
118-
if (operator.equals("!") && operand instanceof LiteralBoolean) {
119-
// !true -> false, !false -> true
161+
unaryExp.setChild(0, operand);
162+
163+
// unary not
164+
if ("!".equals(operator) && operand instanceof LiteralBoolean) {
165+
// !true => false, !false => true
120166
boolean value = ((LiteralBoolean) operand).isBooleanTrue();
121-
return new LiteralBoolean(!value);
167+
Expression res = new LiteralBoolean(!value);
168+
return new ValDerivationNode(res, new UnaryDerivationNode(operandNode, operator));
122169
}
123-
if (operator.equals("-")) {
124-
// -(x) = -x
170+
// unary minus
171+
if ("-".equals(operator)) {
172+
// -(x) => -x
125173
if (operand instanceof LiteralInt) {
126-
int value = ((LiteralInt) operand).getValue();
127-
return new LiteralInt(-value);
174+
Expression res = new LiteralInt(-((LiteralInt) operand).getValue());
175+
return new ValDerivationNode(res, new UnaryDerivationNode(operandNode, operator));
128176
}
129177
if (operand instanceof LiteralReal) {
130-
double value = ((LiteralReal) operand).getValue();
131-
return new LiteralReal(-value);
178+
Expression res = new LiteralReal(-((LiteralReal) operand).getValue());
179+
return new ValDerivationNode(res, new UnaryDerivationNode(operandNode, operator));
132180
}
133181
}
134-
return unaryExp;
182+
183+
// no folding
184+
DerivationNode origin = operandNode.getOrigin() != null ? new UnaryDerivationNode(operandNode, operator) : null;
185+
return new ValDerivationNode(unaryExp, origin);
135186
}
136-
}
187+
}

0 commit comments

Comments
 (0)