Skip to content

Commit 6f99549

Browse files
committed
fix: JavaCallbacks instance must be destroyed _after_ lua_close.
lua_close will destroy any active threads, and the JavaCallbacks instance must be alive during thread destruction.
1 parent eaa531f commit 6f99549

2 files changed

Lines changed: 91 additions & 72 deletions

File tree

src/main/java/net/hollowcube/luau/LuaStateImpl.java

Lines changed: 55 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
import java.util.*;
2020
import java.util.regex.Pattern;
2121

22+
import static net.hollowcube.luau.LuaCallbacksImpl.JavaCallbacks.fromCallbacks;
2223
import static net.hollowcube.luau.internal.vm.lua_h.*;
2324
import static net.hollowcube.luau.internal.vm.lualib_h.*;
2425
import static net.hollowcube.luau.internal.vm.luawrap_h.*;
@@ -39,9 +40,9 @@ record LuaStateImpl(MemorySegment L) implements LuaState {
3940
/// no affect on the runtime behavior.
4041
private static final String ASSERT_HANDLER = System.getProperty("luau.assert-handler");
4142
private static final boolean SHOW_COMPLETE_BACKTRACE =
42-
Boolean.getBoolean("luau.show-complete-backtrace");
43+
Boolean.getBoolean("luau.show-complete-backtrace");
4344
private static final boolean NO_BACKTRACE_MERGE =
44-
Boolean.getBoolean("luau.no-backtrace-merge");
45+
Boolean.getBoolean("luau.no-backtrace-merge");
4546

4647
static {
4748
NativeLibraryLoader.loadLibrary("vm");
@@ -61,22 +62,22 @@ record LuaStateImpl(MemorySegment L) implements LuaState {
6162
static final int GLOBALS_INDEX = LUA_GLOBALSINDEX();
6263

6364
private static final Pattern DEFAULT_ERROR_TRACE_REGEX = Pattern.compile("^\\[string \"" +
64-
".*?\"]:\\d+:\\s");
65+
".*?\"]:\\d+:\\s");
6566

6667
private static final MemorySegment UNTAGGED_UDATA_DTOR = luaW_newuserdatadtor$dtor.allocate(
67-
ud -> GlobalRef.unref(ud.get(ValueLayout.JAVA_LONG, 0)),
68-
Arena.global());
68+
ud -> GlobalRef.unref(ud.get(ValueLayout.JAVA_LONG, 0)),
69+
Arena.global());
6970
private static final MemorySegment TAGGED_UDATA_DTOR = lua_Destructor.allocate(
70-
(_, ud) -> GlobalRef.unref(ud.get(ValueLayout.JAVA_LONG, 0)),
71-
Arena.global());
71+
(_, ud) -> GlobalRef.unref(ud.get(ValueLayout.JAVA_LONG, 0)),
72+
Arena.global());
7273
private static final MemorySegment PCALL_ERRFUNC_REF = lua_CFunction.allocate(
73-
(L) -> pcallErrFunc(new LuaStateImpl(L)),
74-
Arena.global());
74+
(L) -> pcallErrFunc(new LuaStateImpl(L)),
75+
Arena.global());
7576
private static final MemorySegment USERTHREAD_CALLBACK = lua_Callbacks.userthread.allocate(
76-
(LP, L) -> userThreadCallback(
77-
LP.equals(MemorySegment.NULL) ? null : new LuaStateImpl(LP),
78-
new LuaStateImpl(L)),
79-
Arena.global());
77+
(LP, L) -> userThreadCallback(
78+
LP.equals(MemorySegment.NULL) ? null : new LuaStateImpl(LP),
79+
new LuaStateImpl(L)),
80+
Arena.global());
8081
private static final MemorySegment LUA_DEBUG_WHAT = Arena.global().allocateFrom("sln");
8182

