diff --git a/liquidjava-verifier/src/main/java/liquidjava/rj_language/opt/ExpressionSimplifier.java b/liquidjava-verifier/src/main/java/liquidjava/rj_language/opt/ExpressionSimplifier.java index c117b05f..07a8432f 100644 --- a/liquidjava-verifier/src/main/java/liquidjava/rj_language/opt/ExpressionSimplifier.java +++ b/liquidjava-verifier/src/main/java/liquidjava/rj_language/opt/ExpressionSimplifier.java @@ -1,5 +1,7 @@ package liquidjava.rj_language.opt; +import liquidjava.processor.context.Context; +import liquidjava.rj_language.Predicate; import java.util.Map; import liquidjava.processor.facade.AliasDTO; @@ -11,6 +13,8 @@ import liquidjava.rj_language.opt.derivation_node.DerivationNode; import liquidjava.rj_language.opt.derivation_node.UnaryDerivationNode; import liquidjava.rj_language.opt.derivation_node.ValDerivationNode; +import liquidjava.smt.SMTEvaluator; +import liquidjava.smt.SMTResult; public class ExpressionSimplifier { @@ -90,6 +94,16 @@ private static ValDerivationNode simplifyValDerivationNode(ValDerivationNode nod return leftSimplified; } + // remove weaker conjuncts (e.g. x > 0 && x > -1 => x > 0) + if (implies(leftSimplified.getValue(), rightSimplified.getValue())) { + return new ValDerivationNode(leftSimplified.getValue(), + new BinaryDerivationNode(leftSimplified, rightSimplified, "&&")); + } + if (implies(rightSimplified.getValue(), leftSimplified.getValue())) { + return new ValDerivationNode(rightSimplified.getValue(), + new BinaryDerivationNode(leftSimplified, rightSimplified, "&&")); + } + // return the conjunction with simplified children Expression newValue = new BinaryExpression(leftSimplified.getValue(), "&&", rightSimplified.getValue()); // only create origin if at least one child has a meaningful origin @@ -191,4 +205,17 @@ private static ValDerivationNode unwrapBooleanLiterals(ValDerivationNode node) { return node; } + + /** + * Checks whether one expression implies another by asking Z3, used to remove weaker conjuncts in the simplification + */ + private static boolean implies(Expression stronger, Expression weaker) { + try { + SMTResult result = new SMTEvaluator().verifySubtype(new Predicate(stronger), new Predicate(weaker), + Context.getInstance()); + return result.isOk(); + } catch (Exception e) { + return false; + } + } } diff --git a/liquidjava-verifier/src/test/java/liquidjava/rj_language/opt/ExpressionSimplifierTest.java b/liquidjava-verifier/src/test/java/liquidjava/rj_language/opt/ExpressionSimplifierTest.java index fde309e4..6ac20359 100644 --- a/liquidjava-verifier/src/test/java/liquidjava/rj_language/opt/ExpressionSimplifierTest.java +++ b/liquidjava-verifier/src/test/java/liquidjava/rj_language/opt/ExpressionSimplifierTest.java @@ -1,6 +1,7 @@ package liquidjava.rj_language.opt; import static org.junit.jupiter.api.Assertions.*; +import static liquidjava.utils.TestUtils.*; import java.util.List; import java.util.Map; @@ -15,7 +16,6 @@ import liquidjava.rj_language.ast.UnaryExpression; import liquidjava.rj_language.ast.Var; import liquidjava.rj_language.opt.derivation_node.BinaryDerivationNode; -import liquidjava.rj_language.opt.derivation_node.DerivationNode; import liquidjava.rj_language.opt.derivation_node.IteDerivationNode; import liquidjava.rj_language.opt.derivation_node.UnaryDerivationNode; import liquidjava.rj_language.opt.derivation_node.ValDerivationNode; @@ -1020,37 +1020,78 @@ void testTwoArgAliasWithNormalExpression() { assertNull(rightNode.getOrigin()); } - /** - * Helper method to compare two derivation nodes recursively - */ - private void assertDerivationEquals(DerivationNode expected, DerivationNode actual, String message) { - if (expected == null && actual == null) - return; - - assertNotNull(expected); - assertEquals(expected.getClass(), actual.getClass(), message + ": node types should match"); - if (expected instanceof ValDerivationNode expectedVal) { - ValDerivationNode actualVal = (ValDerivationNode) actual; - assertEquals(expectedVal.getValue().toString(), actualVal.getValue().toString(), - message + ": values should match"); - assertDerivationEquals(expectedVal.getOrigin(), actualVal.getOrigin(), message + " > origin"); - } else if (expected instanceof BinaryDerivationNode expectedBin) { - BinaryDerivationNode actualBin = (BinaryDerivationNode) actual; - assertEquals(expectedBin.getOp(), actualBin.getOp(), message + ": operators should match"); - assertDerivationEquals(expectedBin.getLeft(), actualBin.getLeft(), message + " > left"); - assertDerivationEquals(expectedBin.getRight(), actualBin.getRight(), message + " > right"); - } else if (expected instanceof VarDerivationNode expectedVar) { - VarDerivationNode actualVar = (VarDerivationNode) actual; - assertEquals(expectedVar.getVar(), actualVar.getVar(), message + ": variables should match"); - } else if (expected instanceof UnaryDerivationNode expectedUnary) { - UnaryDerivationNode actualUnary = (UnaryDerivationNode) actual; - assertEquals(expectedUnary.getOp(), actualUnary.getOp(), message + ": operators should match"); - assertDerivationEquals(expectedUnary.getOperand(), actualUnary.getOperand(), message + " > operand"); - } else if (expected instanceof IteDerivationNode expectedIte) { - IteDerivationNode actualIte = (IteDerivationNode) actual; - assertDerivationEquals(expectedIte.getCondition(), actualIte.getCondition(), message + " > condition"); - assertDerivationEquals(expectedIte.getThenBranch(), actualIte.getThenBranch(), message + " > then"); - assertDerivationEquals(expectedIte.getElseBranch(), actualIte.getElseBranch(), message + " > else"); - } + @Test + void testEntailedConjunctIsRemovedButOriginIsPreserved() { + // Given: b >= 100 && b > 0 + // Expected: b >= 100 (b >= 100 implies b > 0) + + addIntVariableToContext("b"); + Expression b = new Var("b"); + Expression bGe100 = new BinaryExpression(b, ">=", new LiteralInt(100)); + Expression bGt0 = new BinaryExpression(b, ">", new LiteralInt(0)); + Expression fullExpression = new BinaryExpression(bGe100, "&&", bGt0); + + ValDerivationNode result = ExpressionSimplifier.simplify(fullExpression); + + assertNotNull(result); + assertEquals("b >= 100", result.getValue().toString(), + "The weaker conjunct should be removed when implied by the stronger one"); + + ValDerivationNode expectedLeft = new ValDerivationNode(bGe100, null); + ValDerivationNode expectedRight = new ValDerivationNode(bGt0, null); + ValDerivationNode expected = new ValDerivationNode(bGe100, + new BinaryDerivationNode(expectedLeft, expectedRight, "&&")); + + assertDerivationEquals(expected, result, "Entailment simplification should preserve conjunction origin"); + } + + @Test + void testStrictComparisonImpliesNonStrictComparison() { + // Given: x > y && x >= y + // Expected: x > y (x > y implies x >= y) + + addIntVariableToContext("x"); + addIntVariableToContext("y"); + Expression x = new Var("x"); + Expression y = new Var("y"); + Expression xGtY = new BinaryExpression(x, ">", y); + Expression xGeY = new BinaryExpression(x, ">=", y); + Expression fullExpression = new BinaryExpression(xGtY, "&&", xGeY); + + ValDerivationNode result = ExpressionSimplifier.simplify(fullExpression); + + assertNotNull(result); + assertEquals("x > y", result.getValue().toString(), + "The stricter comparison should be kept when it implies the weaker one"); + + ValDerivationNode expectedLeft = new ValDerivationNode(xGtY, null); + ValDerivationNode expectedRight = new ValDerivationNode(xGeY, null); + ValDerivationNode expected = new ValDerivationNode(xGtY, + new BinaryDerivationNode(expectedLeft, expectedRight, "&&")); + + assertDerivationEquals(expected, result, "Strict comparison simplification should preserve conjunction origin"); + } + + @Test + void testEquivalentBoundsKeepOneSide() { + // Given: i >= 0 && 0 <= i + // Expected: 0 <= i (both conjuncts express the same condition) + addIntVariableToContext("i"); + Expression i = new Var("i"); + Expression zeroLeI = new BinaryExpression(new LiteralInt(0), "<=", i); + Expression iGeZero = new BinaryExpression(i, ">=", new LiteralInt(0)); + Expression fullExpression = new BinaryExpression(zeroLeI, "&&", iGeZero); + + ValDerivationNode result = ExpressionSimplifier.simplify(fullExpression); + + assertNotNull(result); + assertEquals("0 <= i", result.getValue().toString(), "Equivalent bounds should collapse to a single conjunct"); + + ValDerivationNode expectedLeft = new ValDerivationNode(zeroLeI, null); + ValDerivationNode expectedRight = new ValDerivationNode(iGeZero, null); + ValDerivationNode expected = new ValDerivationNode(zeroLeI, + new BinaryDerivationNode(expectedLeft, expectedRight, "&&")); + + assertDerivationEquals(expected, result, "Equivalent bounds simplification should preserve conjunction origin"); } } diff --git a/liquidjava-verifier/src/test/java/liquidjava/utils/TestUtils.java b/liquidjava-verifier/src/test/java/liquidjava/utils/TestUtils.java index cd15f869..628ecc40 100644 --- a/liquidjava-verifier/src/test/java/liquidjava/utils/TestUtils.java +++ b/liquidjava-verifier/src/test/java/liquidjava/utils/TestUtils.java @@ -1,13 +1,30 @@ package liquidjava.utils; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertNotNull; + import java.io.IOException; import java.nio.file.Files; import java.nio.file.Path; import java.util.Optional; import java.util.stream.Stream; +import liquidjava.processor.context.Context; +import liquidjava.rj_language.Predicate; +import liquidjava.rj_language.opt.derivation_node.BinaryDerivationNode; +import liquidjava.rj_language.opt.derivation_node.DerivationNode; +import liquidjava.rj_language.opt.derivation_node.IteDerivationNode; +import liquidjava.rj_language.opt.derivation_node.UnaryDerivationNode; +import liquidjava.rj_language.opt.derivation_node.ValDerivationNode; +import liquidjava.rj_language.opt.derivation_node.VarDerivationNode; +import spoon.Launcher; +import spoon.reflect.factory.Factory; + public class TestUtils { + private final static Factory factory = new Launcher().getFactory(); + private final static Context context = Context.getInstance(); + /** * Determines if the given path indicates that the test should pass * @@ -64,4 +81,48 @@ public static Optional getExpectedErrorFromDirectory(Path dirPath) { } return Optional.empty(); } + + /** + * Helper method to compare two derivation nodes recursively + */ + public static void assertDerivationEquals(DerivationNode expected, DerivationNode actual, String message) { + if (expected == null && actual == null) + return; + + assertNotNull(expected); + assertEquals(expected.getClass(), actual.getClass(), message + ": node types should match"); + if (expected instanceof ValDerivationNode expectedVal) { + ValDerivationNode actualVal = (ValDerivationNode) actual; + assertEquals(expectedVal.getValue().toString(), actualVal.getValue().toString(), + message + ": values should match"); + assertDerivationEquals(expectedVal.getOrigin(), actualVal.getOrigin(), message + " > origin"); + } else if (expected instanceof BinaryDerivationNode expectedBin) { + BinaryDerivationNode actualBin = (BinaryDerivationNode) actual; + assertEquals(expectedBin.getOp(), actualBin.getOp(), message + ": operators should match"); + assertDerivationEquals(expectedBin.getLeft(), actualBin.getLeft(), message + " > left"); + assertDerivationEquals(expectedBin.getRight(), actualBin.getRight(), message + " > right"); + } else if (expected instanceof VarDerivationNode expectedVar) { + VarDerivationNode actualVar = (VarDerivationNode) actual; + assertEquals(expectedVar.getVar(), actualVar.getVar(), message + ": variables should match"); + } else if (expected instanceof UnaryDerivationNode expectedUnary) { + UnaryDerivationNode actualUnary = (UnaryDerivationNode) actual; + assertEquals(expectedUnary.getOp(), actualUnary.getOp(), message + ": operators should match"); + assertDerivationEquals(expectedUnary.getOperand(), actualUnary.getOperand(), message + " > operand"); + } else if (expected instanceof IteDerivationNode expectedIte) { + IteDerivationNode actualIte = (IteDerivationNode) actual; + assertDerivationEquals(expectedIte.getCondition(), actualIte.getCondition(), message + " > condition"); + assertDerivationEquals(expectedIte.getThenBranch(), actualIte.getThenBranch(), message + " > then"); + assertDerivationEquals(expectedIte.getElseBranch(), actualIte.getElseBranch(), message + " > else"); + } + } + + /** + * Helper method to add an integer variable to the context Needed for tests that rely on the SMT-based implication + * checks The simplifier asks Z3 whether one conjunct implies another, so every variable in those expressions must + * be in the context + */ + public static void addIntVariableToContext(String name) { + context.addVarToContext(name, factory.Type().INTEGER_PRIMITIVE, new Predicate(), + factory.Code().createCodeSnippetStatement("")); + } }