4949import org .codehaus .groovy .ast .expr .BooleanExpression ;
5050import org .codehaus .groovy .ast .expr .CastExpression ;
5151import org .codehaus .groovy .ast .expr .ClosureExpression ;
52+ import org .codehaus .groovy .ast .expr .ConstantExpression ;
5253import org .codehaus .groovy .ast .expr .ConstructorCallExpression ;
5354import org .codehaus .groovy .ast .expr .Expression ;
5455import org .codehaus .groovy .ast .expr .MethodCallExpression ;
5960import org .codehaus .groovy .ast .expr .VariableExpression ;
6061import org .codehaus .groovy .ast .stmt .BlockStatement ;
6162import org .codehaus .groovy .ast .stmt .EmptyStatement ;
63+ import org .codehaus .groovy .ast .stmt .ReturnStatement ;
6264import org .codehaus .groovy .control .SourceUnit ;
6365import org .codehaus .groovy .control .io .ReaderSource ;
6466import org .codehaus .groovy .syntax .Token ;
6567import org .codehaus .groovy .syntax .Types ;
68+ import org .codehaus .groovy .transform .stc .StaticTypesMarker ;
6669import org .objectweb .asm .Opcodes ;
6770
6871import java .util .ArrayList ;
7174import java .util .Iterator ;
7275import java .util .List ;
7376import java .util .Map ;
77+ import java .util .Set ;
7478
7579import static org .codehaus .groovy .ast .tools .GeneralUtils .args ;
7680import static org .codehaus .groovy .ast .tools .GeneralUtils .callX ;
@@ -175,6 +179,8 @@ public void visitConstructorOrMethod(MethodNode methodNode, boolean isConstructo
175179 replaceWithClosureClassReference (annotationNode , methodNode );
176180 }
177181
182+ recordNonNullReturnViolations (methodNode );
183+
178184 markProcessed (methodNode );
179185
180186 super .visitConstructorOrMethod (methodNode , isConstructor );
@@ -201,6 +207,8 @@ private void replaceWithClosureClassReference(AnnotationNode annotationNode, Met
201207 List <BooleanExpression > booleanExpressions = ExpressionUtils .getBooleanExpression (closureExpression );
202208 if (booleanExpressions == null || booleanExpressions .isEmpty ()) return ;
203209
210+ inferNonNullFromContract (methodNode , annotationNode , booleanExpressions );
211+
204212 boolean isConstructor = methodNode .isConstructor ();
205213 boolean isPostcondition = AnnotationUtils .hasAnnotationOfType (annotationNode .getClassNode (), POSTCONDITION_TYPE_NAME );
206214
@@ -276,6 +284,100 @@ private void markProcessed(ASTNode someNode) {
276284 someNode .setNodeMetaData (PROCESSED , Boolean .TRUE );
277285 }
278286
287+ /**
288+ * Records non-null facts derivable from top-level {@code x != null} conjuncts in a
289+ * {@code @Requires} or {@code @Ensures} closure, as {@link StaticTypesMarker#INFERRED_NON_NULL}
290+ * metadata on the matching {@link Parameter} (or the {@link MethodNode} itself when the closure
291+ * references {@code result} in a postcondition).
292+ */
293+ private static void inferNonNullFromContract (MethodNode methodNode , AnnotationNode annotationNode , List <BooleanExpression > booleanExpressions ) {
294+ String simpleName = annotationNode .getClassNode ().getNameWithoutPackage ();
295+ boolean isEnsures = "Ensures" .equals (simpleName );
296+ if (!("Requires" .equals (simpleName ) || isEnsures )) return ;
297+ for (BooleanExpression booleanExpression : booleanExpressions ) {
298+ if (booleanExpression == null ) continue ;
299+ collectNonNullFacts (booleanExpression .getExpression (), methodNode , isEnsures );
300+ }
301+ }
302+
303+ private static void collectNonNullFacts (Expression expr , MethodNode methodNode , boolean isEnsures ) {
304+ if (!(expr instanceof BinaryExpression bin )) return ;
305+ int op = bin .getOperation ().getType ();
306+ if (op == Types .LOGICAL_AND ) {
307+ collectNonNullFacts (bin .getLeftExpression (), methodNode , isEnsures );
308+ collectNonNullFacts (bin .getRightExpression (), methodNode , isEnsures );
309+ return ;
310+ }
311+ if (op != Types .COMPARE_NOT_EQUAL && op != Types .COMPARE_NOT_IDENTICAL ) return ;
312+ String name = matchVarAgainstNull (bin .getLeftExpression (), bin .getRightExpression ());
313+ if (name == null ) name = matchVarAgainstNull (bin .getRightExpression (), bin .getLeftExpression ());
314+ if (name == null ) return ;
315+ if (isEnsures && "result" .equals (name )) {
316+ methodNode .putNodeMetaData (StaticTypesMarker .INFERRED_NON_NULL , Boolean .TRUE );
317+ return ;
318+ }
319+ for (Parameter p : methodNode .getParameters ()) {
320+ if (p .getName ().equals (name )) {
321+ p .putNodeMetaData (StaticTypesMarker .INFERRED_NON_NULL , Boolean .TRUE );
322+ return ;
323+ }
324+ }
325+ }
326+
327+ private static String matchVarAgainstNull (Expression maybeVar , Expression maybeNull ) {
328+ if (!(maybeNull instanceof ConstantExpression ) || !((ConstantExpression ) maybeNull ).isNullExpression ()) return null ;
329+ if (maybeVar instanceof VariableExpression ve ) return ve .getName ();
330+ return null ;
331+ }
332+
333+ private static final Set <String > NONNULL_ANNO_SIMPLE_NAMES = Set .of ("NonNull" , "NotNull" , "Nonnull" );
334+
335+ /**
336+ * Records {@code return null} statements on a method whose return is effectively {@code @NonNull}
337+ * and whose body will be rewritten by the contracts transform (postcondition or class invariant).
338+ * The stashed list is read later by a downstream checker (e.g. NullChecker) which would otherwise
339+ * no longer see the literal null after the rewrite.
340+ */
341+ private static void recordNonNullReturnViolations (MethodNode methodNode ) {
342+ if (methodNode .isVoidMethod () || methodNode .isAbstract () || methodNode .getCode () == null ) return ;
343+ if (!isEffectivelyNonNullReturn (methodNode )) return ;
344+ if (!willHavePostconditionRewrite (methodNode )) return ;
345+ List <ASTNode > violations = null ;
346+ for (ReturnStatement rs : AssertStatementCreationUtility .getReturnStatements (methodNode )) {
347+ Expression expr = rs .getExpression ();
348+ if (expr instanceof ConstantExpression ce && ce .isNullExpression ()) {
349+ if (violations == null ) violations = new ArrayList <>();
350+ violations .add (rs );
351+ }
352+ }
353+ if (violations != null ) {
354+ methodNode .putNodeMetaData (StaticTypesMarker .INFERRED_NON_NULL_RETURN_VIOLATIONS , violations );
355+ }
356+ }
357+
358+ private static boolean isEffectivelyNonNullReturn (MethodNode method ) {
359+ if (Boolean .TRUE .equals (method .getNodeMetaData (StaticTypesMarker .INFERRED_NON_NULL ))) return true ;
360+ for (AnnotationNode a : method .getAnnotations ()) {
361+ if (NONNULL_ANNO_SIMPLE_NAMES .contains (a .getClassNode ().getNameWithoutPackage ())) return true ;
362+ }
363+ return false ;
364+ }
365+
366+ private static boolean willHavePostconditionRewrite (MethodNode method ) {
367+ for (AnnotationNode a : method .getAnnotations ()) {
368+ String name = a .getClassNode ().getName ();
369+ if ("groovy.contracts.Ensures" .equals (name ) || "groovy.contracts.EnsuresConditions" .equals (name )) return true ;
370+ }
371+ ClassNode cls = method .getDeclaringClass ();
372+ if (cls != null ) {
373+ for (AnnotationNode a : cls .getAnnotations ()) {
374+ String name = a .getClassNode ().getName ();
375+ if ("groovy.contracts.Invariant" .equals (name ) || "groovy.contracts.Invariants" .equals (name )) return true ;
376+ }
377+ }
378+ return false ;
379+ }
380+
279381 //--------------------------------------------------------------------------
280382
281383 static class ClosureExpressionValidator extends ClassCodeVisitorSupport {
0 commit comments