8283
static LuaState newState(@Nullable MemorySegment allocator) {
@@ -106,12 +107,15 @@ public void close() {
106107
// Remove our reference to a threaddata object
107108
setThreadData(null);
108109

109-
// Remove our reference to the JavaCallbacks object
110+
// Get our reference to the JavaCallbacks object, but keep it so closing threads have it.
110111
final MemorySegment callbacks = lua_callbacks(L);
111-
GlobalRef.unref(lua_Callbacks.userdata(callbacks).address());
112+
final MemorySegment javaCallbacks = lua_Callbacks.userdata(callbacks);
112113

113114
// Finally, close the lua state itself.
114115
lua_close(L);
116+
117+
// Destroy the ref
118+
GlobalRef.unref(javaCallbacks.address());
115119
}
116120

117121
//TODO: test me
@@ -369,9 +373,9 @@ public long toUnsigned(int index) {
369373
public float @Nullable [] toVector(int index) {
370374
final MemorySegment value = lua_tovector(L, index);
371375
return value.equals(MemorySegment.NULL) ? null : new float[]{
372-
value.getAtIndex(ValueLayout.JAVA_FLOAT, 0),
373-
value.getAtIndex(ValueLayout.JAVA_FLOAT, 1),
374-
value.getAtIndex(ValueLayout.JAVA_FLOAT, 2),
376+
value.getAtIndex(ValueLayout.JAVA_FLOAT, 0),
377+
value.getAtIndex(ValueLayout.JAVA_FLOAT, 1),
378+
value.getAtIndex(ValueLayout.JAVA_FLOAT, 2),
375379
};
376380
}
377381

@@ -431,7 +435,7 @@ public short toStringAtomRaw(int index) {
431435
if (atom >= 0) return new LuaString.Atom(atom);
432436

433437
byte[] text = str.reinterpret(lenRef.get(ValueLayout.JAVA_INT, 0))
434-
.toArray(ValueLayout.JAVA_BYTE);
438+
.toArray(ValueLayout.JAVA_BYTE);
435439
return new LuaString.Str(new String(text, StandardCharsets.UTF_8));
436440
}
437441
}
@@ -576,15 +580,15 @@ public void pushLightUserData(long value) {
576580
public void pushLightUserDataTagged(long value, int tag) {
577581
if (tag < 0 || tag > LIGHT_USERDATA_TAG_LIMIT)
578582
throw new LuaError(
579-
"light userdata tag must be between 0 and " + LIGHT_USERDATA_TAG_LIMIT);
583+
"light userdata tag must be between 0 and " + LIGHT_USERDATA_TAG_LIMIT);
580584
lua_pushlightuserdatatagged(L, MemorySegment.ofAddress(value), tag);
581585
}
582586

583587
@Override
584588
public void newUserData(Object value) {
585589
final MemorySegment ud = luaW_newuserdatadtor(L,
586-
ValueLayout.JAVA_LONG.byteSize(),
587-
UNTAGGED_UDATA_DTOR);
590+
ValueLayout.JAVA_LONG.byteSize(),
591+
UNTAGGED_UDATA_DTOR);
588592
propagateException();
589593
ud.set(ValueLayout.JAVA_LONG, 0, GlobalRef.newref(value));
590594
}
@@ -599,7 +603,7 @@ public void newUserDataTagged(Object value, int tag) {
599603
@Override
600604
public void newUserDataTaggedWithMetatable(Object value, int tag) {
601605
final MemorySegment ud = luaW_newuserdatataggedwithmetatable(L,
602-
ValueLayout.JAVA_LONG.byteSize(), tag);
606+
ValueLayout.JAVA_LONG.byteSize(), tag);
603607
propagateException();
604608
ud.set(ValueLayout.JAVA_LONG, 0, GlobalRef.newref(value));
605609
}
@@ -622,7 +626,7 @@ public void pushFunction(LuaFunc func) {
622626
// The switch is here as an exhaustivity check :)
623627
switch (func) {
624628
case LuaFuncImpl(
625-
MemorySegment funcRef, MemorySegment debugNameRef, _
629+
MemorySegment funcRef, MemorySegment debugNameRef, _
626630
) -> luaW_pushcclosurek(L, funcRef, debugNameRef, 0, MemorySegment.NULL);
627631
}
628632
}
@@ -744,7 +748,7 @@ public void load(String chunkName, byte[] data) {
744748
final MemorySegment chunkNameRef = arena.allocateFrom(chunkName);
745749
final MemorySegment bytecodeRef = arena.allocateFrom(ValueLayout.JAVA_BYTE, data);
746750
final LuaStatus status = LuaStatus.byId(luau_load(L, chunkNameRef, bytecodeRef,
747-
data.length, 0));
751+
data.length, 0));
748752
if (status != LuaStatus.OK) {
749753
final String message = toString(-1);
750754
throw new LuaError(status, message);
@@ -816,8 +820,8 @@ public void setThreadData(@Nullable Object data) {
816820
GlobalRef.unref(oldRef.address());
817821

818822
lua_setthreaddata(L, data != null
819-
? MemorySegment.ofAddress(GlobalRef.newref(data))
820-
: MemorySegment.NULL);
823+
? MemorySegment.ofAddress(GlobalRef.newref(data))
824+
: MemorySegment.NULL);
821825
}
822826

