Skip to content

Commit d9c04f0

Browse files
authored
Merge pull request #563 from dgageot/filter-tools
Continue simplifying tools
2 parents c4c8555 + 81a16bd commit d9c04f0

10 files changed

Lines changed: 241 additions & 154 deletions

File tree

pkg/teamloader/filter.go

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
package teamloader
2+
3+
import (
4+
"context"
5+
"log/slog"
6+
"slices"
7+
8+
"github.com/docker/cagent/pkg/tools"
9+
)
10+
11+
func WithToolsFilter(inner tools.ToolSet, toolNames ...string) tools.ToolSet {
12+
if len(toolNames) == 0 {
13+
return inner
14+
}
15+
16+
return &filterTools{
17+
ToolSet: inner,
18+
toolNames: toolNames,
19+
}
20+
}
21+
22+
type filterTools struct {
23+
tools.ToolSet
24+
toolNames []string
25+
}
26+
27+
func (f *filterTools) Tools(ctx context.Context) ([]tools.Tool, error) {
28+
allTools, err := f.ToolSet.Tools(ctx)
29+
if err != nil {
30+
return nil, err
31+
}
32+
33+
var filtered []tools.Tool
34+
for _, tool := range allTools {
35+
if !slices.Contains(f.toolNames, tool.Name) {
36+
slog.Debug("Filtering out tool", "tool", tool.Name)
37+
continue
38+
}
39+
40+
filtered = append(filtered, tool)
41+
}
42+
43+
return filtered, nil
44+
}

pkg/teamloader/filter_test.go

Lines changed: 121 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,121 @@
1+
package teamloader
2+
3+
import (
4+
"context"
5+
"errors"
6+
"testing"
7+
8+
"github.com/stretchr/testify/assert"
9+
"github.com/stretchr/testify/require"
10+
11+
"github.com/docker/cagent/pkg/tools"
12+
)
13+
14+
type mockToolSet struct {
15+
tools.ToolSet
16+
toolsFunc func(ctx context.Context) ([]tools.Tool, error)
17+
}
18+
19+
func (m *mockToolSet) Tools(ctx context.Context) ([]tools.Tool, error) {
20+
if m.toolsFunc != nil {
21+
return m.toolsFunc(ctx)
22+
}
23+
return nil, nil
24+
}
25+
26+
func TestWithToolsFilter_NilToolNames(t *testing.T) {
27+
inner := &mockToolSet{}
28+
29+
wrapped := WithToolsFilter(inner)
30+
31+
assert.Same(t, inner, wrapped)
32+
}
33+
34+
func TestWithToolsFilter_EmptyNames(t *testing.T) {
35+
inner := &mockToolSet{}
36+
37+
wrapped := WithToolsFilter(inner, []string{}...)
38+
39+
assert.Same(t, inner, wrapped)
40+
}
41+
42+
func TestWithToolsFilter_PickOne(t *testing.T) {
43+
inner := &mockToolSet{
44+
toolsFunc: func(context.Context) ([]tools.Tool, error) {
45+
return []tools.Tool{{Name: "tool1"}, {Name: "tool2"}, {Name: "tool3"}}, nil
46+
},
47+
}
48+
49+
wrapped := WithToolsFilter(inner, "tool2")
50+
51+
result, err := wrapped.Tools(t.Context())
52+
require.NoError(t, err)
53+
require.Len(t, result, 1)
54+
assert.Equal(t, "tool2", result[0].Name)
55+
}
56+
57+
func TestWithToolsFilter_PickAll(t *testing.T) {
58+
inner := &mockToolSet{
59+
toolsFunc: func(context.Context) ([]tools.Tool, error) {
60+
return []tools.Tool{{Name: "tool1"}, {Name: "tool2"}, {Name: "tool3"}}, nil
61+
},
62+
}
63+
64+
wrapped := WithToolsFilter(inner, "tool1", "tool2", "tool3")
65+
66+
result, err := wrapped.Tools(t.Context())
67+
require.NoError(t, err)
68+
69+
require.Len(t, result, 3)
70+
assert.Equal(t, "tool1", result[0].Name)
71+
assert.Equal(t, "tool2", result[1].Name)
72+
assert.Equal(t, "tool3", result[2].Name)
73+
}
74+
75+
func TestWithToolsFilter_NoMatch(t *testing.T) {
76+
inner := &mockToolSet{
77+
toolsFunc: func(context.Context) ([]tools.Tool, error) {
78+
return []tools.Tool{{Name: "tool1"}, {Name: "tool2"}}, nil
79+
},
80+
}
81+
82+
wrapped := WithToolsFilter(inner, "tool3", "tool4")
83+
84+
result, err := wrapped.Tools(t.Context())
85+
require.NoError(t, err)
86+
assert.Empty(t, result)
87+
}
88+
89+
func TestWithToolsFilter_ErrorFromInner(t *testing.T) {
90+
expectedErr := errors.New("mock error")
91+
inner := &mockToolSet{
92+
toolsFunc: func(context.Context) ([]tools.Tool, error) {
93+
return nil, expectedErr
94+
},
95+
}
96+
97+
wrapped := WithToolsFilter(inner, "tool1")
98+
99+
result, err := wrapped.Tools(t.Context())
100+
assert.Nil(t, result)
101+
assert.ErrorIs(t, err, expectedErr)
102+
}
103+
104+
func TestWithToolsFilter_CaseSensitive(t *testing.T) {
105+
inner := &mockToolSet{
106+
toolsFunc: func(ctx context.Context) ([]tools.Tool, error) {
107+
return []tools.Tool{
108+
{Name: "Tool1"},
109+
{Name: "tool1"},
110+
{Name: "TOOL1"},
111+
}, nil
112+
},
113+
}
114+
115+
wrapped := WithToolsFilter(inner, "tool1")
116+
117+
result, err := wrapped.Tools(t.Context())
118+
require.NoError(t, err)
119+
require.Len(t, result, 1)
120+
assert.Equal(t, "tool1", result[0].Name)
121+
}

