2424import org .openrewrite .java .tree .*;
2525import org .openrewrite .marker .Markers ;
2626
27+ import java .util .ArrayList ;
2728import java .util .List ;
2829
2930import static java .util .Collections .emptyList ;
@@ -89,13 +90,31 @@ public J.Block visitBlock(J.Block block, ExecutionContext ctx) {
8990 continue ;
9091 }
9192 Cursor blockCursor = new Cursor (getCursor ().getParentOrThrow (), b );
92- String template = buildReplacementTemplate (stmt , mi , blockCursor );
93+ JavaType .Method resolvedMethod = INVOKE_METHOD .matches (mi ) ?
94+ resolveTargetMethod (mi .getArguments ()) : null ;
95+ String template = buildReplacementTemplate (stmt , mi , blockCursor , resolvedMethod );
9396 if (template != null ) {
94- Object [] templateArgs = buildTemplateArgs (mi );
97+ Object [] templateArgs = buildTemplateArgs (mi , resolvedMethod );
98+
99+ List <String > templateImports = new ArrayList <>();
100+ templateImports .add ("java.lang.reflect.Field" );
101+ templateImports .add ("java.lang.reflect.Method" );
102+ if (resolvedMethod != null ) {
103+ for (JavaType paramType : resolvedMethod .getParameterTypes ()) {
104+ if (paramType instanceof JavaType .FullyQualified ) {
105+ JavaType .FullyQualified fq = (JavaType .FullyQualified ) paramType ;
106+ if (!"java.lang" .equals (fq .getPackageName ())) {
107+ templateImports .add (fq .getFullyQualifiedName ());
108+ maybeAddImport (fq .getFullyQualifiedName ());
109+ }
110+ }
111+ }
112+ }
113+
95114 b = JavaTemplate .builder (template )
96115 .contextSensitive ()
97116 .javaParser (JavaParser .fromJavaVersion ())
98- .imports ("java.lang.reflect.Field" , "java.lang.reflect.Method" )
117+ .imports (templateImports . toArray ( new String [ 0 ]) )
99118 .build ()
100119 .apply (
101120 new Cursor (getCursor ().getParentOrThrow (), b ),
@@ -117,7 +136,8 @@ public J.Block visitBlock(J.Block block, ExecutionContext ctx) {
117136 return b ;
118137 }
119138
120- private @ Nullable String buildReplacementTemplate (Statement statement , J .MethodInvocation mi , Cursor scope ) {
139+ private @ Nullable String buildReplacementTemplate (Statement statement , J .MethodInvocation mi ,
140+ Cursor scope , JavaType .@ Nullable Method resolvedMethod ) {
121141 List <Expression > args = mi .getArguments ();
122142
123143 if (SET_INTERNAL_STATE .matches (mi ) && args .size () == 3 ) {
@@ -127,12 +147,12 @@ public J.Block visitBlock(J.Block block, ExecutionContext ctx) {
127147 return buildGetInternalStateTemplate (args , statement , scope );
128148 }
129149 if (INVOKE_METHOD .matches (mi ) && args .size () >= 2 ) {
130- return buildInvokeMethodTemplate (args , statement , scope );
150+ return buildInvokeMethodTemplate (args , statement , scope , resolvedMethod );
131151 }
132152 return null ;
133153 }
134154
135- private Object [] buildTemplateArgs (J .MethodInvocation mi ) {
155+ private Object [] buildTemplateArgs (J .MethodInvocation mi , JavaType . @ Nullable Method resolvedMethod ) {
136156 List <Expression > args = mi .getArguments ();
137157
138158 if (SET_INTERNAL_STATE .matches (mi ) && args .size () == 3 ) {
@@ -144,7 +164,7 @@ private Object[] buildTemplateArgs(J.MethodInvocation mi) {
144164 return new Object []{args .get (0 ), args .get (1 ), args .get (0 )};
145165 }
146166 if (INVOKE_METHOD .matches (mi ) && args .size () >= 2 ) {
147- return buildInvokeMethodArgs (args );
167+ return buildInvokeMethodArgs (args , resolvedMethod );
148168 }
149169 return new Object [0 ];
150170 }
@@ -181,7 +201,8 @@ private Object[] buildTemplateArgs(J.MethodInvocation mi) {
181201 return prefix + varName + ".get(#{any(java.lang.Object)});" ;
182202 }
183203
184- private @ Nullable String buildInvokeMethodTemplate (List <Expression > args , Statement statement , Cursor scope ) {
204+ private @ Nullable String buildInvokeMethodTemplate (List <Expression > args , Statement statement ,
205+ Cursor scope , JavaType .@ Nullable Method resolvedMethod ) {
185206 String methodName = extractStringLiteral (args .get (1 ));
186207 if (methodName == null ) {
187208 return null ;
@@ -192,7 +213,12 @@ private Object[] buildTemplateArgs(J.MethodInvocation mi) {
192213 StringBuilder sb = new StringBuilder ();
193214 sb .append ("Method " ).append (varName ).append (" = #{any(java.lang.Object)}.getClass().getDeclaredMethod(#{any(java.lang.String)}" );
194215 for (int i = 2 ; i < args .size (); i ++) {
195- sb .append (", #{any(java.lang.Object)}.getClass()" );
216+ String classLiteral = getParamClassLiteral (args , i , resolvedMethod );
217+ if (classLiteral != null ) {
218+ sb .append (", " ).append (classLiteral );
219+ } else {
220+ sb .append (", #{any(java.lang.Object)}.getClass()" );
221+ }
196222 }
197223 sb .append (");\n " );
198224
@@ -219,14 +245,22 @@ private Object[] buildTemplateArgs(J.MethodInvocation mi) {
219245 return sb .toString ();
220246 }
221247
222- private Object [] buildInvokeMethodArgs (List <Expression > args ) {
248+ private Object [] buildInvokeMethodArgs (List <Expression > args , JavaType . @ Nullable Method resolvedMethod ) {
223249 int extraArgs = args .size () - 2 ;
224- Object [] result = new Object [2 + extraArgs + 1 + extraArgs ];
250+ int unresolvedCount = 0 ;
251+ for (int i = 2 ; i < args .size (); i ++) {
252+ if (getParamClassLiteral (args , i , resolvedMethod ) == null ) {
253+ unresolvedCount ++;
254+ }
255+ }
256+ Object [] result = new Object [2 + unresolvedCount + 1 + extraArgs ];
225257 int idx = 0 ;
226258 result [idx ++] = args .get (0 ); // target for getDeclaredMethod
227259 result [idx ++] = args .get (1 ); // methodName
228260 for (int i = 2 ; i < args .size (); i ++) {
229- result [idx ++] = args .get (i ); // arg.getClass() for getDeclaredMethod
261+ if (getParamClassLiteral (args , i , resolvedMethod ) == null ) {
262+ result [idx ++] = args .get (i ); // arg.getClass() fallback for getDeclaredMethod
263+ }
230264 }
231265 result [idx ++] = args .get (0 ); // target for invoke
232266 for (int i = 2 ; i < args .size (); i ++) {
@@ -235,6 +269,56 @@ private Object[] buildInvokeMethodArgs(List<Expression> args) {
235269 return result ;
236270 }
237271
272+ /**
273+ * Get the class literal for a parameter at the given argument index.
274+ * Prefers the resolved method's declared parameter type, falls back to the argument's
275+ * compile-time type, and returns null if neither is available.
276+ */
277+ private @ Nullable String getParamClassLiteral (List <Expression > args , int argIndex ,
278+ JavaType .@ Nullable Method resolvedMethod ) {
279+ if (resolvedMethod != null ) {
280+ String literal = classLiteralFromType (resolvedMethod .getParameterTypes ().get (argIndex - 2 ));
281+ if (literal != null ) {
282+ return literal ;
283+ }
284+ }
285+ return getClassLiteral (args .get (argIndex ));
286+ }
287+
288+ /**
289+ * Resolve the target method from the first argument's type and the method name.
290+ * Returns null if the method cannot be unambiguously resolved (not found, overloaded,
291+ * or missing type information).
292+ */
293+ private JavaType .@ Nullable Method resolveTargetMethod (List <Expression > args ) {
294+ if (args .size () <= 2 ) {
295+ return null ;
296+ }
297+ String methodName = extractStringLiteral (args .get (1 ));
298+ if (methodName == null ) {
299+ return null ;
300+ }
301+ JavaType targetType = args .get (0 ).getType ();
302+ if (!(targetType instanceof JavaType .FullyQualified )) {
303+ return null ;
304+ }
305+ int expectedParamCount = args .size () - 2 ;
306+ JavaType .Method match = null ;
307+ for (JavaType .FullyQualified current = (JavaType .FullyQualified ) targetType ;
308+ current != null ; current = current .getSupertype ()) {
309+ for (JavaType .Method method : current .getMethods ()) {
310+ if (method .getName ().equals (methodName ) &&
311+ method .getParameterTypes ().size () == expectedParamCount ) {
312+ if (match != null ) {
313+ return null ; // ambiguous overload
314+ }
315+ match = method ;
316+ }
317+ }
318+ }
319+ return match ;
320+ }
321+
238322 private J .@ Nullable MethodInvocation extractWhiteboxInvocation (Statement statement ) {
239323 if (statement instanceof J .MethodInvocation ) {
240324 J .MethodInvocation mi = (J .MethodInvocation ) statement ;
@@ -264,6 +348,20 @@ private Object[] buildInvokeMethodArgs(List<Expression> args) {
264348 return null ;
265349 }
266350
351+ private @ Nullable String getClassLiteral (Expression expr ) {
352+ return classLiteralFromType (expr .getType ());
353+ }
354+
355+ private @ Nullable String classLiteralFromType (@ Nullable JavaType type ) {
356+ if (type instanceof JavaType .Primitive ) {
357+ return ((JavaType .Primitive ) type ).getKeyword () + ".class" ;
358+ }
359+ if (type instanceof JavaType .FullyQualified ) {
360+ return ((JavaType .FullyQualified ) type ).getClassName () + ".class" ;
361+ }
362+ return null ;
363+ }
364+
267365 private @ Nullable String getCastType (@ Nullable JavaType type ) {
268366 if (type instanceof JavaType .FullyQualified ) {
269367 return ((JavaType .FullyQualified ) type ).getClassName ();
0 commit comments