823827
//TODO: test me
@@ -880,7 +884,7 @@ public boolean checkStack(int size) {
880884
public void checkStack(int size, @Nullable String message) {
881885
if (!checkStack(size)) {
882886
final String msg = message != null ? "stack overflow (%s)".formatted(message) :
883-
"stack overflow";
887+
"stack overflow";
884888
throw new LuaError(msg);
885889
}
886890
}
@@ -1258,8 +1262,7 @@ private static void userThreadCallback(@Nullable LuaStateImpl parent, LuaState t
12581262

12591263
// Always call the java handle if set.
12601264
final MemorySegment threadL = ((LuaStateImpl) thread).L;
1261-
final LuaCallbacks.UserThread callback = JavaCallbacks.fromCallbacks(
1262-
lua_callbacks(threadL)).userThread;
1265+
final LuaCallbacks.UserThread callback = fromCallbacks(lua_callbacks(threadL)).userThread;
12631266
if (callback != null) callback.userThread(parent, thread);
12641267
}
12651268

@@ -1306,8 +1309,8 @@ private static int pcallErrFunc(LuaState state) {
13061309
}
13071310

13081311
static StackTraceElement[] mergeBacktrace(
1309-
LuaState state, StackTraceElement[] javaTrace,
1310-
boolean startInLua
1312+
LuaState state, StackTraceElement[] javaTrace,
1313+
boolean startInLua
13111314
) {
13121315
if (NO_BACKTRACE_MERGE) return javaTrace;
13131316

@@ -1337,16 +1340,16 @@ private static boolean isDowncall(StackTraceElement elem) {
13371340
// lua_h.lua_pcall is our downcall marker, we expect no other downcalls to occur.
13381341
// At every downcall point, we need to get the 'next' lua trace segment.
13391342
if (lua_h.class.getName().equals(elem.getClassName())
1340-
&& "lua_pcall".equals(elem.getMethodName())) return true;
1343+
&& "lua_pcall".equals(elem.getMethodName())) return true;
13411344
if (LuaStateImpl.class.getName().equals(elem.getClassName())
1342-
&& "resume".equals(elem.getMethodName())) return true;
1345+
&& "resume".equals(elem.getMethodName())) return true;
13431346
return false;
13441347
}
13451348

13461349
private static int readLuaTracePart(
1347-
MemorySegment L, MemorySegment luaElem,
1348-
List<StackTraceElement> mergedTrace,
1349-
int index
1350+
MemorySegment L, MemorySegment luaElem,
1351+
List<StackTraceElement> mergedTrace,
1352+
int index
13501353
) {
13511354
while (lua_getinfo(L, index++, LUA_DEBUG_WHAT, luaElem) != 0) {
13521355
char what = (char) lua_Debug.what(luaElem).get(ValueLayout.JAVA_BYTE, 0);
@@ -1363,14 +1366,14 @@ private static int readLuaTracePart(
13631366
int currentLine = !isLua ? -1 : lua_Debug.currentline(luaElem);
13641367

13651368
mergedTrace.add(new StackTraceElement(
1366-
// declaring class
1367-
"lua",
1368-
// method name
1369-
name,
1370-
// file name
1371-
Objects.requireNonNullElse(source, "<native>"),
1372-
// line number
1373-
currentLine));
1369+
// declaring class
1370+
"lua",
1371+
// method name
1372+
name,
1373+
// file name
1374+
Objects.requireNonNullElse(source, "<native>"),
1375+
// line number
1376+
currentLine));
13741377
}
13751378
return index;
13761379
}
@@ -1380,15 +1383,15 @@ private static boolean shouldExcludeElement(StackTraceElement elem) {
13801383

13811384
class Exclusions {
13821385
static final Set<String> SET = Set.of(lua_h.class.getName() + "-lua_pcall",
1383-
LuaFuncImpl.CFunctionWrapper.class.getName() + "-apply",
1384-
LuaStateImpl.class.getName() +
1385-
"-propagateException", LuaStateImpl.class.getName() +
1386-
"-propagateExceptionInner",
1387-
LuaStateImpl.class.getName() + "-pcallErrFunc",
1388-
LuaStateImpl.class.getName() + "-lambda$static$2");
1386+
LuaFuncImpl.CFunctionWrapper.class.getName() + "-apply",
1387+
LuaStateImpl.class.getName() +
1388+
"-propagateException", LuaStateImpl.class.getName() +
1389+
"-propagateExceptionInner",
1390+
LuaStateImpl.class.getName() + "-pcallErrFunc",
1391+
LuaStateImpl.class.getName() + "-lambda$static$2");
13891392
}
13901393
return Exclusions.SET.contains("%s-%s".formatted(elem.getClassName(),
1391-
elem.getMethodName()));
1394+
elem.getMethodName()));
13921395
}
13931396

