Skip to content

Commit 7a4d6e7

Browse files
committed
refactor(luai): enhance Marshal to support generic Go functions
Add support for marshaling generic Go functions into Lua functions, including variadic functions and struct parameters. This improves flexibility and reduces the need for custom Lua function wrappers. Additionally, update related test cases to validate the new functionality.
1 parent 67e6bab commit 7a4d6e7

8 files changed

Lines changed: 443 additions & 16 deletions

File tree

internal/luai/encode.go

Lines changed: 85 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ package luai
1818

1919
import (
2020
"errors"
21+
"fmt"
2122
"reflect"
2223

2324
lua "github.com/yuin/gopher-lua"
@@ -116,12 +117,93 @@ func Marshal(state *lua.LState, v any) (lua.LValue, error) {
116117
}
117118
return table, nil
118119
case reflect.Func:
119-
if reflected.Type().ConvertibleTo(reflect.TypeOf(lua.LGFunction(nil))) {
120+
goFuncType := reflected.Type()
121+
// If it's already an LGFunction, use it directly
122+
if goFuncType.ConvertibleTo(reflect.TypeOf(lua.LGFunction(nil))) {
120123
lf := reflected.Convert(reflect.TypeOf(lua.LGFunction(nil))).Interface().(lua.LGFunction)
121124
return state.NewFunction(lf), nil
122-
} else {
123-
return nil, errors.New("marshal: unsupported function type " + reflected.Type().String())
124125
}
126+
127+
// Generic Go function wrapper
128+
luaFunc := func(L *lua.LState) int {
129+
numIn := goFuncType.NumIn()
130+
actualNumArgs := L.GetTop()
131+
isVariadic := goFuncType.IsVariadic()
132+
133+
expectedMinArgs := numIn
134+
if isVariadic {
135+
expectedMinArgs = numIn - 1
136+
}
137+
138+
if actualNumArgs < expectedMinArgs {
139+
L.RaiseError(fmt.Sprintf("expected at least %d arguments for %s, got %d", expectedMinArgs, goFuncType.String(), actualNumArgs))
140+
return 0 // Should not reach here due to RaiseError
141+
}
142+
if !isVariadic && actualNumArgs != numIn {
143+
L.RaiseError(fmt.Sprintf("expected %d arguments for %s, got %d", numIn, goFuncType.String(), actualNumArgs))
144+
return 0 // Should not reach here due to RaiseError
145+
}
146+
147+
goArgs := make([]reflect.Value, numIn)
148+
for i := 0; i < numIn; i++ {
149+
goArgType := goFuncType.In(i)
150+
151+
if isVariadic && i == numIn-1 { // Last argument of a variadic function
152+
sliceElementType := goArgType.Elem()
153+
variadicLen := actualNumArgs - (numIn - 1)
154+
if variadicLen < 0 {
155+
variadicLen = 0
156+
}
157+
variadicSlice := reflect.MakeSlice(goArgType, variadicLen, variadicLen)
158+
for j := 0; j < variadicLen; j++ {
159+
luaVariadicArg := L.CheckAny(i + 1 + j)
160+
elemPtr := reflect.New(sliceElementType)
161+
err := Unmarshal(luaVariadicArg, elemPtr.Interface()) // Unmarshal is in the same package
162+
if err != nil {
163+
L.Push(lua.LNil)
164+
L.Push(lua.LString(fmt.Sprintf("error unmarshaling variadic argument %d (item %d): %s", i+1, j+1, err.Error())))
165+
return 2
166+
}
167+
variadicSlice.Index(j).Set(elemPtr.Elem())
168+
}
169+
goArgs[i] = variadicSlice
170+
break // All arguments processed for variadic function
171+
} else {
172+
luaArg := L.CheckAny(i + 1)
173+
goArgPtr := reflect.New(goArgType)
174+
err := Unmarshal(luaArg, goArgPtr.Interface())
175+
if err != nil {
176+
L.Push(lua.LNil)
177+
L.Push(lua.LString(fmt.Sprintf("error unmarshaling argument %d: %s", i+1, err.Error())))
178+
return 2
179+
}
180+
goArgs[i] = goArgPtr.Elem()
181+
}
182+
}
183+
184+
var results []reflect.Value
185+
if isVariadic {
186+
results = reflected.CallSlice(goArgs)
187+
} else {
188+
results = reflected.Call(goArgs)
189+
}
190+
191+
if len(results) == 0 {
192+
return 0
193+
}
194+
195+
for _, result := range results {
196+
luaResult, err := Marshal(L, result.Interface())
197+
if err != nil {
198+
L.Push(lua.LNil)
199+
L.Push(lua.LString(fmt.Sprintf("error marshaling result: %s", err.Error())))
200+
return 2
201+
}
202+
L.Push(luaResult)
203+
}
204+
return len(results)
205+
}
206+
return state.NewFunction(luaFunc), nil
125207
default:
126208
return nil, errors.New("marshal: unsupported type " + reflected.Kind().String() + " for reflected ")
127209
}

0 commit comments

Comments
 (0)