Skip to content

Commit cf368d7

Browse files
authored
Use declared parameter types for Whitebox.invokeMethod reflection (#945)
Resolve the target method's declared parameter types from the AST instead of using arg.getClass() which returns the runtime concrete class and fails when the method parameter is an interface/parent type. Three-tier resolution strategy: 1. Resolve the target method declaration and use its parameter types 2. Fall back to the argument's compile-time type 3. Fall back to arg.getClass() when no type info is available
1 parent e679ace commit cf368d7

2 files changed

Lines changed: 278 additions & 13 deletions

File tree

src/main/java/org/openrewrite/java/testing/mockito/PowerMockWhiteboxToJavaReflection.java

Lines changed: 110 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
import org.openrewrite.java.tree.*;
2525
import org.openrewrite.marker.Markers;
2626

27+
import java.util.ArrayList;
2728
import java.util.List;
2829

2930
import static java.util.Collections.emptyList;
@@ -89,13 +90,31 @@ public J.Block visitBlock(J.Block block, ExecutionContext ctx) {
8990
continue;
9091
}
9192
Cursor blockCursor = new Cursor(getCursor().getParentOrThrow(), b);
92-
String template = buildReplacementTemplate(stmt, mi, blockCursor);
93+
JavaType.Method resolvedMethod = INVOKE_METHOD.matches(mi) ?
94+
resolveTargetMethod(mi.getArguments()) : null;
95+
String template = buildReplacementTemplate(stmt, mi, blockCursor, resolvedMethod);
9396
if (template != null) {
94-
Object[] templateArgs = buildTemplateArgs(mi);
97+
Object[] templateArgs = buildTemplateArgs(mi, resolvedMethod);
98+
99+
List<String> templateImports = new ArrayList<>();
100+
templateImports.add("java.lang.reflect.Field");
101+
templateImports.add("java.lang.reflect.Method");
102+
if (resolvedMethod != null) {
103+
for (JavaType paramType : resolvedMethod.getParameterTypes()) {
104+
if (paramType instanceof JavaType.FullyQualified) {
105+
JavaType.FullyQualified fq = (JavaType.FullyQualified) paramType;
106+
if (!"java.lang".equals(fq.getPackageName())) {
107+
templateImports.add(fq.getFullyQualifiedName());
108+
maybeAddImport(fq.getFullyQualifiedName());
109+
}
110+
}
111+
}
112+
}
113+
95114
b = JavaTemplate.builder(template)
96115
.contextSensitive()
97116
.javaParser(JavaParser.fromJavaVersion())
98-
.imports("java.lang.reflect.Field", "java.lang.reflect.Method")
117+
.imports(templateImports.toArray(new String[0]))
99118
.build()
100119
.apply(
101120
new Cursor(getCursor().getParentOrThrow(), b),
@@ -117,7 +136,8 @@ public J.Block visitBlock(J.Block block, ExecutionContext ctx) {
117136
return b;
118137
}
119138

120-
private @Nullable String buildReplacementTemplate(Statement statement, J.MethodInvocation mi, Cursor scope) {
139+
private @Nullable String buildReplacementTemplate(Statement statement, J.MethodInvocation mi,
140+
Cursor scope, JavaType.@Nullable Method resolvedMethod) {
121141
List<Expression> args = mi.getArguments();
122142

123143
if (SET_INTERNAL_STATE.matches(mi) && args.size() == 3) {
@@ -127,12 +147,12 @@ public J.Block visitBlock(J.Block block, ExecutionContext ctx) {
127147
return buildGetInternalStateTemplate(args, statement, scope);
128148
}
129149
if (INVOKE_METHOD.matches(mi) && args.size() >= 2) {
130-
return buildInvokeMethodTemplate(args, statement, scope);
150+
return buildInvokeMethodTemplate(args, statement, scope, resolvedMethod);
131151
}
132152
return null;
133153
}
134154

135-
private Object[] buildTemplateArgs(J.MethodInvocation mi) {
155+
private Object[] buildTemplateArgs(J.MethodInvocation mi, JavaType.@Nullable Method resolvedMethod) {
136156
List<Expression> args = mi.getArguments();
137157

138158
if (SET_INTERNAL_STATE.matches(mi) && args.size() == 3) {
@@ -144,7 +164,7 @@ private Object[] buildTemplateArgs(J.MethodInvocation mi) {
144164
return new Object[]{args.get(0), args.get(1), args.get(0)};
145165
}
146166
if (INVOKE_METHOD.matches(mi) && args.size() >= 2) {
147-
return buildInvokeMethodArgs(args);
167+
return buildInvokeMethodArgs(args, resolvedMethod);
148168
}
149169
return new Object[0];
150170
}
@@ -181,7 +201,8 @@ private Object[] buildTemplateArgs(J.MethodInvocation mi) {
181201
return prefix + varName + ".get(#{any(java.lang.Object)});";
182202
}
183203

184-
private @Nullable String buildInvokeMethodTemplate(List<Expression> args, Statement statement, Cursor scope) {
204+
private @Nullable String buildInvokeMethodTemplate(List<Expression> args, Statement statement,
205+
Cursor scope, JavaType.@Nullable Method resolvedMethod) {
185206
String methodName = extractStringLiteral(args.get(1));
186207
if (methodName == null) {
187208
return null;
@@ -192,7 +213,12 @@ private Object[] buildTemplateArgs(J.MethodInvocation mi) {
192213
StringBuilder sb = new StringBuilder();
193214
sb.append("Method ").append(varName).append(" = #{any(java.lang.Object)}.getClass().getDeclaredMethod(#{any(java.lang.String)}");
194215
for (int i = 2; i < args.size(); i++) {
195-
sb.append(", #{any(java.lang.Object)}.getClass()");
216+
String classLiteral = getParamClassLiteral(args, i, resolvedMethod);
217+
if (classLiteral != null) {
218+
sb.append(", ").append(classLiteral);
219+
} else {
220+
sb.append(", #{any(java.lang.Object)}.getClass()");
221+
}
196222
}
197223
sb.append(");\n");
198224

@@ -219,14 +245,22 @@ private Object[] buildTemplateArgs(J.MethodInvocation mi) {
219245
return sb.toString();
220246
}
221247

222-
private Object[] buildInvokeMethodArgs(List<Expression> args) {
248+
private Object[] buildInvokeMethodArgs(List<Expression> args, JavaType.@Nullable Method resolvedMethod) {
223249
int extraArgs = args.size() - 2;
224-
Object[] result = new Object[2 + extraArgs + 1 + extraArgs];
250+
int unresolvedCount = 0;
251+
for (int i = 2; i < args.size(); i++) {
252+
if (getParamClassLiteral(args, i, resolvedMethod) == null) {
253+
unresolvedCount++;
254+
}
255+
}
256+
Object[] result = new Object[2 + unresolvedCount + 1 + extraArgs];
225257
int idx = 0;
226258
result[idx++] = args.get(0); // target for getDeclaredMethod
227259
result[idx++] = args.get(1); // methodName
228260
for (int i = 2; i < args.size(); i++) {
229-
result[idx++] = args.get(i); // arg.getClass() for getDeclaredMethod
261+
if (getParamClassLiteral(args, i, resolvedMethod) == null) {
262+
result[idx++] = args.get(i); // arg.getClass() fallback for getDeclaredMethod
263+
}
230264
}
231265
result[idx++] = args.get(0); // target for invoke
232266
for (int i = 2; i < args.size(); i++) {
@@ -235,6 +269,56 @@ private Object[] buildInvokeMethodArgs(List<Expression> args) {
235269
return result;
236270
}
237271

272+
/**
273+
* Get the class literal for a parameter at the given argument index.
274+
* Prefers the resolved method's declared parameter type, falls back to the argument's
275+
* compile-time type, and returns null if neither is available.
276+
*/
277+
private @Nullable String getParamClassLiteral(List<Expression> args, int argIndex,
278+
JavaType.@Nullable Method resolvedMethod) {
279+
if (resolvedMethod != null) {
280+
String literal = classLiteralFromType(resolvedMethod.getParameterTypes().get(argIndex - 2));
281+
if (literal != null) {
282+
return literal;
283+
}
284+
}
285+
return getClassLiteral(args.get(argIndex));
286+
}
287+
288+
/**
289+
* Resolve the target method from the first argument's type and the method name.
290+
* Returns null if the method cannot be unambiguously resolved (not found, overloaded,
291+
* or missing type information).
292+
*/
293+
private JavaType.@Nullable Method resolveTargetMethod(List<Expression> args) {
294+
if (args.size() <= 2) {
295+
return null;
296+
}
297+
String methodName = extractStringLiteral(args.get(1));
298+
if (methodName == null) {
299+
return null;
300+
}
301+
JavaType targetType = args.get(0).getType();
302+
if (!(targetType instanceof JavaType.FullyQualified)) {
303+
return null;
304+
}
305+
int expectedParamCount = args.size() - 2;
306+
JavaType.Method match = null;
307+
for (JavaType.FullyQualified current = (JavaType.FullyQualified) targetType;
308+
current != null; current = current.getSupertype()) {
309+
for (JavaType.Method method : current.getMethods()) {
310+
if (method.getName().equals(methodName) &&
311+
method.getParameterTypes().size() == expectedParamCount) {
312+
if (match != null) {
313+
return null; // ambiguous overload
314+
}
315+
match = method;
316+
}
317+
}
318+
}
319+
return match;
320+
}
321+
238322
private J.@Nullable MethodInvocation extractWhiteboxInvocation(Statement statement) {
239323
if (statement instanceof J.MethodInvocation) {
240324
J.MethodInvocation mi = (J.MethodInvocation) statement;
@@ -264,6 +348,20 @@ private Object[] buildInvokeMethodArgs(List<Expression> args) {
264348
return null;
265349
}
266350

351+
private @Nullable String getClassLiteral(Expression expr) {
352+
return classLiteralFromType(expr.getType());
353+
}
354+
355+
private @Nullable String classLiteralFromType(@Nullable JavaType type) {
356+
if (type instanceof JavaType.Primitive) {
357+
return ((JavaType.Primitive) type).getKeyword() + ".class";
358+
}
359+
if (type instanceof JavaType.FullyQualified) {
360+
return ((JavaType.FullyQualified) type).getClassName() + ".class";
361+
}
362+
return null;
363+
}
364+
267365
private @Nullable String getCastType(@Nullable JavaType type) {
268366
if (type instanceof JavaType.FullyQualified) {
269367
return ((JavaType.FullyQualified) type).getClassName();

src/test/java/org/openrewrite/java/testing/mockito/PowerMockWhiteboxToJavaReflectionTest.java

Lines changed: 168 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -181,7 +181,7 @@ void testInvokeWithArgs() {
181181
class MyServiceTest {
182182
void testInvokeWithArgs() throws Exception {
183183
MyService service = new MyService();
184-
Method greetMethod = service.getClass().getDeclaredMethod("greet", "World".getClass());
184+
Method greetMethod = service.getClass().getDeclaredMethod("greet", String.class);
185185
greetMethod.setAccessible(true);
186186
String result = (String) greetMethod.invoke(service, "World");
187187
}
@@ -191,6 +191,173 @@ void testInvokeWithArgs() throws Exception {
191191
);
192192
}
193193

194+
@Test
195+
void invokeMethodWithMultipleArgs() {
196+
//language=java
197+
rewriteRun(
198+
java(
199+
"""
200+
class MyService {
201+
private String combine(String a, String b) { return a + b; }
202+
}
203+
"""
204+
),
205+
java(
206+
"""
207+
import org.powermock.reflect.Whitebox;
208+
209+
class MyServiceTest {
210+
void testInvokeWithMultipleArgs() {
211+
MyService service = new MyService();
212+
String result = Whitebox.invokeMethod(service, "combine", "Hello", "World");
213+
}
214+
}
215+
""",
216+
"""
217+
import java.lang.reflect.Method;
218+
219+
class MyServiceTest {
220+
void testInvokeWithMultipleArgs() throws Exception {
221+
MyService service = new MyService();
222+
Method combineMethod = service.getClass().getDeclaredMethod("combine", String.class, String.class);
223+
combineMethod.setAccessible(true);
224+
String result = (String) combineMethod.invoke(service, "Hello", "World");
225+
}
226+
}
227+
"""
228+
)
229+
);
230+
}
231+
232+
@Test
233+
void invokeMethodWithPrimitiveArg() {
234+
//language=java
235+
rewriteRun(
236+
java(
237+
"""
238+
class MyService {
239+
private int doubleIt(int value) { return value * 2; }
240+
}
241+
"""
242+
),
243+
java(
244+
"""
245+
import org.powermock.reflect.Whitebox;
246+
247+
class MyServiceTest {
248+
void testInvokeWithPrimitive() {
249+
MyService service = new MyService();
250+
Whitebox.invokeMethod(service, "doubleIt", 5);
251+
}
252+
}
253+
""",
254+
"""
255+
import java.lang.reflect.Method;
256+
257+
class MyServiceTest {
258+
void testInvokeWithPrimitive() throws Exception {
259+
MyService service = new MyService();
260+
Method doubleItMethod = service.getClass().getDeclaredMethod("doubleIt", int.class);
261+
doubleItMethod.setAccessible(true);
262+
doubleItMethod.invoke(service, 5);
263+
}
264+
}
265+
"""
266+
)
267+
);
268+
}
269+
270+
@Test
271+
void invokeMethodWithConcreteArgButInterfaceParam() {
272+
//language=java
273+
rewriteRun(
274+
java(
275+
"""
276+
import java.util.List;
277+
278+
class MyService {
279+
private String process(List<String> items) { return items.toString(); }
280+
}
281+
"""
282+
),
283+
java(
284+
"""
285+
import java.util.ArrayList;
286+
import org.powermock.reflect.Whitebox;
287+
288+
class MyServiceTest {
289+
void testInvokeWithConcreteArg() {
290+
MyService service = new MyService();
291+
ArrayList<String> items = new ArrayList<>();
292+
String result = Whitebox.invokeMethod(service, "process", items);
293+
}
294+
}
295+
""",
296+
"""
297+
import java.lang.reflect.Method;
298+
import java.util.ArrayList;
299+
import java.util.List;
300+
301+
class MyServiceTest {
302+
void testInvokeWithConcreteArg() throws Exception {
303+
MyService service = new MyService();
304+
ArrayList<String> items = new ArrayList<>();
305+
Method processMethod = service.getClass().getDeclaredMethod("process", List.class);
306+
processMethod.setAccessible(true);
307+
String result = (String) processMethod.invoke(service, items);
308+
}
309+
}
310+
"""
311+
)
312+
);
313+
}
314+
315+
@Test
316+
void invokeMethodWithInterfaceTypedArg() {
317+
//language=java
318+
rewriteRun(
319+
java(
320+
"""
321+
import java.util.List;
322+
323+
class MyService {
324+
private String process(List<String> items) { return items.toString(); }
325+
}
326+
"""
327+
),
328+
java(
329+
"""
330+
import java.util.ArrayList;
331+
import java.util.List;
332+
import org.powermock.reflect.Whitebox;
333+
334+
class MyServiceTest {
335+
void testInvokeWithInterfaceArg() {
336+
MyService service = new MyService();
337+
List<String> items = new ArrayList<>();
338+
String result = Whitebox.invokeMethod(service, "process", items);
339+
}
340+
}
341+
""",
342+
"""
343+
import java.lang.reflect.Method;
344+
import java.util.ArrayList;
345+
import java.util.List;
346+
347+
class MyServiceTest {
348+
void testInvokeWithInterfaceArg() throws Exception {
349+
MyService service = new MyService();
350+
List<String> items = new ArrayList<>();
351+
Method processMethod = service.getClass().getDeclaredMethod("process", List.class);
352+
processMethod.setAccessible(true);
353+
String result = (String) processMethod.invoke(service, items);
354+
}
355+
}
356+
"""
357+
)
358+
);
359+
}
360+
194361
@Test
195362
void throwsExceptionNotDuplicatedWhenAlreadyPresent() {
196363
//language=java

0 commit comments

Comments
 (0)