Skip to content

Commit b560484

Browse files
authored
Merge pull request #480 from dgageot/schema-rework
JSON Schema rework
2 parents 3dcd387 + 8260062 commit b560484

34 files changed

Lines changed: 1346 additions & 1191 deletions

examples/golibrary/tool/main.go

Lines changed: 23 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,8 @@ import (
44
"context"
55
"encoding/json"
66
"fmt"
7-
"log"
7+
"os/signal"
8+
"syscall"
89

910
"github.com/docker/cagent/pkg/agent"
1011
latest "github.com/docker/cagent/pkg/config/v2"
@@ -16,13 +17,22 @@ import (
1617
"github.com/docker/cagent/pkg/tools"
1718
)
1819

19-
func addNumbers(_ context.Context, toolCall tools.ToolCall) (*tools.ToolCallResult, error) {
20-
type params struct {
21-
A int `json:"a"`
22-
B int `json:"b"`
20+
func main() {
21+
ctx, cancel := signal.NotifyContext(context.Background(), syscall.SIGINT, syscall.SIGTERM)
22+
defer cancel()
23+
24+
if err := run(ctx); err != nil {
25+
fmt.Println(err)
2326
}
27+
}
28+
29+
type AddNumbersArgs struct {
30+
A int `json:"a"`
31+
B int `json:"b"`
32+
}
2433

25-
var p params
34+
func addNumbers(_ context.Context, toolCall tools.ToolCall) (*tools.ToolCallResult, error) {
35+
var p AddNumbersArgs
2636
if err := json.Unmarshal([]byte(toolCall.Function.Arguments), &p); err != nil {
2737
return nil, err
2838
}
@@ -34,9 +44,7 @@ func addNumbers(_ context.Context, toolCall tools.ToolCall) (*tools.ToolCallResu
3444
}, nil
3545
}
3646

37-
func main() {
38-
ctx := context.Background()
39-
47+
func run(ctx context.Context) error {
4048
llm, err := openai.NewClient(
4149
ctx,
4250
&latest.ModelConfig{
@@ -46,24 +54,14 @@ func main() {
4654
environment.NewDefaultProvider(ctx),
4755
)
4856
if err != nil {
49-
log.Fatal(err)
57+
return err
5058
}
5159

5260
toolAddNumbers := tools.Tool{
5361
Name: "add",
5462
Description: "Add two numbers",
55-
Parameters: tools.FunctionParameters{
56-
Type: "object",
57-
Properties: map[string]any{
58-
"a": map[string]any{
59-
"type": "number",
60-
},
61-
"b": map[string]any{
62-
"type": "number",
63-
},
64-
},
65-
},
66-
Handler: addNumbers,
63+
Parameters: tools.MustSchemaFor[AddNumbersArgs](),
64+
Handler: addNumbers,
6765
}
6866

6967
calculator := agent.New(
@@ -77,15 +75,16 @@ func main() {
7775

7876
rt, err := runtime.New(calculatorTeam)
7977
if err != nil {
80-
log.Fatal(err)
78+
return err
8179
}
8280

8381
sess := session.New(session.WithUserMessage("", "What is 1 + 2?"))
8482

8583
messages, err := rt.Run(ctx, sess)
8684
if err != nil {
87-
log.Fatal(err)
85+
return err
8886
}
8987

9088
fmt.Println(messages[len(messages)-1].Message.Content)
89+
return nil
9190
}

pkg/codemode/codemode.go

Lines changed: 19 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -26,19 +26,25 @@ Available tools/functions:
2626
2727
`
2828

29-
type RunToolsWithJavascriptArgs struct {
30-
Script string `json:"script"`
29+
func Wrap(toolsets []tools.ToolSet) tools.ToolSet {
30+
return &codeModeTool{
31+
toolsets: toolsets,
32+
}
3133
}
3234

33-
type tool struct {
35+
type codeModeTool struct {
3436
toolsets []tools.ToolSet
3537
}
3638

37-
func (c *tool) Instructions() string {
39+
type RunToolsWithJavascriptArgs struct {
40+
Script string `json:"script" jsonschema:"Script to execute"`
41+
}
42+
43+
func (c *codeModeTool) Instructions() string {
3844
return ""
3945
}
4046

41-
func (c *tool) Tools(ctx context.Context) ([]tools.Tool, error) {
47+
func (c *codeModeTool) Tools(ctx context.Context) ([]tools.Tool, error) {
4248
var functionsDoc []string
4349

4450
for _, toolset := range c.toolsets {
@@ -55,19 +61,7 @@ func (c *tool) Tools(ctx context.Context) ([]tools.Tool, error) {
5561
return []tools.Tool{{
5662
Name: "run_tools_with_javascript",
5763
Description: prompt + strings.Join(functionsDoc, "\n"),
58-
Annotations: tools.ToolAnnotations{
59-
Title: "Run tools with Javascript",
60-
},
61-
Parameters: tools.FunctionParameters{
62-
Type: "object",
63-
Required: []string{"script"},
64-
Properties: map[string]any{
65-
"script": map[string]any{
66-
"type": "string",
67-
"description": "script to execute",
68-
},
69-
},
70-
},
64+
Parameters: tools.MustSchemaFor[RunToolsWithJavascriptArgs](),
7165
Handler: func(ctx context.Context, toolCall tools.ToolCall) (*tools.ToolCallResult, error) {
7266
var args RunToolsWithJavascriptArgs
7367
if err := json.Unmarshal([]byte(toolCall.Function.Arguments), &args); err != nil {
@@ -83,10 +77,13 @@ func (c *tool) Tools(ctx context.Context) ([]tools.Tool, error) {
8377
Output: output,
8478
}, nil
8579
},
80+
Annotations: tools.ToolAnnotations{
81+
Title: "Run tools with Javascript",
82+
},
8683
}}, nil
8784
}
8885

89-
func (c *tool) Start(ctx context.Context) error {
86+
func (c *codeModeTool) Start(ctx context.Context) error {
9087
for _, t := range c.toolsets {
9188
if err := t.Start(ctx); err != nil {
9289
return err
@@ -96,7 +93,7 @@ func (c *tool) Start(ctx context.Context) error {
9693
return nil
9794
}
9895

99-
func (c *tool) Stop() error {
96+
func (c *codeModeTool) Stop() error {
10097
var errs []error
10198

10299
for _, t := range c.toolsets {
@@ -108,16 +105,10 @@ func (c *tool) Stop() error {
108105
return errors.Join(errs...)
109106
}
110107

111-
func (c *tool) SetElicitationHandler(handler tools.ElicitationHandler) {
108+
func (c *codeModeTool) SetElicitationHandler(handler tools.ElicitationHandler) {
112109
// No-op, this tool does not use elicitation
113110
}
114111

115-
func (c *tool) SetOAuthSuccessHandler(handler func()) {
112+
func (c *codeModeTool) SetOAuthSuccessHandler(handler func()) {
116113
// No-op, this tool does not use OAuth
117114
}
118-
119-
func Wrap(toolsets []tools.ToolSet) tools.ToolSet {
120-
return &tool{
121-
toolsets: toolsets,
122-
}
123-
}

pkg/codemode/codemode_test.go

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,55 @@
1+
package codemode
2+
3+
import (
4+
"encoding/json"
5+
"testing"
6+
7+
"github.com/stretchr/testify/assert"
8+
"github.com/stretchr/testify/require"
9+
)
10+
11+
func TestCodeModeTool_Tools(t *testing.T) {
12+
tool := &codeModeTool{}
13+
14+
toolSet, err := tool.Tools(t.Context())
15+
require.NoError(t, err)
16+
require.Len(t, toolSet, 1)
17+
18+
fetchTool := toolSet[0]
19+
assert.Equal(t, "run_tools_with_javascript", fetchTool.Name)
20+
assert.NotNil(t, fetchTool.Handler)
21+
22+
schema, err := json.Marshal(fetchTool.Parameters)
23+
require.NoError(t, err)
24+
assert.JSONEq(t, `{
25+
"type": "object",
26+
"required": [
27+
"script"
28+
],
29+
"properties": {
30+
"script": {
31+
"type": "string",
32+
"description": "Script to execute"
33+
}
34+
},
35+
"additionalProperties": false
36+
}`, string(schema))
37+
}
38+
39+
func TestCodeModeTool_Instructions(t *testing.T) {
40+
tool := &codeModeTool{}
41+
42+
instructions := tool.Instructions()
43+
44+
assert.Empty(t, instructions)
45+
}
46+
47+
func TestCodeModeTool_StartStop(t *testing.T) {
48+
tool := &codeModeTool{}
49+
50+
err := tool.Start(t.Context())
51+
require.NoError(t, err)
52+
53+
err = tool.Stop()
54+
require.NoError(t, err)
55+
}

pkg/codemode/console.go

Lines changed: 18 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -7,35 +7,23 @@ import (
77

88
func console() map[string]any {
99
return map[string]any{
10-
"debug": console_debug,
11-
"error": console_error,
12-
"info": console_info,
13-
"log": console_log,
14-
"trace": console_trace,
15-
"warn": console_warn,
10+
"debug": func(args ...any) {
11+
fmt.Fprintln(os.Stdout, args...)
12+
},
13+
"error": func(args ...any) {
14+
fmt.Fprintln(os.Stdout, args...)
15+
},
16+
"info": func(args ...any) {
17+
fmt.Fprintln(os.Stdout, args...)
18+
},
19+
"log": func(args ...any) {
20+
fmt.Fprintln(os.Stdout, args...)
21+
},
22+
"trace": func(args ...any) {
23+
fmt.Fprintln(os.Stdout, args...)
24+
},
25+
"warn": func(args ...any) {
26+
fmt.Fprintln(os.Stdout, args...)
27+
},
1628
}
1729
}
18-
19-
func console_debug(args ...any) { //nolint:staticcheck // match JavaScript's console method names.
20-
fmt.Fprintln(os.Stdout, args...)
21-
}
22-
23-
func console_error(args ...any) { //nolint:staticcheck // match JavaScript's console method names.
24-
fmt.Fprintln(os.Stdout, args...)
25-
}
26-
27-
func console_info(args ...any) { //nolint:staticcheck // match JavaScript's console method names.
28-
fmt.Fprintln(os.Stdout, args...)
29-
}
30-
31-
func console_log(args ...any) { //nolint:staticcheck // match JavaScript's console method names.
32-
fmt.Fprintln(os.Stdout, args...)
33-
}
34-
35-
func console_trace(args ...any) { //nolint:staticcheck // match JavaScript's console method names.
36-
fmt.Fprintln(os.Stdout, args...)
37-
}
38-
39-
func console_warn(args ...any) { //nolint:staticcheck // match JavaScript's console method names.
40-
fmt.Fprintln(os.Stdout, args...)
41-
}

pkg/codemode/exec.go

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -4,14 +4,13 @@ import (
44
"context"
55
"encoding/json"
66
"fmt"
7-
"slices"
87

98
"github.com/dop251/goja"
109

1110
"github.com/docker/cagent/pkg/tools"
1211
)
1312

14-
func (c *tool) runJavascript(ctx context.Context, script string) (string, error) {
13+
func (c *codeModeTool) runJavascript(ctx context.Context, script string) (string, error) {
1514
vm := goja.New()
1615

1716
// Inject console object to the help the LLM debug its own code.
@@ -52,9 +51,9 @@ func callTool(ctx context.Context, tool tools.Tool) func(args map[string]any) (s
5251
return func(args map[string]any) (string, error) {
5352
nonNilArgs := make(map[string]any)
5453
for k, v := range args {
55-
if slices.Contains(tool.Parameters.Required, k) || v != nil {
56-
nonNilArgs[k] = v
57-
}
54+
// if slices.Contains(tool.Parameters.Required, k) || v != nil {
55+
nonNilArgs[k] = v
56+
// }
5857
}
5958

6059
arguments, err := json.Marshal(nonNilArgs)

0 commit comments

Comments
 (0)