3131import org .objectweb .asm .Handle ;
3232import org .objectweb .asm .Opcodes ;
3333import org .objectweb .asm .Type ;
34+ import org .objectweb .asm .tree .ClassNode ;
3435import org .objectweb .asm .tree .FieldInsnNode ;
3536import org .objectweb .asm .tree .FieldNode ;
37+ import org .objectweb .asm .tree .FrameNode ;
3638import org .objectweb .asm .tree .InnerClassNode ;
3739import org .objectweb .asm .tree .InsnNode ;
3840import org .objectweb .asm .tree .InvokeDynamicInsnNode ;
41+ import org .objectweb .asm .tree .JumpInsnNode ;
42+ import org .objectweb .asm .tree .LabelNode ;
3943import org .objectweb .asm .tree .MethodInsnNode ;
4044import org .objectweb .asm .tree .MethodNode ;
4145import org .objectweb .asm .tree .TypeInsnNode ;
4751import java .util .Set ;
4852
4953// TODO ASM Logging
54+
5055public class Threading_ThreadSafeBlockRendererInjector implements TurboClassTransformer {
5156 private static final Set <String > CLASS_NAMES = new HashSet <>();
5257 private static final Set <String > INTERNAL_NAMES = new HashSet <>();
5358 private static final Map <String , Handle > INITIALIZERS = new HashMap <>();
5459 private static final Map <String , String > SUPPLIERS = new HashMap <>();
60+ private static final Set <String > FACTORIES = new HashSet <>();
5561 private static final String TSBR_InternalName = "com/falsepattern/falsetweaks/api/threading/ThreadSafeBlockRenderer" ;
5662 private static final String ISBR_InternalName = "cpw/mods/fml/client/registry/ISimpleBlockRenderingHandler" ;
5763 private static final Handle LAMBDA_META_FACTORY = new Handle (Opcodes .H_INVOKESTATIC , "java/lang/invoke/LambdaMetafactory" , "metafactory" ,
5864 "(Ljava/lang/invoke/MethodHandles$Lookup;Ljava/lang/String;Ljava/lang/invoke/MethodType;Ljava/lang/invoke/MethodType;Ljava/lang/invoke/MethodHandle;Ljava/lang/invoke/MethodType;)Ljava/lang/invoke/CallSite;" );
5965
66+ public static final String THREAD_SAFE_ANNOTATION_InternalName = "com/falsepattern/falsetweaks/modules/threadedupdates/interop/ThreadSafeISBRH" ;
67+ public static final String THREAD_SAFE_ANNOTATION_DESC = "Lcom/falsepattern/falsetweaks/modules/threadedupdates/interop/ThreadSafeISBRH;" ;
68+ public static final String THREAD_SAFE_FACTORY_InternalName = "com/falsepattern/falsetweaks/modules/threadedupdates/interop/ThreadSafeISBRHFactory" ;
69+
70+ public static final String FACTORY_METHOD_DESC = "()L" + THREAD_SAFE_FACTORY_InternalName + ";" ;
71+ public static final String FACTORY_METHOD_NAME = "newInstance" ;
72+
6073 private static final String [] HARDCODED = new String [] {
6174 "com.carpentersblocks.renderer.BlockHandlerCarpentersBarrier:default!" ,
6275 "com.carpentersblocks.renderer.BlockHandlerCarpentersBed:default!" ,
@@ -101,7 +114,7 @@ public static void addAll(String... entries) {
101114 creatorHandle = new Handle (Opcodes .H_NEWINVOKESPECIAL , internalName , "<init>" , "()V" );
102115 } else {
103116 val genParts = generator .split ("!" );
104- val genInternalName = internalName .replace ('.' , '/' );
117+ val genInternalName = genParts [ 0 ] .replace ('.' , '/' );
105118 val genMethodName = genParts [1 ];
106119 creatorHandle = new Handle (Opcodes .H_INVOKESTATIC , genInternalName , genMethodName , "()L" + internalName + ";" );
107120 }
@@ -128,7 +141,47 @@ public String name() {
128141
129142 @ Override
130143 public boolean shouldTransformClass (@ NotNull String className , @ NotNull ClassNodeHandle classNode ) {
131- return CLASS_NAMES .contains (className );
144+ if (CLASS_NAMES .contains (className )) {
145+ return true ;
146+ } else {
147+ val node = classNode .getNode ();
148+ if (node == null )
149+ return false ;
150+ val anns = node .visibleAnnotations ;
151+ if (anns == null )
152+ return false ;
153+ val internalName = className .replace ('.' , '/' );
154+ for (val ann : anns ) {
155+ if (THREAD_SAFE_ANNOTATION_DESC .equals (ann .desc )) {
156+ boolean perThread = false ;
157+ val values = ann .values ;
158+ if (values != null ) {
159+ val iter = values .iterator ();
160+ while (iter .hasNext ()) {
161+ val name = iter .next ();
162+ val value = iter .next ();
163+ if ("perThread" .equals (name )) {
164+ perThread = (Boolean ) value ;
165+ }
166+ }
167+ }
168+ if (perThread ) {
169+ INITIALIZERS .put (internalName , new Handle (Opcodes .H_NEWINVOKESPECIAL , internalName , "<init>" , "()V" ));
170+ }
171+ return true ;
172+ }
173+ }
174+ val ifcs = node .interfaces ;
175+ if (ifcs == null )
176+ return false ;
177+ for (val ifc : ifcs ) {
178+ if (THREAD_SAFE_FACTORY_InternalName .equals (ifc )) {
179+ FACTORIES .add (className .replace ('.' , '/' ));
180+ return true ;
181+ }
182+ }
183+ }
184+ return false ;
132185 }
133186
134187 @ Override
@@ -141,52 +194,80 @@ public boolean transformClass(@NotNull String className, @NotNull ClassNodeHandl
141194
142195 val internalName = className .replace ('.' , '/' );
143196 cn .interfaces .add (TSBR_InternalName );
144- if (INITIALIZERS .containsKey (internalName )) {
145- cn .innerClasses .add (new InnerClassNode ("java/lang/invoke/MethodHandles$Lookup" , "java/lang/invoke/MethodHandles" , "Lookup" , Opcodes .ACC_PUBLIC | Opcodes .ACC_STATIC | Opcodes .ACC_FINAL ));
146- cn .fields .add (new FieldNode (Opcodes .ACC_PRIVATE | Opcodes .ACC_STATIC | Opcodes .ACC_FINAL , "ft$tlInjected" , "Ljava/lang/ThreadLocal;" , null , null ));
147- boolean staticInitializedFound = false ;
148- for (val method : cn .methods ) {
149- if (!"<clinit>" .equals (method .name ))
150- continue ;
151- staticInitializedFound = true ;
152- injectInstanceCreation (method , internalName );
153- }
154- if (!staticInitializedFound ) {
155- val clinit = new MethodNode (Opcodes .ACC_STATIC , "<clinit>" , "()V" , null , null );
156- cn .methods .add (clinit );
157- injectInstanceCreation (clinit , internalName );
158- clinit .instructions .add (new InsnNode (Opcodes .RETURN ));
159- }
160- }
161197 val getter = new MethodNode (Opcodes .ACC_PUBLIC , "forCurrentThread" , "()L" + ISBR_InternalName + ";" , null , null );
162198 cn .methods .add (getter );
163199 val insnList = getter .instructions ;
164- if (SUPPLIERS .containsKey (internalName )) {
165- val supplier = SUPPLIERS .get (internalName );
166- val parts = supplier .split ("\\ ?" );
167- insnList .add (new MethodInsnNode (Opcodes .INVOKESTATIC , parts [0 ], parts [1 ], "()L" + internalName + ";" , false ));
168- } else if (INITIALIZERS .containsKey (internalName )) {
200+ if (INITIALIZERS .containsKey (internalName )) {
201+ injectThreadLocal (cn , internalName , true );
169202 insnList .add (new FieldInsnNode (Opcodes .GETSTATIC , internalName , "ft$tlInjected" , "Ljava/lang/ThreadLocal;" ));
170203 insnList .add (new MethodInsnNode (Opcodes .INVOKEVIRTUAL , "java/lang/ThreadLocal" , "get" , "()Ljava/lang/Object;" , false ));
171204 insnList .add (new TypeInsnNode (Opcodes .CHECKCAST , ISBR_InternalName ));
205+ getter .maxStack = 1 ;
206+ } else if (FACTORIES .contains (internalName )) {
207+ injectThreadLocal (cn , internalName , false );
208+ insnList .add (new FieldInsnNode (Opcodes .GETSTATIC , internalName , "ft$tlInjected" , "Ljava/lang/ThreadLocal;" ));
209+ insnList .add (new MethodInsnNode (Opcodes .INVOKEVIRTUAL , "java/lang/ThreadLocal" , "get" , "()Ljava/lang/Object;" , false ));
210+ insnList .add (new InsnNode (Opcodes .DUP ));
211+ val nonNull = new LabelNode ();
212+ insnList .add (new JumpInsnNode (Opcodes .IFNONNULL , nonNull ));
213+ insnList .add (new InsnNode (Opcodes .POP ));
214+ insnList .add (new FieldInsnNode (Opcodes .GETSTATIC , internalName , "ft$tlInjected" , "Ljava/lang/ThreadLocal;" ));
215+ insnList .add (new VarInsnNode (Opcodes .ALOAD , 0 ));
216+ insnList .add (new TypeInsnNode (Opcodes .CHECKCAST , THREAD_SAFE_FACTORY_InternalName ));
217+ insnList .add (new MethodInsnNode (Opcodes .INVOKEINTERFACE , THREAD_SAFE_FACTORY_InternalName , FACTORY_METHOD_NAME , FACTORY_METHOD_DESC , true ));
218+ insnList .add (new InsnNode (Opcodes .DUP_X1 ));
219+ insnList .add (new MethodInsnNode (Opcodes .INVOKEVIRTUAL , "java/lang/ThreadLocal" , "set" , "(Ljava/lang/Object;)V" , false ));
220+ insnList .add (nonNull );
221+ insnList .add (new FrameNode (Opcodes .F_SAME1 , 0 , null , 1 , new Object []{"java/lang/Object" }));
222+ insnList .add (new TypeInsnNode (Opcodes .CHECKCAST , ISBR_InternalName ));
223+ getter .maxStack = 3 ;
224+ } else if (SUPPLIERS .containsKey (internalName )) {
225+ val supplier = SUPPLIERS .get (internalName );
226+ val parts = supplier .split ("\\ ?" );
227+ insnList .add (new MethodInsnNode (Opcodes .INVOKESTATIC , parts [0 ], parts [1 ], "()L" + internalName + ";" , false ));
228+ getter .maxStack = 1 ;
172229 } else {
173230 insnList .add (new VarInsnNode (Opcodes .ALOAD , 0 ));
231+ getter .maxStack = 1 ;
174232 }
175233 insnList .add (new InsnNode (Opcodes .ARETURN ));
176- getter .maxStack = 1 ;
177234 getter .maxLocals = 1 ;
178235 return true ;
179236 }
180237
238+ private void injectThreadLocal (ClassNode cn , String internalName , boolean withInitial ) {
239+ if (withInitial ) {
240+ cn .innerClasses .add (new InnerClassNode ("java/lang/invoke/MethodHandles$Lookup" , "java/lang/invoke/MethodHandles" , "Lookup" , Opcodes .ACC_PUBLIC | Opcodes .ACC_STATIC | Opcodes .ACC_FINAL ));
241+ }
242+ cn .fields .add (new FieldNode (Opcodes .ACC_PRIVATE | Opcodes .ACC_STATIC | Opcodes .ACC_FINAL , "ft$tlInjected" , "Ljava/lang/ThreadLocal;" , null , null ));
243+ boolean staticInitializedFound = false ;
244+ for (val method : cn .methods ) {
245+ if (!"<clinit>" .equals (method .name ))
246+ continue ;
247+ staticInitializedFound = true ;
248+ injectThreadLocalCreation (method , internalName , withInitial );
249+ }
250+ if (!staticInitializedFound ) {
251+ val clinit = new MethodNode (Opcodes .ACC_STATIC , "<clinit>" , "()V" , null , null );
252+ clinit .instructions .add (new FrameNode (Opcodes .F_FULL , 0 , new Object [0 ], 0 , new Object [0 ]));
253+ cn .methods .add (clinit );
254+ injectThreadLocalCreation (clinit , internalName , withInitial );
255+ clinit .instructions .add (new InsnNode (Opcodes .RETURN ));
256+ }
257+ }
181258
182- private void injectInstanceCreation (MethodNode method , String internalName ) {
259+ private void injectThreadLocalCreation (MethodNode method , String internalName , boolean withInitial ) {
183260 val insnList = method .instructions .iterator ();
184- insnList .add (new InvokeDynamicInsnNode ("get" , "()Ljava/util/function/Supplier;" ,
185- LAMBDA_META_FACTORY ,
186- Type .getType ("()Ljava/lang/Object;" ),
187- INITIALIZERS .get (internalName ),
188- Type .getType ("()L" + internalName + ";" )));
189- insnList .add (new MethodInsnNode (Opcodes .INVOKESTATIC , "java/lang/ThreadLocal" , "withInitial" , "(Ljava/util/function/Supplier;)Ljava/lang/ThreadLocal;" , false ));
261+ if (withInitial ) {
262+ insnList .add (
263+ new InvokeDynamicInsnNode ("get" , "()Ljava/util/function/Supplier;" , LAMBDA_META_FACTORY , Type .getType ("()Ljava/lang/Object;" ), INITIALIZERS .get (internalName ),
264+ Type .getType ("()L" + internalName + ";" )));
265+ insnList .add (new MethodInsnNode (Opcodes .INVOKESTATIC , "java/lang/ThreadLocal" , "withInitial" , "(Ljava/util/function/Supplier;)Ljava/lang/ThreadLocal;" , false ));
266+ } else {
267+ insnList .add (new TypeInsnNode (Opcodes .NEW , "java/lang/ThreadLocal" ));
268+ insnList .add (new InsnNode (Opcodes .DUP ));
269+ insnList .add (new MethodInsnNode (Opcodes .INVOKESPECIAL , "java/lang/ThreadLocal" , "<init>" , "()V" , false ));
270+ }
190271 insnList .add (new FieldInsnNode (Opcodes .PUTSTATIC , internalName , "ft$tlInjected" , "Ljava/lang/ThreadLocal;" ));
191272 if (method .maxStack == 0 ) {
192273 method .maxStack = 1 ;
0 commit comments