Skip to content

Commit 9b91655

Browse files
Copilotdgageot
andcommitted
Implement output schema support in code mode
Co-authored-by: dgageot <153495+dgageot@users.noreply.github.com>
1 parent be5f3d6 commit 9b91655

6 files changed

Lines changed: 476 additions & 8 deletions

File tree

pkg/codemode/codemode.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ and manipulate the results before returning them.
1919
Instructions:
2020
- The script has access to all the tools as plain javascript functions.
2121
- "await"/"async" are never needed. All the tool calls are synchronous.
22-
- Every tool function returns a string result.
22+
- Each tool function returns the type specified in its signature below.
2323
- The script must return a string result.
2424
2525
Available tools/functions:

pkg/codemode/exec.go

Lines changed: 29 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ import (
77
"slices"
88

99
"github.com/dop251/goja"
10+
"github.com/google/jsonschema-go/jsonschema"
1011

1112
"github.com/docker/cagent/pkg/tools"
1213
)
@@ -48,8 +49,8 @@ func (c *tool) runJavascript(ctx context.Context, script string) (string, error)
4849
return fmt.Sprintf("%v", result), nil
4950
}
5051

51-
func callTool(ctx context.Context, tool tools.Tool) func(args map[string]any) (string, error) {
52-
return func(args map[string]any) (string, error) {
52+
func callTool(ctx context.Context, tool tools.Tool) func(args map[string]any) (any, error) {
53+
return func(args map[string]any) (any, error) {
5354
nonNilArgs := make(map[string]any)
5455
for k, v := range args {
5556
if slices.Contains(tool.Function.Parameters.Required, k) || v != nil {
@@ -59,7 +60,7 @@ func callTool(ctx context.Context, tool tools.Tool) func(args map[string]any) (s
5960

6061
arguments, err := json.Marshal(nonNilArgs)
6162
if err != nil {
62-
return "", err
63+
return nil, err
6364
}
6465

6566
result, err := tool.Handler(ctx, tools.ToolCall{
@@ -69,9 +70,32 @@ func callTool(ctx context.Context, tool tools.Tool) func(args map[string]any) (s
6970
},
7071
})
7172
if err != nil {
72-
return "", err
73+
return nil, err
74+
}
75+
76+
// If the tool has a string output schema or no schema, return as string
77+
if tool.Function.OutputSchema == nil {
78+
return result.Output, nil
79+
}
80+
81+
// Check if output schema indicates a string type
82+
if s, ok := tool.Function.OutputSchema.(*jsonschema.Schema); ok {
83+
if s.Type == "string" {
84+
return result.Output, nil
85+
}
86+
} else if schemaMap, ok := tool.Function.OutputSchema.(map[string]any); ok {
87+
if schemaType, hasType := schemaMap["type"].(string); hasType && schemaType == "string" {
88+
return result.Output, nil
89+
}
90+
}
91+
92+
// For non-string schemas, try to parse JSON
93+
var parsed any
94+
if err := json.Unmarshal([]byte(result.Output), &parsed); err != nil {
95+
// If JSON parsing fails, return as string (fallback)
96+
return result.Output, nil
7397
}
7498

75-
return result.Output, nil
99+
return parsed, nil
76100
}
77101
}

pkg/codemode/exec_test.go

Lines changed: 127 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,127 @@
1+
package codemode
2+
3+
import (
4+
"context"
5+
"encoding/json"
6+
"reflect"
7+
"testing"
8+
9+
"github.com/stretchr/testify/assert"
10+
"github.com/stretchr/testify/require"
11+
12+
"github.com/docker/cagent/pkg/tools"
13+
)
14+
15+
func TestCallTool_StringOutput(t *testing.T) {
16+
tool := tools.Tool{
17+
Function: &tools.FunctionDefinition{
18+
Name: "string_tool",
19+
OutputSchema: tools.ToOutputSchemaSchemaMust(reflect.TypeFor[string]()),
20+
},
21+
Handler: func(ctx context.Context, toolCall tools.ToolCall) (*tools.ToolCallResult, error) {
22+
return &tools.ToolCallResult{Output: "hello world"}, nil
23+
},
24+
}
25+
26+
callFunc := callTool(context.Background(), tool)
27+
result, err := callFunc(map[string]any{})
28+
29+
require.NoError(t, err)
30+
assert.Equal(t, "hello world", result)
31+
assert.IsType(t, "", result) // Should be string type
32+
}
33+
34+
func TestCallTool_ArrayOutput(t *testing.T) {
35+
tool := tools.Tool{
36+
Function: &tools.FunctionDefinition{
37+
Name: "array_tool",
38+
OutputSchema: tools.ToOutputSchemaSchemaMust(reflect.TypeFor[[]string]()),
39+
},
40+
Handler: func(ctx context.Context, toolCall tools.ToolCall) (*tools.ToolCallResult, error) {
41+
data := []string{"file1.txt", "file2.txt"}
42+
jsonData, _ := json.Marshal(data)
43+
return &tools.ToolCallResult{Output: string(jsonData)}, nil
44+
},
45+
}
46+
47+
callFunc := callTool(context.Background(), tool)
48+
result, err := callFunc(map[string]any{})
49+
50+
require.NoError(t, err)
51+
assert.IsType(t, []any{}, result) // Should be parsed as array
52+
53+
// Convert result to []any and check contents
54+
resultArray := result.([]any)
55+
assert.Len(t, resultArray, 2)
56+
assert.Equal(t, "file1.txt", resultArray[0])
57+
assert.Equal(t, "file2.txt", resultArray[1])
58+
}
59+
60+
func TestCallTool_ObjectOutput(t *testing.T) {
61+
type FileInfo struct {
62+
Name string `json:"name"`
63+
Size int64 `json:"size"`
64+
}
65+
66+
tool := tools.Tool{
67+
Function: &tools.FunctionDefinition{
68+
Name: "object_tool",
69+
OutputSchema: tools.ToOutputSchemaSchemaMust(reflect.TypeFor[FileInfo]()),
70+
},
71+
Handler: func(ctx context.Context, toolCall tools.ToolCall) (*tools.ToolCallResult, error) {
72+
data := FileInfo{Name: "test.txt", Size: 1024}
73+
jsonData, _ := json.Marshal(data)
74+
return &tools.ToolCallResult{Output: string(jsonData)}, nil
75+
},
76+
}
77+
78+
callFunc := callTool(context.Background(), tool)
79+
result, err := callFunc(map[string]any{})
80+
81+
require.NoError(t, err)
82+
assert.IsType(t, map[string]any{}, result) // Should be parsed as object
83+
84+
// Convert result to map and check contents
85+
resultMap := result.(map[string]any)
86+
assert.Equal(t, "test.txt", resultMap["name"])
87+
assert.Equal(t, float64(1024), resultMap["size"]) // JSON numbers become float64
88+
}
89+
90+
func TestCallTool_NoSchema(t *testing.T) {
91+
tool := tools.Tool{
92+
Function: &tools.FunctionDefinition{
93+
Name: "no_schema_tool",
94+
OutputSchema: nil,
95+
},
96+
Handler: func(ctx context.Context, toolCall tools.ToolCall) (*tools.ToolCallResult, error) {
97+
return &tools.ToolCallResult{Output: "raw output"}, nil
98+
},
99+
}
100+
101+
callFunc := callTool(context.Background(), tool)
102+
result, err := callFunc(map[string]any{})
103+
104+
require.NoError(t, err)
105+
assert.Equal(t, "raw output", result)
106+
assert.IsType(t, "", result) // Should remain as string
107+
}
108+
109+
func TestCallTool_InvalidJSON(t *testing.T) {
110+
tool := tools.Tool{
111+
Function: &tools.FunctionDefinition{
112+
Name: "invalid_json_tool",
113+
OutputSchema: tools.ToOutputSchemaSchemaMust(reflect.TypeFor[[]string]()),
114+
},
115+
Handler: func(ctx context.Context, toolCall tools.ToolCall) (*tools.ToolCallResult, error) {
116+
return &tools.ToolCallResult{Output: "invalid json {"}, nil
117+
},
118+
}
119+
120+
callFunc := callTool(context.Background(), tool)
121+
result, err := callFunc(map[string]any{})
122+
123+
require.NoError(t, err)
124+
// Should fallback to string when JSON parsing fails
125+
assert.Equal(t, "invalid json {", result)
126+
assert.IsType(t, "", result)
127+
}

pkg/codemode/functions.go

Lines changed: 83 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,19 +5,100 @@ import (
55
"slices"
66
"strings"
77

8+
"github.com/google/jsonschema-go/jsonschema"
9+
810
"github.com/docker/cagent/pkg/tools"
911
)
1012

13+
// schemaToJSType converts a JSON schema to a JavaScript/TypeScript type string
14+
func schemaToJSType(schema any) string {
15+
if schema == nil {
16+
return "any"
17+
}
18+
19+
// Handle jsonschema.Schema type from the Google jsonschema library
20+
if s, ok := schema.(*jsonschema.Schema); ok {
21+
return schemaToJSTypeFromStruct(s)
22+
}
23+
24+
// Handle boolean schema (any type)
25+
if boolSchema, ok := schema.(bool); ok {
26+
if boolSchema {
27+
return "any"
28+
}
29+
return "never"
30+
}
31+
32+
// Handle object schema (map)
33+
schemaMap, ok := schema.(map[string]any)
34+
if !ok {
35+
return "any"
36+
}
37+
38+
schemaType, hasType := schemaMap["type"].(string)
39+
if !hasType {
40+
return "any"
41+
}
42+
43+
switch schemaType {
44+
case "string":
45+
return "string"
46+
case "number", "integer":
47+
return "number"
48+
case "boolean":
49+
return "boolean"
50+
case "array":
51+
if items, hasItems := schemaMap["items"]; hasItems {
52+
itemType := schemaToJSType(items)
53+
return itemType + "[]"
54+
}
55+
return "any[]"
56+
case "object":
57+
// For complex objects, return 'object' for simplicity in JS context
58+
return "object"
59+
default:
60+
return "any"
61+
}
62+
}
63+
64+
func schemaToJSTypeFromStruct(s *jsonschema.Schema) string {
65+
switch s.Type {
66+
case "string":
67+
return "string"
68+
case "number", "integer":
69+
return "number"
70+
case "boolean":
71+
return "boolean"
72+
case "array":
73+
if s.Items != nil {
74+
itemType := schemaToJSType(s.Items)
75+
return itemType + "[]"
76+
}
77+
return "any[]"
78+
case "object":
79+
// For complex objects, return 'object' for simplicity in JS context
80+
return "object"
81+
default:
82+
return "any"
83+
}
84+
}
85+
1186
func toolToJsDoc(tool tools.Tool) string {
1287
var doc strings.Builder
1388

89+
// Determine return type from output schema
90+
returnType := "any" // default fallback when no schema is available
91+
if tool.Function.OutputSchema != nil {
92+
returnType = schemaToJSType(tool.Function.OutputSchema)
93+
}
94+
1495
doc.WriteString("===== " + tool.Function.Name + " =====\n\n")
1596
doc.WriteString(strings.TrimSpace(tool.Function.Description))
1697
doc.WriteString("\n\n")
1798
if len(tool.Function.Parameters.Properties) == 0 {
18-
doc.WriteString(fmt.Sprintf("%s(): string\n", tool.Function.Name))
99+
doc.WriteString(fmt.Sprintf("%s(): %s\n", tool.Function.Name, returnType))
19100
} else {
20-
doc.WriteString(fmt.Sprintf("%s(args: ArgsObject): string\n", tool.Function.Name))
101+
doc.WriteString(fmt.Sprintf("%s(args: ArgsObject): %s\n", tool.Function.Name, returnType))
21102
doc.WriteString("\nwhere type ArgsObject = {\n")
22103
for paramName, param := range tool.Function.Parameters.Properties {
23104
pType := "Object"

0 commit comments

Comments
 (0)