Skip to content

Commit 2081b63

Browse files
committed
Fix looking up private static final Capability fields
1 parent 94f1fbf commit 2081b63

1 file changed

Lines changed: 134 additions & 18 deletions

File tree

src/main/java/org/embeddedt/modernfix/forge/capability/CapabilityProviderDispatcherGenerator.java

Lines changed: 134 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
package org.embeddedt.modernfix.forge.capability;
22

3+
import net.minecraftforge.common.capabilities.Capability;
34
import net.minecraftforge.common.capabilities.ICapabilityProvider;
45
import net.minecraftforge.common.util.LazyOptional;
56
import org.embeddedt.modernfix.ModernFix;
@@ -10,6 +11,7 @@
1011
import org.objectweb.asm.commons.GeneratorAdapter;
1112
import org.objectweb.asm.commons.Method;
1213

14+
import java.lang.reflect.Field;
1315
import java.lang.reflect.Modifier;
1416
import java.lang.invoke.MethodHandle;
1517
import java.lang.invoke.MethodHandles;
@@ -74,6 +76,7 @@ record Hash(int mapIndex, List<Guarded> entries) implements ProviderDispatch {}
7476
private static final String DIRECTION_DESC = "Lnet/minecraft/core/Direction;";
7577
private static final String MAP_DESC = "Ljava/util/Map;";
7678
private static final String MAP_SIGNATURE = "Ljava/util/Map<Lnet/minecraftforge/common/capabilities/Capability<*>;Lnet/minecraftforge/common/capabilities/ICapabilityProvider;>;";
79+
private static final String LOOKUP_DESC = "Ljava/lang/invoke/MethodHandles$Lookup;";
7780

7881
/**
7982
* Gets or generates a constructor MethodHandle for the given capability provider types.
@@ -117,6 +120,14 @@ private static MethodHandle generateClass(List<Class<? extends ICapabilityProvid
117120
int generatedClassId = classCounter.incrementAndGet();
118121
String className = "org.embeddedt.modernfix.forge.capability.CapabilityDispatcher$Generated$" + generatedClassId;
119122

123+
List<ProviderDispatch> dispatches = optimizeDispatches(buildDispatchList(providerTypes, analysisResults));
124+
125+
// Assign a stable index to every unique CapabilityRef across all dispatches.
126+
// We resolve the actual Capability<?> instances here (in Java) so the generated
127+
// <clinit> only needs simple classDataAt calls - no reflection bytecode needed.
128+
LinkedHashMap<CapabilityRef, Integer> capRefIndices = collectCapabilityRefs(dispatches);
129+
List<Capability<?>> capValues = resolveCapabilityValues(capRefIndices);
130+
120131
ModernFix.LOGGER.debug("Generating capability dispatcher #{} for types: [{}]", () -> generatedClassId, () -> {
121132
StringBuilder sb = new StringBuilder();
122133
for (int i = 0; i < providerTypes.size(); i++) {
@@ -126,11 +137,14 @@ private static MethodHandle generateClass(List<Class<? extends ICapabilityProvid
126137
return sb;
127138
});
128139

129-
byte[] classBytes = generateClassBytes(className, providerTypes, analysisResults);
140+
byte[] classBytes = generateClassBytes(className, providerTypes, dispatches, capRefIndices);
130141

131-
// Define the hidden class
132-
MethodHandles.Lookup hiddenLookup = lookup.defineHiddenClass(
142+
// Define the hidden class, injecting the resolved Capability instances as class data.
143+
// The generated <clinit> retrieves them via MethodHandles.classDataAt so it never
144+
// needs to perform reflection itself - private fields are handled transparently here.
145+
MethodHandles.Lookup hiddenLookup = lookup.defineHiddenClassWithClassData(
133146
classBytes,
147+
capValues,
134148
true,
135149
MethodHandles.Lookup.ClassOption.NESTMATE
136150
);
@@ -154,6 +168,47 @@ private static MethodHandle generateClass(List<Class<? extends ICapabilityProvid
154168
}
155169
}
156170

171+
/**
172+
* Collects all unique {@link CapabilityRef}s referenced by {@code dispatches} in encounter order,
173+
* assigning each a stable list index for use with {@code classDataAt}.
174+
*/
175+
private static LinkedHashMap<CapabilityRef, Integer> collectCapabilityRefs(List<ProviderDispatch> dispatches) {
176+
LinkedHashMap<CapabilityRef, Integer> result = new LinkedHashMap<>();
177+
for (ProviderDispatch dispatch : dispatches) {
178+
if (dispatch instanceof ProviderDispatch.Guarded g) {
179+
result.putIfAbsent(g.capability(), result.size());
180+
} else if (dispatch instanceof ProviderDispatch.Hash hash) {
181+
for (ProviderDispatch.Guarded g : hash.entries()) {
182+
result.putIfAbsent(g.capability(), result.size());
183+
}
184+
}
185+
}
186+
return result;
187+
}
188+
189+
/**
190+
* Resolves the actual {@link Capability} instances for all refs at class-generation time.
191+
* Uses reflection (with {@code setAccessible}) so private fields are handled without any
192+
* reflection bytecode appearing in the generated class.
193+
*/
194+
private static List<Capability<?>> resolveCapabilityValues(LinkedHashMap<CapabilityRef, Integer> capRefIndices) {
195+
@SuppressWarnings("unchecked")
196+
Capability<?>[] caps = new Capability[capRefIndices.size()];
197+
for (Map.Entry<CapabilityRef, Integer> entry : capRefIndices.entrySet()) {
198+
CapabilityRef ref = entry.getKey();
199+
try {
200+
Class<?> clazz = Class.forName(ref.owner().replace('/', '.'), false,
201+
CapabilityProviderDispatcherGenerator.class.getClassLoader());
202+
Field field = clazz.getDeclaredField(ref.fieldName());
203+
field.setAccessible(true);
204+
caps[entry.getValue()] = (Capability<?>) field.get(null);
205+
} catch (ReflectiveOperationException e) {
206+
throw new RuntimeException("Failed to resolve capability field " + ref, e);
207+
}
208+
}
209+
return Arrays.asList(caps);
210+
}
211+
157212
/**
158213
* Build the dispatch list describing how each provider should be handled.
159214
*/
@@ -261,21 +316,22 @@ private static LinkedHashMap<Integer, String> collectProviderFields(List<Provide
261316
return fields;
262317
}
263318

