Skip to content

Commit e36d256

Browse files
committed
GROOVY-11894: Provide a NullChecker for groovy-typecheckers (some combinations of NullChecker and @Requires/@ensures weren't working)
1 parent bf84d92 commit e36d256

6 files changed

Lines changed: 335 additions & 1 deletion

File tree

src/main/java/org/codehaus/groovy/transform/stc/StaticTypesMarker.java

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -57,5 +57,9 @@ public enum StaticTypesMarker {
5757
/** used to store the condition expression type of the switch-case statement */
5858
SWITCH_CONDITION_EXPRESSION_TYPE,
5959
/** used to store the result of {@link StaticTypeCheckingVisitor#getType} */
60-
TYPE
60+
TYPE,
61+
/** indicates a parameter or method return is known to be non-null (e.g., inferred from {@code @Requires}/{@code @Ensures} contracts) */
62+
INFERRED_NON_NULL,
63+
/** list of {@code return null} statements recorded on a method before its body is rewritten, so a downstream checker can still report them as non-null violations */
64+
INFERRED_NON_NULL_RETURN_VIOLATIONS
6165
}

subprojects/groovy-contracts/src/main/java/org/apache/groovy/contracts/ast/visitor/AnnotationClosureVisitor.java

Lines changed: 102 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,7 @@
4949
import org.codehaus.groovy.ast.expr.BooleanExpression;
5050
import org.codehaus.groovy.ast.expr.CastExpression;
5151
import org.codehaus.groovy.ast.expr.ClosureExpression;
52+
import org.codehaus.groovy.ast.expr.ConstantExpression;
5253
import org.codehaus.groovy.ast.expr.ConstructorCallExpression;
5354
import org.codehaus.groovy.ast.expr.Expression;
5455
import org.codehaus.groovy.ast.expr.MethodCallExpression;
@@ -59,10 +60,12 @@
5960
import org.codehaus.groovy.ast.expr.VariableExpression;
6061
import org.codehaus.groovy.ast.stmt.BlockStatement;
6162
import org.codehaus.groovy.ast.stmt.EmptyStatement;
63+
import org.codehaus.groovy.ast.stmt.ReturnStatement;
6264
import org.codehaus.groovy.control.SourceUnit;
6365
import org.codehaus.groovy.control.io.ReaderSource;
6466
import org.codehaus.groovy.syntax.Token;
6567
import org.codehaus.groovy.syntax.Types;
68+
import org.codehaus.groovy.transform.stc.StaticTypesMarker;
6669
import org.objectweb.asm.Opcodes;
6770

6871
import java.util.ArrayList;
@@ -71,6 +74,7 @@
7174
import java.util.Iterator;
7275
import java.util.List;
7376
import java.util.Map;
77+
import java.util.Set;
7478

7579
import static org.codehaus.groovy.ast.tools.GeneralUtils.args;
7680
import static org.codehaus.groovy.ast.tools.GeneralUtils.callX;
@@ -175,6 +179,8 @@ public void visitConstructorOrMethod(MethodNode methodNode, boolean isConstructo
175179
replaceWithClosureClassReference(annotationNode, methodNode);
176180
}
177181

182+
recordNonNullReturnViolations(methodNode);
183+
178184
markProcessed(methodNode);
179185

180186
super.visitConstructorOrMethod(methodNode, isConstructor);
@@ -201,6 +207,8 @@ private void replaceWithClosureClassReference(AnnotationNode annotationNode, Met
201207
List<BooleanExpression> booleanExpressions = ExpressionUtils.getBooleanExpression(closureExpression);
202208
if (booleanExpressions == null || booleanExpressions.isEmpty()) return;
203209

210+
inferNonNullFromContract(methodNode, annotationNode, booleanExpressions);
211+
204212
boolean isConstructor = methodNode.isConstructor();
205213
boolean isPostcondition = AnnotationUtils.hasAnnotationOfType(annotationNode.getClassNode(), POSTCONDITION_TYPE_NAME);
206214

@@ -276,6 +284,100 @@ private void markProcessed(ASTNode someNode) {
276284
someNode.setNodeMetaData(PROCESSED, Boolean.TRUE);
277285
}
278286

287+
/**
288+
* Records non-null facts derivable from top-level {@code x != null} conjuncts in a
289+
* {@code @Requires} or {@code @Ensures} closure, as {@link StaticTypesMarker#INFERRED_NON_NULL}
290+
* metadata on the matching {@link Parameter} (or the {@link MethodNode} itself when the closure
291+
* references {@code result} in a postcondition).
292+
*/
293+
private static void inferNonNullFromContract(MethodNode methodNode, AnnotationNode annotationNode, List<BooleanExpression> booleanExpressions) {
294+
String simpleName = annotationNode.getClassNode().getNameWithoutPackage();
295+
boolean isEnsures = "Ensures".equals(simpleName);
296+
if (!("Requires".equals(simpleName) || isEnsures)) return;
297+
for (BooleanExpression booleanExpression : booleanExpressions) {
298+
if (booleanExpression == null) continue;
299+
collectNonNullFacts(booleanExpression.getExpression(), methodNode, isEnsures);
300+
}
301+
}
302+
303+
private static void collectNonNullFacts(Expression expr, MethodNode methodNode, boolean isEnsures) {
304+
if (!(expr instanceof BinaryExpression bin)) return;
305+
int op = bin.getOperation().getType();
306+
if (op == Types.LOGICAL_AND) {
307+
collectNonNullFacts(bin.getLeftExpression(), methodNode, isEnsures);
308+
collectNonNullFacts(bin.getRightExpression(), methodNode, isEnsures);
309+
return;
310+
}
311+
if (op != Types.COMPARE_NOT_EQUAL && op != Types.COMPARE_NOT_IDENTICAL) return;
312+
String name = matchVarAgainstNull(bin.getLeftExpression(), bin.getRightExpression());
313+
if (name == null) name = matchVarAgainstNull(bin.getRightExpression(), bin.getLeftExpression());
314+
if (name == null) return;
315+
if (isEnsures && "result".equals(name)) {
316+
methodNode.putNodeMetaData(StaticTypesMarker.INFERRED_NON_NULL, Boolean.TRUE);
317+
return;
318+
}
319+
for (Parameter p : methodNode.getParameters()) {
320+
if (p.getName().equals(name)) {
321+
p.putNodeMetaData(StaticTypesMarker.INFERRED_NON_NULL, Boolean.TRUE);
322+
return;
323+
}
324+
}
325+
}
326+
327+
private static String matchVarAgainstNull(Expression maybeVar, Expression maybeNull) {
328+
if (!(maybeNull instanceof ConstantExpression) || !((ConstantExpression) maybeNull).isNullExpression()) return null;
329+
if (maybeVar instanceof VariableExpression ve) return ve.getName();
330+
return null;
331+
}
332+
333+
private static final Set<String> NONNULL_ANNO_SIMPLE_NAMES = Set.of("NonNull", "NotNull", "Nonnull");
334+
335+
/**
336+
* Records {@code return null} statements on a method whose return is effectively {@code @NonNull}
337+
* and whose body will be rewritten by the contracts transform (postcondition or class invariant).
338+
* The stashed list is read later by a downstream checker (e.g. NullChecker) which would otherwise
339+
* no longer see the literal null after the rewrite.
340+
*/
341+
private static void recordNonNullReturnViolations(MethodNode methodNode) {
342+
if (methodNode.isVoidMethod() || methodNode.isAbstract() || methodNode.getCode() == null) return;
343+
if (!isEffectivelyNonNullReturn(methodNode)) return;
344+
if (!willHavePostconditionRewrite(methodNode)) return;
345+
List<ASTNode> violations = null;
346+
for (ReturnStatement rs : AssertStatementCreationUtility.getReturnStatements(methodNode)) {
347+
Expression expr = rs.getExpression();
348+
if (expr instanceof ConstantExpression ce && ce.isNullExpression()) {
349+
if (violations == null) violations = new ArrayList<>();
350+
violations.add(rs);
351+
}
352+
}
353+
if (violations != null) {
354+
methodNode.putNodeMetaData(StaticTypesMarker.INFERRED_NON_NULL_RETURN_VIOLATIONS, violations);
355+
}
356+
}
357+
358+
private static boolean isEffectivelyNonNullReturn(MethodNode method) {
359+
if (Boolean.TRUE.equals(method.getNodeMetaData(StaticTypesMarker.INFERRED_NON_NULL))) return true;
360+
for (AnnotationNode a : method.getAnnotations()) {
361+
if (NONNULL_ANNO_SIMPLE_NAMES.contains(a.getClassNode().getNameWithoutPackage())) return true;
362+
}
363+
return false;
364+
}
365+
366+
private static boolean willHavePostconditionRewrite(MethodNode method) {
367+
for (AnnotationNode a : method.getAnnotations()) {
368+
String name = a.getClassNode().getName();
369+
if ("groovy.contracts.Ensures".equals(name) || "groovy.contracts.EnsuresConditions".equals(name)) return true;
370+
}
371+
ClassNode cls = method.getDeclaringClass();
372+
if (cls != null) {
373+
for (AnnotationNode a : cls.getAnnotations()) {
374+
String name = a.getClassNode().getName();
375+
if ("groovy.contracts.Invariant".equals(name) || "groovy.contracts.Invariants".equals(name)) return true;
376+
}
377+
}
378+
return false;
379+
}
380+
279381
//--------------------------------------------------------------------------
280382

281383
static class ClosureExpressionValidator extends ClassCodeVisitorSupport {

subprojects/groovy-typecheckers/src/main/groovy/groovy/typecheckers/NullChecker.groovy

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -118,6 +118,14 @@ class NullChecker extends GroovyTypeCheckingExtensionSupport.TypeCheckingDSL {
118118
private CheckingVisitor makeVisitor(boolean flowSensitive, MethodNode method) {
119119
boolean classNonNullByDefault = method.declaringClass != null && hasNonNullByDefaultAnno(method.declaringClass)
120120
boolean methodNonNull = method.returnType != VOID_TYPE && (hasNonNullAnno(method) || (classNonNullByDefault && !hasNullableAnno(method)))
121+
if (methodNonNull) {
122+
def stash = method.getNodeMetaData(StaticTypesMarker.INFERRED_NON_NULL_RETURN_VIOLATIONS)
123+
if (stash instanceof List) {
124+
stash.each { node ->
125+
addStaticTypeError("Cannot return null from @NonNull method '${method.name}'", node)
126+
}
127+
}
128+
}
121129
def initialNullable = method.parameters.findAll { hasNullableAnno(it) } as Set<Variable>
122130

123131
new CheckingVisitor() {
@@ -348,6 +356,7 @@ class NullChecker extends GroovyTypeCheckingExtensionSupport.TypeCheckingDSL {
348356
}
349357

350358
private static boolean hasNonNullAnno(AnnotatedNode node) {
359+
if (node.getNodeMetaData(StaticTypesMarker.INFERRED_NON_NULL) == Boolean.TRUE) return true
351360
node.annotations?.any { it.classNode?.nameWithoutPackage in NONNULL_ANNOS } ?: false
352361
}
353362

subprojects/groovy-typecheckers/src/spec/doc/typecheckers.adoc

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -765,6 +765,50 @@ We would see the following error at compile-time:
765765
include::../test/NullCheckerTest.groovy[tags=nonnull_by_default_message,indent=0]
766766
----
767767
768+
=== Integration with @Requires and @Ensures
769+
770+
When `groovy-contracts` is on the classpath, `NullChecker` recognises null-safety
771+
facts expressed through Design-by-Contract annotations. Top-level `x != null`
772+
conjuncts inside a `@Requires` closure mark the corresponding parameter as
773+
effectively `@NonNull`; a `result != null` conjunct inside an `@Ensures` closure
774+
marks the method return the same way. Both orderings (`x != null` and
775+
`null != x`) are recognised. Disjunctions (`||`) are deliberately not used for
776+
inference, since they do not establish non-nullness of any individual operand.
777+
778+
[source,groovy]
779+
----
780+
include::../test/NullCheckerTest.groovy[tags=requires_integration,indent=0]
781+
----
782+
783+
We would see the following error at compile-time:
784+
785+
[source]
786+
----
787+
include::../test/NullCheckerTest.groovy[tags=requires_integration_message,indent=0]
788+
----
789+
790+
The same applies on the return side with `@Ensures`:
791+
792+
[source,groovy]
793+
----
794+
include::../test/NullCheckerTest.groovy[tags=ensures_integration,indent=0]
795+
----
796+
797+
We would see the following error at compile-time:
798+
799+
[source]
800+
----
801+
include::../test/NullCheckerTest.groovy[tags=ensures_integration_message,indent=0]
802+
----
803+
804+
This coexists with explicit `@NonNull` annotations: a method marked `@NonNull`
805+
that also carries any `@Ensures` postcondition, or that lives in a class with
806+
an `@Invariant` class invariant, still has its literal `return null` statements
807+
flagged. Without this coordination, the `groovy-contracts` AST transform would
808+
rewrite the return statement before the type-checking pass runs, hiding the
809+
literal from the checker. The runtime postcondition continues to guard non-literal
810+
cases.
811+
768812
=== Flow-sensitive mode with strict option
769813
770814
`NullChecker` only flags issues involving annotated code. `NullChecker(strict: true)` adds

subprojects/groovy-typecheckers/src/spec/test/NullCheckerTest.groovy

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -416,4 +416,54 @@ class NullCheckerTest {
416416
''')
417417
}
418418

419+
@Test
420+
void testRequiresIntegration() {
421+
def err = shouldFail('''
422+
import groovy.transform.TypeChecked
423+
import groovy.contracts.Requires
424+
425+
// tag::requires_integration[]
426+
@TypeChecked(extensions='groovy.typecheckers.NullChecker')
427+
class Greeter {
428+
@Requires({ name != null })
429+
static String greet(name) { "Hello, $name!" }
430+
431+
static void main(String[] args) {
432+
greet(null) // caught at compile time
433+
}
434+
}
435+
// end::requires_integration[]
436+
''')
437+
def expectedError = '''\
438+
# tag::requires_integration_message[]
439+
[Static type checking] - Cannot pass null to @NonNull parameter 'name' of 'greet'
440+
# end::requires_integration_message[]
441+
'''
442+
assert err.message.contains(expectedError.readLines()[1].trim())
443+
}
444+
445+
@Test
446+
void testEnsuresIntegration() {
447+
def err = shouldFail('''
448+
import groovy.transform.TypeChecked
449+
import groovy.contracts.Ensures
450+
451+
// tag::ensures_integration[]
452+
@TypeChecked(extensions='groovy.typecheckers.NullChecker')
453+
class Greeter {
454+
@Ensures({ result != null })
455+
String greet() {
456+
return null // caught at compile time
457+
}
458+
}
459+
// end::ensures_integration[]
460+
''')
461+
def expectedError = '''\
462+
# tag::ensures_integration_message[]
463+
[Static type checking] - Cannot return null from @NonNull method 'greet'
464+
# end::ensures_integration_message[]
465+
'''
466+
assert err.message.contains(expectedError.readLines()[1].trim())
467+
}
468+
419469
}

0 commit comments

Comments
 (0)