pkg/teamloader/teamloader.go

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -235,7 +235,10 @@ func getToolsForAgent(ctx context.Context, a *latest.AgentConfig, parentDir stri
235235
return nil, err
236236
}
237237

238-
t = append(t, WithInstructions(tool, a.Instruction))
238+
wrapped := WithToolsFilter(tool, toolset.Tools...)
239+
wrapped = WithInstructions(wrapped, a.Instruction)
240+
241+
t = append(t, wrapped)
239242
}
240243

241244
if !a.CodeModeTools && !runtimeConfig.GlobalCodeMode {
@@ -312,7 +315,7 @@ func createTool(ctx context.Context, toolset latest.Toolset, parentDir string, e
312315
}
313316
}
314317

315-
opts := []builtin.FileSystemOpt{builtin.WithAllowedTools(toolset.Tools)}
318+
var opts []builtin.FileSystemOpt
316319
if len(toolset.PostEdit) > 0 {
317320
postEditConfigs := make([]builtin.PostEditConfig, len(toolset.PostEdit))
318321
for i, pe := range toolset.PostEdit {
@@ -343,13 +346,13 @@ func createTool(ctx context.Context, toolset latest.Toolset, parentDir string, e
343346

344347
// TODO(dga): until the MCP Gateway supports oauth with cagent, we fetch the remote url and directly connect to it.
345348
if serverSpec.Type == "remote" {
346-
return mcp.NewRemoteToolset(serverSpec.Remote.URL, serverSpec.Remote.TransportType, nil, toolset.Tools, runtimeConfig.RedirectURI)
349+
return mcp.NewRemoteToolset(serverSpec.Remote.URL, serverSpec.Remote.TransportType, nil, runtimeConfig.RedirectURI), nil
347350
}
348351

349-
return mcp.NewGatewayToolset(mcpServerName, toolset.Config, toolset.Tools, envProvider), nil
352+
return mcp.NewGatewayToolset(ctx, mcpServerName, toolset.Config, envProvider)
350353

351354
case toolset.Type == "mcp" && toolset.Command != "":
352-
return mcp.NewToolsetCommand(toolset.Command, toolset.Args, env, toolset.Tools), nil
355+
return mcp.NewToolsetCommand(toolset.Command, toolset.Args, env), nil
353356

354357
case toolset.Type == "mcp" && toolset.Remote.URL != "":
355358
headers := map[string]string{}
@@ -362,7 +365,7 @@ func createTool(ctx context.Context, toolset latest.Toolset, parentDir string, e
362365
headers[k] = expanded
363366
}
364367

365-
return mcp.NewRemoteToolset(toolset.Remote.URL, toolset.Remote.TransportType, headers, toolset.Tools, runtimeConfig.RedirectURI)
368+
return mcp.NewRemoteToolset(toolset.Remote.URL, toolset.Remote.TransportType, headers, runtimeConfig.RedirectURI), nil
366369

367370
default:
368371
return nil, fmt.Errorf("unknown toolset type: %s", toolset.Type)

pkg/tools/builtin/filesystem.go

Lines changed: 2 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,6 @@ import (
1010
"os/exec"
1111
"path/filepath"
1212
"regexp"
13-
"slices"
1413
"strings"
1514
"time"
1615

@@ -27,20 +26,13 @@ type FilesystemTool struct {
2726
tools.ElicitationTool
2827

2928
allowedDirectories []string
30-
allowedTools []string
3129
postEditCommands []PostEditConfig
3230
}
3331

3432
var _ tools.ToolSet = (*FilesystemTool)(nil)
3533

3634
type FileSystemOpt func(*FilesystemTool)
3735

38-
func WithAllowedTools(allowedTools []string) FileSystemOpt {
39-
return func(t *FilesystemTool) {
40-
t.allowedTools = allowedTools
41-
}
42-
}
43-
4436
func WithPostEditCommands(postEditCommands []PostEditConfig) FileSystemOpt {
4537
return func(t *FilesystemTool) {
4638
t.postEditCommands = postEditCommands
@@ -150,7 +142,7 @@ type EditFileArgs struct {
150142
}
151143

152144
func (t *FilesystemTool) Tools(context.Context) ([]tools.Tool, error) {
153-
tls := []tools.Tool{
145+
return []tools.Tool{
154146
{
155147
Name: "create_directory",
156148
Category: "filesystem",
@@ -337,20 +329,7 @@ func (t *FilesystemTool) Tools(context.Context) ([]tools.Tool, error) {
337329
Title: "Write File",
338330
},
339331
},
340-
}
341-
342-
if len(t.allowedTools) == 0 {
343-
return tls, nil
344-
}
345-
346-
var allowedTools []tools.Tool
347-
for _, tool := range tls {
348-
if slices.Contains(t.allowedTools, tool.Name) {
349-
allowedTools = append(allowedTools, tool)
350-
}
351-
}
352-
353-
return allowedTools, nil
332+
}, nil
354333
}
355334

356335
// executePostEditCommands executes any matching post-edit commands for the given file path

pkg/tools/builtin/filesystem_test.go

Lines changed: 0 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1137,17 +1137,6 @@ func TestFilesystemTool_AddAllowedDirectory(t *testing.T) {
11371137
})
11381138
}
11391139

1140-
func TestFilesystemTool_FilterTools(t *testing.T) {
1141-
allowedDirs := []string{"/tmp"}
1142-
tool := NewFilesystemTool(allowedDirs, WithAllowedTools([]string{"list_allowed_directories"}))
1143-
1144-
allTools, err := tool.Tools(t.Context())
1145-
require.NoError(t, err)
1146-
require.Len(t, allTools, 1)
1147-
require.Equal(t, "list_allowed_directories", allTools[0].Name)
1148-
require.NotNil(t, allTools[0].Handler)
1149-
}
1150-
11511140
func TestMatchExcludePattern(t *testing.T) {
11521141
tests := []struct {
11531142
name string

0 commit comments

Comments
 (0)