264-
private static byte[] generateClassBytes(String className, List<Class<? extends ICapabilityProvider>> providerTypes, List<CapabilityAnalysisResult> analysisResults) {
265-
List<ProviderDispatch> dispatches = optimizeDispatches(buildDispatchList(providerTypes, analysisResults));
266-
319+
private static byte[] generateClassBytes(String className, List<Class<? extends ICapabilityProvider>> providerTypes,
320+
List<ProviderDispatch> dispatches, LinkedHashMap<CapabilityRef, Integer> capRefIndices) {
267321
ClassWriter cw = new ClassWriter(ClassWriter.COMPUTE_FRAMES | ClassWriter.COMPUTE_MAXS) {
268322
@Override
269323
protected ClassLoader getClassLoader() {
270324
return CapabilityProviderDispatcherGenerator.class.getClassLoader();
271325
}
272326
};
273327

328+
String internalName = className.replace('.', '/');
329+
274330
// Class declaration: implements ICapabilityProvider
275331
cw.visit(
276332
V17,
277333
ACC_PUBLIC | ACC_FINAL | ACC_SUPER,
278-
className.replace('.', '/'),
334+
internalName,
279335
null,
280336
"java/lang/Object",
281337
new String[] { "net/minecraftforge/common/capabilities/ICapabilityProvider" }
@@ -301,17 +357,74 @@ protected ClassLoader getClassLoader() {
301357
}
302358
}
303359

360+
// Generate one static final field per unique CapabilityRef.
361+
// These are populated in <clinit> via MethodHandles.classDataAt, which reads the
362+
// Capability<?> instances injected by defineHiddenClassWithClassData. This avoids
363+
// any reflection bytecode in the generated class and handles private fields transparently.
364+
for (Map.Entry<CapabilityRef, Integer> entry : capRefIndices.entrySet()) {
365+
cw.visitField(ACC_PRIVATE | ACC_STATIC | ACC_FINAL,
366+
capRefFieldName(entry.getValue()), CAPABILITY_DESC, null, null).visitEnd();
367+
}
368+
369+
// Generate <clinit> to load capability instances from class data
370+
if (!capRefIndices.isEmpty()) {
371+
generateClinit(cw, internalName, capRefIndices);
372+
}
373+
304374
// Generate constructor
305-
generateConstructor(cw, className, providerFields, dispatches);
375+
generateConstructor(cw, className, providerFields, dispatches, capRefIndices);
306376

307377
// Generate getCapability method with sided parameter
308-
generateGetCapabilityMethod(cw, className, dispatches);
378+
generateGetCapabilityMethod(cw, className, dispatches, capRefIndices);
309379

310380
cw.visitEnd();
311381
return cw.toByteArray();
312382
}
313383

314-
private static void generateConstructor(ClassWriter cw, String className, Map<Integer, String> providerFields, List<ProviderDispatch> dispatches) {
384+
private static String capRefFieldName(int index) {
385+
return "capRef" + index;
386+
}
387+
388+
/**
389+
* Generates {@code <clinit>} that loads each capability from class data injected at define time.
390+
* The bytecode is simply: {@code capRefN = MethodHandles.classDataAt(lookup(), "", Capability.class, N)}.
391+
*/
392+
private static void generateClinit(ClassWriter cw, String internalName, LinkedHashMap<CapabilityRef, Integer> capRefIndices) {
393+
MethodVisitor mv = cw.visitMethod(ACC_STATIC, "<clinit>", "()V", null, null);
394+
mv.visitCode();
395+
396+
for (int i = 0; i < capRefIndices.size(); i++) {
397+
// MethodHandles.lookup()
398+
mv.visitMethodInsn(INVOKESTATIC, "java/lang/invoke/MethodHandles", "lookup",
399+
"()" + LOOKUP_DESC, false);
400+
// "_" (classDataAt requires this exact name)
401+
mv.visitLdcInsn("_");
402+
// Capability.class
403+
mv.visitLdcInsn(Type.getType(CAPABILITY_DESC));
404+
// index
405+
mv.visitLdcInsn(i);
406+
// MethodHandles.classDataAt(lookup, name, type, index) → Object
407+
mv.visitMethodInsn(INVOKESTATIC, "java/lang/invoke/MethodHandles", "classDataAt",
408+
"(" + LOOKUP_DESC + "Ljava/lang/String;Ljava/lang/Class;I)Ljava/lang/Object;", false);
409+
mv.visitTypeInsn(CHECKCAST, "net/minecraftforge/common/capabilities/Capability");
410+
mv.visitFieldInsn(PUTSTATIC, internalName, capRefFieldName(i), CAPABILITY_DESC);
411+
}
412+
413+
mv.visitInsn(RETURN);
414+
mv.visitMaxs(0, 0);
415+
mv.visitEnd();
416+
}
417+
418+
/**
419+
* Emits a load of the capability constant for {@code ref} from the generated class's own static field.
420+
*/
421+
private static void emitCapabilityLoad(MethodVisitor mv, String internalName, CapabilityRef ref,
422+
Map<CapabilityRef, Integer> capRefIndices) {
423+
mv.visitFieldInsn(GETSTATIC, internalName, capRefFieldName(capRefIndices.get(ref)), CAPABILITY_DESC);
424+
}
425+
426+
private static void generateConstructor(ClassWriter cw, String className, Map<Integer, String> providerFields,
427+
List<ProviderDispatch> dispatches, Map<CapabilityRef, Integer> capRefIndices) {
315428
Method constructor = Method.getMethod("void <init>(net.minecraftforge.common.capabilities.ICapabilityProvider[])");
316429
GeneratorAdapter mg = new GeneratorAdapter(ACC_PUBLIC, constructor, null, null, cw);
317430
Type classType = Type.getObjectType(className.replace('.', '/'));
@@ -338,15 +451,16 @@ private static void generateConstructor(ClassWriter cw, String className, Map<In
338451
// Build hash maps
339452
for (ProviderDispatch dispatch : dispatches) {
340453
if (dispatch instanceof ProviderDispatch.Hash hash) {
341-
generateMapConstruction(mg, classType, hash);
454+
generateMapConstruction(mg, classType, hash, capRefIndices);
342455
}
343456
}
344457

345458
mg.returnValue();
346459
mg.endMethod();
347460
}
348461

349-
private static void generateMapConstruction(GeneratorAdapter mg, Type classType, ProviderDispatch.Hash hash) {
462+
private static void generateMapConstruction(GeneratorAdapter mg, Type classType, ProviderDispatch.Hash hash,
463+
Map<CapabilityRef, Integer> capRefIndices) {
350464
List<ProviderDispatch.Guarded> entries = hash.entries();
351465
mg.loadThis(); // for PUTFIELD at the end
352466

@@ -356,7 +470,7 @@ private static void generateMapConstruction(GeneratorAdapter mg, Type classType,
356470
ProviderDispatch.Guarded g = entries.get(i);
357471
mg.dup();
358472
mg.push(i);
359-
mg.visitFieldInsn(GETSTATIC, g.capability().owner(), g.capability().fieldName(), CAPABILITY_DESC);
473+
emitCapabilityLoad(mg, classType.getInternalName(), g.capability(), capRefIndices);
360474
mg.loadArg(0);
361475
mg.push(g.providerIndex());
362476
mg.arrayLoad(Type.getType(ICAP_PROVIDER_DESC));
@@ -370,7 +484,8 @@ private static void generateMapConstruction(GeneratorAdapter mg, Type classType,
370484
mg.putField(classType, "capMap" + hash.mapIndex(), Type.getType(MAP_DESC));
371485
}
372486

373-
private static void generateGetCapabilityMethod(ClassWriter cw, String className, List<ProviderDispatch> dispatches) {
487+
private static void generateGetCapabilityMethod(ClassWriter cw, String className, List<ProviderDispatch> dispatches,
488+
Map<CapabilityRef, Integer> capRefIndices) {
374489
// Method: <T> LazyOptional<T> getCapability(Capability<T>, Direction)
375490
MethodVisitor mv = cw.visitMethod(
376491
ACC_PUBLIC,
@@ -401,7 +516,7 @@ private static void generateGetCapabilityMethod(ClassWriter cw, String className
401516
emitHashDispatch(mv, internalName, getCapDesc, hash, nextLabel);
402517
di++;
403518
} else if (dispatch instanceof ProviderDispatch.Guarded) {
404-
di = emitGuardedDispatch(mv, internalName, getCapDesc, dispatches, di, nextLabel);
519+
di = emitGuardedDispatch(mv, internalName, getCapDesc, dispatches, di, nextLabel, capRefIndices);
405520
} else {
406521
var u = (ProviderDispatch.Unguarded) dispatch;
407522
emitProviderGetCapability(mv, internalName, getCapDesc, u.providerIndex(), u.fieldDesc());
@@ -490,7 +605,8 @@ private static void emitHashDispatch(MethodVisitor mv, String internalName, Stri
490605
* @return the updated dispatch index (past the consumed group)
491606
*/
492607
private static int emitGuardedDispatch(MethodVisitor mv, String internalName, String getCapDesc,
493-
List<ProviderDispatch> dispatches, int di, Label nextLabel) {
608+
List<ProviderDispatch> dispatches, int di, Label nextLabel,
609+
Map<CapabilityRef, Integer> capRefIndices) {
494610
var guarded = (ProviderDispatch.Guarded) dispatches.get(di);
495611

496612
// Peek ahead to collect consecutive Guarded entries with same providerIndex
@@ -507,7 +623,7 @@ private static int emitGuardedDispatch(MethodVisitor mv, String internalName, St
507623
var g = (ProviderDispatch.Guarded) dispatches.get(gi);
508624
CapabilityRef ref = g.capability();
509625
mv.visitVarInsn(ALOAD, 1);
510-
mv.visitFieldInsn(GETSTATIC, ref.owner(), ref.fieldName(), CAPABILITY_DESC);
626+
emitCapabilityLoad(mv, internalName, ref, capRefIndices);
511627
if (gi < groupEnd - 1) {
512628
mv.visitJumpInsn(IF_ACMPEQ, matchLabel);
513629
} else {
@@ -544,4 +660,4 @@ private static String formatAnalysisResult(CapabilityAnalysisResult result) {
544660
}
545661
return result.toString();
546662
}
547-
}
663+
}

0 commit comments

Comments
 (0)