13941397
static @Nullable String stripDefaultErrorPrefix(@Nullable String raw) {

src/test/java/net/hollowcube/luau/TestLuaState.java

Lines changed: 36 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -188,10 +188,10 @@ void readBuffer(LuaState state) {
188188
state.setGlobal("theBuffer");
189189

190190
eval(
191-
state,
192-
"""
193-
buffer.writei32(theBuffer, 32, buffer.len(theBuffer))
191+
state,
194192
"""
193+
buffer.writei32(theBuffer, 32, buffer.len(theBuffer))
194+
"""
195195
);
196196

197197
state.getGlobal("theBuffer");
@@ -284,13 +284,13 @@ void emptyFunctionCall(LuaState state, Arena arena) {
284284
@Test
285285
void functionParamReturn(LuaState state, Arena arena) {
286286
var func = LuaFunc.wrap(
287-
L -> {
288-
boolean b = L.type(1) == LuaType.TABLE;
289-
L.pushString(b ? "yes" : "no");
290-
return 1;
291-
},
292-
"func",
293-
arena
287+
L -> {
288+
boolean b = L.type(1) == LuaType.TABLE;
289+
L.pushString(b ? "yes" : "no");
290+
return 1;
291+
},
292+
"func",
293+
arena
294294
);
295295

296296
state.pushFunction(func);
@@ -313,12 +313,12 @@ private static class MockLuaFunc {
313313

314314
public MockLuaFunc(Arena arena) {
315315
this.ref = LuaFunc.wrap(
316-
_ -> {
317-
callCount.incrementAndGet();
318-
return 0;
319-
},
320-
"mockFunc",
321-
arena
316+
_ -> {
317+
callCount.incrementAndGet();
318+
return 0;
319+
},
320+
"mockFunc",
321+
arena
322322
);
323323
}
324324

@@ -328,9 +328,9 @@ public void assertCalled() {
328328

329329
public void assertCalled(int times) {
330330
assertEquals(
331-
times,
332-
callCount.get(),
333-
"was not called " + times + " times"
331+
times,
332+
callCount.get(),
333+
"was not called " + times + " times"
334334
);
335335
}
336336
}
@@ -442,7 +442,7 @@ void gcCategoryManipulation(LuaState state) {
442442
state.setMemCat(42);
443443
assertEquals(0, state.totalBytes(42));
444444
state.newUserData(
445-
"this shouldnt count, it should only be 8 bytes because java owns this string"
445+
"this shouldnt count, it should only be 8 bytes because java owns this string"
446446
);
447447
assertEquals(32, state.totalBytes(42));
448448
}
@@ -462,4 +462,20 @@ void threadData(LuaState state) {
462462
}
463463

464464
//TODO test all the check and opt methods
465+
466+
@Test
467+
void regressionCloseWithThreadsDestroyingCallbacksEarly() {
468+
// Covers a regression where the JavaCallbacks instance is destroyed too early,
469+
// so threads closed from a LuaState#close did not have access to the instance.
470+
// Not throwing is the expected output
471+
472+
var state = LuaState.newState();
473+
state.openLibs();
474+
var thread = state.newThread();
475+
eval(thread, "local x = 1 + 2");
476+
477+
// Thread left alive
478+
479+
assertDoesNotThrow(state::close);
480+
}
465481
}

0 commit comments

Comments
 (0)