3333import org .openrewrite .java .tree .TypeUtils ;
3434
3535import java .util .ArrayList ;
36+ import java .util .Arrays ;
3637import java .util .List ;
3738import java .util .Set ;
3839
@@ -85,6 +86,12 @@ public TreeVisitor<?, ExecutionContext> getVisitor() {
8586 MethodMatcher assertThatMatcher = new MethodMatcher ("org.assertj.core.api.Assertions assertThat(..)" );
8687 MethodMatcher chainedAssertMatcher = new MethodMatcher ("java..* " + chainedAssertion + "(..)" );
8788 MethodMatcher assertToReplace = new MethodMatcher ("org.assertj.core.api.* " + this .assertToReplace + "(..)" );
89+ List <MethodMatcher > intermediateMatchers = Arrays .asList (
90+ new MethodMatcher ("org.assertj.core.api.* as(..)" ),
91+ new MethodMatcher ("org.assertj.core.api.* describedAs(..)" ),
92+ new MethodMatcher ("org.assertj.core.api.* withFailMessage(..)" ),
93+ new MethodMatcher ("org.assertj.core.api.* overridingErrorMessage(..)" )
94+ );
8895
8996 return new JavaIsoVisitor <ExecutionContext >() {
9097 @ Override
@@ -96,9 +103,19 @@ public J.MethodInvocation visitMethodInvocation(J.MethodInvocation methodInvocat
96103 return mi ;
97104 }
98105
99- // assertThat has method call
100- J .MethodInvocation assertThat = (J .MethodInvocation ) mi .getSelect ();
101- if (!assertThatMatcher .matches (assertThat ) || !(assertThat .getArguments ().get (0 ) instanceof J .MethodInvocation )) {
106+ // Walk past intermediate methods (as, describedAs, etc.) to find assertThat
107+ List <J .MethodInvocation > intermediates = new ArrayList <>();
108+ J .MethodInvocation current = (J .MethodInvocation ) mi .getSelect ();
109+ while (!assertThatMatcher .matches (current )) {
110+ if (isIntermediate (current ) && current .getSelect () instanceof J .MethodInvocation ) {
111+ intermediates .add (current );
112+ current = (J .MethodInvocation ) current .getSelect ();
113+ } else {
114+ return mi ;
115+ }
116+ }
117+ J .MethodInvocation assertThat = current ;
118+ if (!(assertThat .getArguments ().get (0 ) instanceof J .MethodInvocation )) {
102119 return mi ;
103120 }
104121
@@ -142,11 +159,30 @@ public J.MethodInvocation visitMethodInvocation(J.MethodInvocation methodInvocat
142159 arguments .add (actual );
143160
144161 String template = getStringTemplateAndAppendArguments (assertThatArg , mi , arguments );
145- return JavaTemplate .builder (String .format (template , dedicatedAssertion ))
162+ J . MethodInvocation result = JavaTemplate .builder (String .format (template , dedicatedAssertion ))
146163 .contextSensitive ()
147164 .javaParser (JavaParser .fromJavaVersion ().classpathFromResources (ctx , "junit-jupiter-api-5" , "assertj-core-3" ))
148165 .build ()
149166 .apply (getCursor (), mi .getCoordinates ().replace (), arguments .toArray ());
167+
168+ // Splice intermediate methods (as, describedAs, etc.) back into the chain
169+ if (!intermediates .isEmpty ()) {
170+ Expression chain = result .getSelect ();
171+ for (int i = intermediates .size () - 1 ; i >= 0 ; i --) {
172+ chain = intermediates .get (i ).withSelect (chain );
173+ }
174+ result = result .withSelect (chain );
175+ }
176+ return result ;
177+ }
178+
179+ private boolean isIntermediate (J .MethodInvocation method ) {
180+ for (MethodMatcher matcher : intermediateMatchers ) {
181+ if (matcher .matches (method )) {
182+ return true ;
183+ }
184+ }
185+ return false ;
150186 }
151187
152188 private String getStringTemplateAndAppendArguments (J .MethodInvocation assertThatArg , J .MethodInvocation methodToReplace , List <Expression > arguments ) {
0 commit comments