Skip to content

Commit fe34771

Browse files
committed
Separate methods per state
1 parent 922302c commit fe34771

1 file changed

Lines changed: 77 additions & 4 deletions

File tree

src/LuaEngine/ALETemplate.h

Lines changed: 77 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -18,12 +18,32 @@ extern "C"
1818
#include "ALEUtility.h"
1919
#include "SharedDefines.h"
2020

21+
enum MethodRegisterState
22+
{
23+
METHOD_REG_MAP = 0,
24+
METHOD_REG_WORLD = 1,
25+
METHOD_REG_ALL = 2
26+
};
27+
28+
struct ALEGlobalRegister
29+
{
30+
const char* name;
31+
int(*func)(lua_State*);
32+
MethodRegisterState regState;
33+
34+
ALEGlobalRegister(const char* name, int(*f)(lua_State*), MethodRegisterState state = METHOD_REG_ALL)
35+
: name(name), func(f), regState(state) {}
36+
37+
ALEGlobalRegister(const char* name, MethodRegisterState state = METHOD_REG_ALL)
38+
: name(name), func(nullptr), regState(state) {}
39+
};
40+
2141
class ALEGlobal
2242
{
2343
public:
2444
static int thunk(lua_State* L)
2545
{
26-
luaL_Reg* l = static_cast<luaL_Reg*>(lua_touserdata(L, lua_upvalueindex(1)));
46+
ALEGlobalRegister* l = static_cast<ALEGlobalRegister*>(lua_touserdata(L, lua_upvalueindex(1)));
2747
int top = lua_gettop(L);
2848
int expected = l->func(L);
2949
int args = lua_gettop(L) - top;
@@ -36,15 +56,38 @@ class ALEGlobal
3656
return expected;
3757
}
3858

39-
static void SetMethods(ALE* E, luaL_Reg* methodTable)
59+
static int MethodWrongState(lua_State* L)
60+
{
61+
luaL_error(L, "attempt to call method '%s' that is not available in this state", lua_tostring(L, lua_upvalueindex(1)));
62+
return 0;
63+
}
64+
65+
static void SetMethods(ALE* E, ALEGlobalRegister* methodTable)
4066
{
4167
ASSERT(E);
4268
ASSERT(methodTable);
4369

4470
lua_pushglobaltable(E->L);
4571

46-
for (; methodTable && methodTable->name && methodTable->func; ++methodTable)
72+
for (; methodTable && methodTable->name; ++methodTable)
4773
{
74+
if (methodTable->regState != METHOD_REG_ALL)
75+
{
76+
bool isMapState = (E->GetStateMapId() != ALE_GLOBAL_STATE);
77+
if ((!isMapState && methodTable->regState == METHOD_REG_MAP) ||
78+
(isMapState && methodTable->regState == METHOD_REG_WORLD))
79+
{
80+
lua_pushstring(E->L, methodTable->name);
81+
lua_pushstring(E->L, methodTable->name);
82+
lua_pushcclosure(E->L, MethodWrongState, 1);
83+
lua_rawset(E->L, -3);
84+
continue;
85+
}
86+
}
87+
88+
if (!methodTable->func)
89+
continue;
90+
4891
lua_pushstring(E->L, methodTable->name);
4992
lua_pushlightuserdata(E->L, (void*)methodTable);
5093
lua_pushcclosure(E->L, thunk, 1);
@@ -117,6 +160,13 @@ struct ALERegister
117160
{
118161
const char* name;
119162
int(*mfunc)(lua_State*, T*);
163+
MethodRegisterState regState;
164+
165+
ALERegister(const char* name, int(*func)(lua_State*, T*), MethodRegisterState state = METHOD_REG_ALL)
166+
: name(name), mfunc(func), regState(state) {}
167+
168+
ALERegister(const char* name, MethodRegisterState state = METHOD_REG_ALL)
169+
: name(name), mfunc(nullptr), regState(state) {}
120170
};
121171

122172
template<typename T>
@@ -241,8 +291,25 @@ class ALETemplate
241291
lua_rawget(E->L, LUA_REGISTRYINDEX);
242292
ASSERT(lua_istable(E->L, -1));
243293

244-
for (; methodTable && methodTable->name && methodTable->mfunc; ++methodTable)
294+
for (; methodTable && methodTable->name; ++methodTable)
245295
{
296+
if (methodTable->regState != METHOD_REG_ALL)
297+
{
298+
bool isMapState = (E->GetStateMapId() != ALE_GLOBAL_STATE);
299+
if ((!isMapState && methodTable->regState == METHOD_REG_MAP) ||
300+
(isMapState && methodTable->regState == METHOD_REG_WORLD))
301+
{
302+
lua_pushstring(E->L, methodTable->name);
303+
lua_pushstring(E->L, methodTable->name);
304+
lua_pushcclosure(E->L, MethodWrongState, 1);
305+
lua_rawset(E->L, -3);
306+
continue;
307+
}
308+
}
309+
310+
if (!methodTable->mfunc)
311+
continue;
312+
246313
lua_pushstring(E->L, methodTable->name);
247314
lua_pushlightuserdata(E->L, (void*)methodTable);
248315
lua_pushcclosure(E->L, CallMethod, 1);
@@ -322,6 +389,12 @@ class ALETemplate
322389
return 0;
323390
}
324391

392+
static int MethodWrongState(lua_State* L)
393+
{
394+
luaL_error(L, "attempt to call method '%s' that is not available in this state", lua_tostring(L, lua_upvalueindex(1)));
395+
return 0;
396+
}
397+
325398
static int CallMethod(lua_State* L)
326399
{
327400
T* obj = ALE::CHECKOBJ<T>(L, 1); // get self

0 commit comments

Comments
 (0)