Skip to content

Commit ef8a808

Browse files
authored
Merge pull request #556 from dgageot/fix-code-mode-required
Fix handling of required attributes
2 parents 1185b10 + 2af5edb commit ef8a808

4 files changed

Lines changed: 124 additions & 6 deletions

File tree

pkg/codemode/codemode.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ Available tools/functions:
2828
2929
`
3030

31-
func Wrap(toolsets []tools.ToolSet) tools.ToolSet {
31+
func Wrap(toolsets ...tools.ToolSet) tools.ToolSet {
3232
return &codeModeTool{
3333
toolsets: toolsets,
3434
}

pkg/codemode/codemode_test.go

Lines changed: 110 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,15 @@
11
package codemode
22

33
import (
4+
"context"
45
"encoding/json"
56
"fmt"
67
"testing"
78

89
"github.com/stretchr/testify/assert"
910
"github.com/stretchr/testify/require"
11+
12+
"github.com/docker/cagent/pkg/tools"
1013
)
1114

1215
func TestCodeModeTool_Tools(t *testing.T) {
@@ -74,11 +77,117 @@ func TestCodeModeTool_Instructions(t *testing.T) {
7477
}
7578

7679
func TestCodeModeTool_StartStop(t *testing.T) {
77-
tool := &codeModeTool{}
80+
inner := &testToolSet{}
81+
82+
tool := Wrap(inner)
83+
84+
assert.Equal(t, 0, inner.start)
85+
assert.Equal(t, 0, inner.stop)
7886

7987
err := tool.Start(t.Context())
8088
require.NoError(t, err)
89+
assert.Equal(t, 1, inner.start)
90+
assert.Equal(t, 0, inner.stop)
8191

8292
err = tool.Stop(t.Context())
8393
require.NoError(t, err)
94+
assert.Equal(t, 1, inner.start)
95+
assert.Equal(t, 1, inner.stop)
8496
}
97+
98+
func TestCodeModeTool_CallHello(t *testing.T) {
99+
tool := Wrap(&testToolSet{
100+
tools: []tools.Tool{{
101+
Name: "hello_world",
102+
Handler: func(ctx context.Context, toolCall tools.ToolCall) (*tools.ToolCallResult, error) {
103+
return &tools.ToolCallResult{
104+
Output: "Hello, World!",
105+
}, nil
106+
},
107+
}},
108+
})
109+
110+
allTools, err := tool.Tools(t.Context())
111+
require.NoError(t, err)
112+
require.Len(t, allTools, 1)
113+
114+
result, err := allTools[0].Handler(t.Context(), tools.ToolCall{
115+
Function: tools.FunctionCall{
116+
Arguments: `{"script":"return hello_world();"}`,
117+
},
118+
})
119+
require.NoError(t, err)
120+
121+
var scriptResult ScriptResult
122+
err = json.Unmarshal([]byte(result.Output), &scriptResult)
123+
require.NoError(t, err)
124+
125+
require.Equal(t, "Hello, World!", scriptResult.Value)
126+
require.Empty(t, scriptResult.StdErr)
127+
require.Empty(t, scriptResult.StdOut)
128+
}
129+
130+
func TestCodeModeTool_CallEcho(t *testing.T) {
131+
type EchoArgs struct {
132+
Message string `json:"message" jsonschema:"Message to echo"`
133+
}
134+
135+
tool := Wrap(&testToolSet{
136+
tools: []tools.Tool{{
137+
Name: "echo",
138+
Handler: func(ctx context.Context, toolCall tools.ToolCall) (*tools.ToolCallResult, error) {
139+
return &tools.ToolCallResult{
140+
Output: "ECHO",
141+
}, nil
142+
},
143+
Parameters: tools.MustSchemaFor[EchoArgs](),
144+
}},
145+
})
146+
147+
allTools, err := tool.Tools(t.Context())
148+
require.NoError(t, err)
149+
require.Len(t, allTools, 1)
150+
151+
result, err := allTools[0].Handler(t.Context(), tools.ToolCall{
152+
Function: tools.FunctionCall{
153+
Arguments: `{"script":"return echo({'message':'ECHO'});"}`,
154+
},
155+
})
156+
require.NoError(t, err)
157+
158+
var scriptResult ScriptResult
159+
err = json.Unmarshal([]byte(result.Output), &scriptResult)
160+
require.NoError(t, err)
161+
162+
require.Equal(t, "ECHO", scriptResult.Value)
163+
require.Empty(t, scriptResult.StdErr)
164+
require.Empty(t, scriptResult.StdOut)
165+
}
166+
167+
type testToolSet struct {
168+
tools []tools.Tool
169+
start int
170+
stop int
171+
}
172+
173+
func (t *testToolSet) Tools(ctx context.Context) ([]tools.Tool, error) {
174+
return t.tools, nil
175+
}
176+
177+
func (t *testToolSet) Instructions() string {
178+
return ""
179+
}
180+
181+
func (t *testToolSet) Start(context.Context) error {
182+
t.start++
183+
return nil
184+
}
185+
186+
func (t *testToolSet) Stop(context.Context) error {
187+
t.stop++
188+
return nil
189+
}
190+
191+
func (t *testToolSet) SetElicitationHandler(tools.ElicitationHandler) {}
192+
193+
func (t *testToolSet) SetOAuthSuccessHandler(func()) {}

pkg/codemode/exec.go

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ import (
55
"context"
66
"encoding/json"
77
"fmt"
8+
"slices"
89

910
"github.com/dop251/goja"
1011

@@ -66,11 +67,19 @@ func (c *codeModeTool) runJavascript(ctx context.Context, script string) (Script
6667

6768
func callTool(ctx context.Context, tool tools.Tool) func(args map[string]any) (string, error) {
6869
return func(args map[string]any) (string, error) {
70+
var toolArgs struct {
71+
Required []string `json:"required"`
72+
}
73+
74+
if err := tools.ConvertSchema(tool.Parameters, &toolArgs); err != nil {
75+
return "", err
76+
}
77+
6978
nonNilArgs := make(map[string]any)
7079
for k, v := range args {
71-
// if slices.Contains(tool.Parameters.Required, k) || v != nil {
72-
nonNilArgs[k] = v
73-
// }
80+
if slices.Contains(toolArgs.Required, k) || v != nil {
81+
nonNilArgs[k] = v
82+
}
7483
}
7584

7685
arguments, err := json.Marshal(nonNilArgs)

pkg/teamloader/teamloader.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -255,7 +255,7 @@ func getToolsForAgent(ctx context.Context, a *latest.AgentConfig, parentDir stri
255255
// This allows the agent to call multiple tools in a single response.
256256
// It also allows to combine the results of multiple tools in a single response.
257257
return []tools.ToolSet{
258-
codemode.Wrap(t),
258+
codemode.Wrap(t...),
259259
}, nil
260260
}
261261

0 commit comments

Comments
 (0)