diff --git a/table.go b/table.go index ddf14dd8..b655edc2 100644 --- a/table.go +++ b/table.go @@ -324,27 +324,44 @@ func (tb *LTable) RawGetString(key string) LValue { // ForEach iterates over this table of elements, yielding each in turn to a given function. func (tb *LTable) ForEach(cb func(LValue, LValue)) { + tb.ForEachWithError(func(key, value LValue) error { + cb(key, value) + return nil + }) +} + +// ForEachWithError iterates over this table of elements, yielding each in turn +// to a given function. If it receives a non-nil error from its callback, it +// breaks and passes it back to its caller. +func (tb *LTable) ForEachWithError(cb func(LValue, LValue) error) error { if tb.array != nil { for i, v := range tb.array { if v != LNil { - cb(LNumber(i+1), v) + if err := cb(LNumber(i+1), v); err != nil { + return err + } } } } if tb.strdict != nil { for k, v := range tb.strdict { if v != LNil { - cb(LString(k), v) + if err := cb(LString(k), v); err != nil { + return err + } } } } if tb.dict != nil { for k, v := range tb.dict { if v != LNil { - cb(k, v) + if err := cb(k, v); err != nil { + return err + } } } } + return nil } // This function is equivalent to lua_next ( http://www.lua.org/manual/5.1/manual.html#lua_next ). diff --git a/table_test.go b/table_test.go index 6acbbb2c..50ffefc4 100644 --- a/table_test.go +++ b/table_test.go @@ -1,6 +1,7 @@ package lua import ( + "fmt" "testing" ) @@ -231,3 +232,29 @@ func TestTableForEach(t *testing.T) { } }) } + +func TestTableForEachWithError(t *testing.T) { + tbl := newLTable(0, 0) + tbl.Append(LNumber(1)) + tbl.Append(LNumber(2)) + tbl.Append(LNumber(3)) + tbl.Append(LNil) + tbl.Append(LNumber(5)) + + tbl.RawSetH(LString("a"), LString("a")) + tbl.RawSetH(LString("b"), LString("b")) + tbl.RawSetH(LString("c"), LString("c")) + + tbl.RawSetH(LTrue, LString("true")) + tbl.RawSetH(LFalse, LString("false")) + + testError := fmt.Errorf("test error") + runCount := 0 + err := tbl.ForEachWithError(func(key, value LValue) error { + runCount += 1 + return testError + }) + + errorIfNotEqual(t, testError, err) + errorIfNotEqual(t, 1, runCount) +}