@@ -18,12 +18,32 @@ extern "C"
1818#include " ALEUtility.h"
1919#include " SharedDefines.h"
2020
21+ enum MethodRegisterState
22+ {
23+ METHOD_REG_MAP = 0 ,
24+ METHOD_REG_WORLD = 1 ,
25+ METHOD_REG_ALL = 2
26+ };
27+
28+ struct ALEGlobalRegister
29+ {
30+ const char * name;
31+ int (*func)(lua_State*);
32+ MethodRegisterState regState;
33+
34+ ALEGlobalRegister (const char * name, int (*f)(lua_State*), MethodRegisterState state = METHOD_REG_ALL)
35+ : name(name), func(f), regState(state) {}
36+
37+ ALEGlobalRegister (const char * name, MethodRegisterState state = METHOD_REG_ALL)
38+ : name(name), func(nullptr ), regState(state) {}
39+ };
40+
2141class ALEGlobal
2242{
2343public:
2444 static int thunk (lua_State* L)
2545 {
26- luaL_Reg * l = static_cast <luaL_Reg *>(lua_touserdata (L, lua_upvalueindex (1 )));
46+ ALEGlobalRegister * l = static_cast <ALEGlobalRegister *>(lua_touserdata (L, lua_upvalueindex (1 )));
2747 int top = lua_gettop (L);
2848 int expected = l->func (L);
2949 int args = lua_gettop (L) - top;
@@ -36,15 +56,38 @@ class ALEGlobal
3656 return expected;
3757 }
3858
39- static void SetMethods (ALE* E, luaL_Reg* methodTable)
59+ static int MethodWrongState (lua_State* L)
60+ {
61+ luaL_error (L, " attempt to call method '%s' that is not available in this state" , lua_tostring (L, lua_upvalueindex (1 )));
62+ return 0 ;
63+ }
64+
65+ static void SetMethods (ALE* E, ALEGlobalRegister* methodTable)
4066 {
4167 ASSERT (E);
4268 ASSERT (methodTable);
4369
4470 lua_pushglobaltable (E->L );
4571
46- for (; methodTable && methodTable->name && methodTable-> func ; ++methodTable)
72+ for (; methodTable && methodTable->name ; ++methodTable)
4773 {
74+ if (methodTable->regState != METHOD_REG_ALL)
75+ {
76+ bool isMapState = (E->GetStateMapId () != ALE_GLOBAL_STATE);
77+ if ((!isMapState && methodTable->regState == METHOD_REG_MAP) ||
78+ (isMapState && methodTable->regState == METHOD_REG_WORLD))
79+ {
80+ lua_pushstring (E->L , methodTable->name );
81+ lua_pushstring (E->L , methodTable->name );
82+ lua_pushcclosure (E->L , MethodWrongState, 1 );
83+ lua_rawset (E->L , -3 );
84+ continue ;
85+ }
86+ }
87+
88+ if (!methodTable->func )
89+ continue ;
90+
4891 lua_pushstring (E->L , methodTable->name );
4992 lua_pushlightuserdata (E->L , (void *)methodTable);
5093 lua_pushcclosure (E->L , thunk, 1 );
@@ -117,6 +160,13 @@ struct ALERegister
117160{
118161 const char * name;
119162 int (*mfunc)(lua_State*, T*);
163+ MethodRegisterState regState;
164+
165+ ALERegister (const char * name, int (*func)(lua_State*, T*), MethodRegisterState state = METHOD_REG_ALL)
166+ : name(name), mfunc(func), regState(state) {}
167+
168+ ALERegister (const char * name, MethodRegisterState state = METHOD_REG_ALL)
169+ : name(name), mfunc(nullptr ), regState(state) {}
120170};
121171
122172template <typename T>
@@ -241,8 +291,25 @@ class ALETemplate
241291 lua_rawget (E->L , LUA_REGISTRYINDEX);
242292 ASSERT (lua_istable (E->L , -1 ));
243293
244- for (; methodTable && methodTable->name && methodTable-> mfunc ; ++methodTable)
294+ for (; methodTable && methodTable->name ; ++methodTable)
245295 {
296+ if (methodTable->regState != METHOD_REG_ALL)
297+ {
298+ bool isMapState = (E->GetStateMapId () != ALE_GLOBAL_STATE);
299+ if ((!isMapState && methodTable->regState == METHOD_REG_MAP) ||
300+ (isMapState && methodTable->regState == METHOD_REG_WORLD))
301+ {
302+ lua_pushstring (E->L , methodTable->name );
303+ lua_pushstring (E->L , methodTable->name );
304+ lua_pushcclosure (E->L , MethodWrongState, 1 );
305+ lua_rawset (E->L , -3 );
306+ continue ;
307+ }
308+ }
309+
310+ if (!methodTable->mfunc )
311+ continue ;
312+
246313 lua_pushstring (E->L , methodTable->name );
247314 lua_pushlightuserdata (E->L , (void *)methodTable);
248315 lua_pushcclosure (E->L , CallMethod, 1 );
@@ -322,6 +389,12 @@ class ALETemplate
322389 return 0 ;
323390 }
324391
392+ static int MethodWrongState (lua_State* L)
393+ {
394+ luaL_error (L, " attempt to call method '%s' that is not available in this state" , lua_tostring (L, lua_upvalueindex (1 )));
395+ return 0 ;
396+ }
397+
325398 static int CallMethod (lua_State* L)
326399 {
327400 T* obj = ALE::CHECKOBJ<T>(L, 1 ); // get self
0 commit comments