Skip to content

Commit e89fada

Browse files
committed
Extract registry code
Signed-off-by: David Gageot <david.gageot@docker.com>
1 parent c4de15f commit e89fada

2 files changed

Lines changed: 225 additions & 210 deletions

File tree

pkg/teamloader/registry.go

Lines changed: 225 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,225 @@
1+
package teamloader
2+
3+
import (
4+
"context"
5+
"fmt"
6+
"os"
7+
"path/filepath"
8+
"time"
9+
10+
"github.com/docker/cagent/pkg/config"
11+
latest "github.com/docker/cagent/pkg/config/v2"
12+
"github.com/docker/cagent/pkg/environment"
13+
"github.com/docker/cagent/pkg/gateway"
14+
"github.com/docker/cagent/pkg/js"
15+
"github.com/docker/cagent/pkg/memory/database/sqlite"
16+
"github.com/docker/cagent/pkg/path"
17+
"github.com/docker/cagent/pkg/tools"
18+
"github.com/docker/cagent/pkg/tools/builtin"
19+
"github.com/docker/cagent/pkg/tools/mcp"
20+
)
21+
22+
// ToolsetCreator is a function that creates a toolset based on the provided configuration
23+
type ToolsetCreator func(ctx context.Context, toolset latest.Toolset, parentDir string, envProvider environment.Provider, runtimeConfig config.RuntimeConfig) (tools.ToolSet, error)
24+
25+
// ToolsetRegistry manages the registration of toolset creators by type
26+
type ToolsetRegistry struct {
27+
creators map[string]ToolsetCreator
28+
}
29+
30+
// NewToolsetRegistry creates a new empty toolset registry
31+
func NewToolsetRegistry() *ToolsetRegistry {
32+
return &ToolsetRegistry{
33+
creators: make(map[string]ToolsetCreator),
34+
}
35+
}
36+
37+
// Register adds a new toolset creator for the given type
38+
func (r *ToolsetRegistry) Register(toolsetType string, creator ToolsetCreator) {
39+
r.creators[toolsetType] = creator
40+
}
41+
42+
// Get retrieves a toolset creator for the given type
43+
func (r *ToolsetRegistry) Get(toolsetType string) (ToolsetCreator, bool) {
44+
creator, ok := r.creators[toolsetType]
45+
return creator, ok
46+
}
47+
48+
// CreateTool creates a toolset using the registered creator for the given type
49+
func (r *ToolsetRegistry) CreateTool(ctx context.Context, toolset latest.Toolset, parentDir string, envProvider environment.Provider, runtimeConfig config.RuntimeConfig) (tools.ToolSet, error) {
50+
creator, ok := r.Get(toolset.Type)
51+
if !ok {
52+
return nil, fmt.Errorf("unknown toolset type: %s", toolset.Type)
53+
}
54+
return creator(ctx, toolset, parentDir, envProvider, runtimeConfig)
55+
}
56+
57+
func NewDefaultToolsetRegistry() *ToolsetRegistry {
58+
r := NewToolsetRegistry()
59+
// Register all built-in toolset creators
60+
r.Register("todo", createTodoTool)
61+
r.Register("memory", createMemoryTool)
62+
r.Register("think", createThinkTool)
63+
r.Register("shell", createShellTool)
64+
r.Register("script", createScriptTool)
65+
r.Register("filesystem", createFilesystemTool)
66+
r.Register("fetch", createFetchTool)
67+
r.Register("mcp", createMCPTool)
68+
r.Register("api", createAPITool)
69+
return r
70+
}
71+
72+
func createTodoTool(ctx context.Context, toolset latest.Toolset, parentDir string, envProvider environment.Provider, runtimeConfig config.RuntimeConfig) (tools.ToolSet, error) {
73+
if toolset.Shared {
74+
return builtin.NewSharedTodoTool(), nil
75+
}
76+
return builtin.NewTodoTool(), nil
77+
}
78+
79+
func createMemoryTool(ctx context.Context, toolset latest.Toolset, parentDir string, envProvider environment.Provider, runtimeConfig config.RuntimeConfig) (tools.ToolSet, error) {
80+
var memoryPath string
81+
if filepath.IsAbs(toolset.Path) {
82+
memoryPath = ""
83+
} else if wd, err := os.Getwd(); err == nil {
84+
memoryPath = wd
85+
} else {
86+
memoryPath = parentDir
87+
}
88+
89+
validatedMemoryPath, err := path.ValidatePathInDirectory(toolset.Path, memoryPath)
90+
if err != nil {
91+
return nil, fmt.Errorf("invalid memory database path: %w", err)
92+
}
93+
if err := os.MkdirAll(filepath.Dir(validatedMemoryPath), 0o700); err != nil {
94+
return nil, fmt.Errorf("failed to create memory database directory: %w", err)
95+
}
96+
97+
db, err := sqlite.NewMemoryDatabase(validatedMemoryPath)
98+
if err != nil {
99+
return nil, fmt.Errorf("failed to create memory database: %w", err)
100+
}
101+
102+
return builtin.NewMemoryTool(db), nil
103+
}
104+
105+
func createThinkTool(ctx context.Context, toolset latest.Toolset, parentDir string, envProvider environment.Provider, runtimeConfig config.RuntimeConfig) (tools.ToolSet, error) {
106+
return builtin.NewThinkTool(), nil
107+
}
108+
109+
func createShellTool(ctx context.Context, toolset latest.Toolset, parentDir string, envProvider environment.Provider, runtimeConfig config.RuntimeConfig) (tools.ToolSet, error) {
110+
env, err := environment.ExpandAll(ctx, environment.ToValues(toolset.Env), envProvider)
111+
if err != nil {
112+
return nil, fmt.Errorf("failed to expand the tool's environment variables: %w", err)
113+
}
114+
env = append(env, os.Environ()...)
115+
return builtin.NewShellTool(env), nil
116+
}
117+
118+
func createScriptTool(ctx context.Context, toolset latest.Toolset, parentDir string, envProvider environment.Provider, runtimeConfig config.RuntimeConfig) (tools.ToolSet, error) {
119+
if len(toolset.Shell) == 0 {
120+
return nil, fmt.Errorf("shell is required for script toolset")
121+
}
122+
123+
env, err := environment.ExpandAll(ctx, environment.ToValues(toolset.Env), envProvider)
124+
if err != nil {
125+
return nil, fmt.Errorf("failed to expand the tool's environment variables: %w", err)
126+
}
127+
env = append(env, os.Environ()...)
128+
return builtin.NewScriptShellTool(toolset.Shell, env), nil
129+
}
130+
131+
func createFilesystemTool(ctx context.Context, toolset latest.Toolset, parentDir string, envProvider environment.Provider, runtimeConfig config.RuntimeConfig) (tools.ToolSet, error) {
132+
wd := runtimeConfig.WorkingDir
133+
if wd == "" {
134+
var err error
135+
wd, err = os.Getwd()
136+
if err != nil {
137+
return nil, fmt.Errorf("failed to get working directory: %w", err)
138+
}
139+
}
140+
141+
var opts []builtin.FileSystemOpt
142+
143+
// Handle ignore_vcs configuration (default to true)
144+
ignoreVCS := true
145+
if toolset.IgnoreVCS != nil {
146+
ignoreVCS = *toolset.IgnoreVCS
147+
}
148+
opts = append(opts, builtin.WithIgnoreVCS(ignoreVCS))
149+
150+
// Handle post-edit commands
151+
if len(toolset.PostEdit) > 0 {
152+
postEditConfigs := make([]builtin.PostEditConfig, len(toolset.PostEdit))
153+
for i, pe := range toolset.PostEdit {
154+
postEditConfigs[i] = builtin.PostEditConfig{
155+
Path: pe.Path,
156+
Cmd: pe.Cmd,
157+
}
158+
}
159+
opts = append(opts, builtin.WithPostEditCommands(postEditConfigs))
160+
}
161+
162+
return builtin.NewFilesystemTool([]string{wd}, opts...), nil
163+
}
164+
165+
func createAPITool(ctx context.Context, toolset latest.Toolset, parentDir string, envProvider environment.Provider, runtimeConfig config.RuntimeConfig) (tools.ToolSet, error) {
166+
if toolset.APIConfig.Endpoint == "" {
167+
return nil, fmt.Errorf("api tool requires an endpoint in api_config")
168+
}
169+
170+
toolset.APIConfig.Headers = js.Expand(ctx, toolset.APIConfig.Headers, envProvider)
171+
172+
return builtin.NewAPITool(toolset.APIConfig), nil
173+
}
174+
175+
func createFetchTool(ctx context.Context, toolset latest.Toolset, parentDir string, envProvider environment.Provider, runtimeConfig config.RuntimeConfig) (tools.ToolSet, error) {
176+
var opts []builtin.FetchToolOption
177+
if toolset.Timeout > 0 {
178+
timeout := time.Duration(toolset.Timeout) * time.Second
179+
opts = append(opts, builtin.WithTimeout(timeout))
180+
}
181+
return builtin.NewFetchTool(opts...), nil
182+
}
183+
184+
func createMCPTool(ctx context.Context, toolset latest.Toolset, parentDir string, envProvider environment.Provider, runtimeConfig config.RuntimeConfig) (tools.ToolSet, error) {
185+
// MCP tool has three different modes: ref, command, and remote
186+
if toolset.Ref != "" {
187+
mcpServerName := gateway.ParseServerRef(toolset.Ref)
188+
serverSpec, err := gateway.ServerSpec(ctx, mcpServerName)
189+
if err != nil {
190+
return nil, fmt.Errorf("fetching MCP server spec for %q: %w", mcpServerName, err)
191+
}
192+
193+
// TODO(dga): until the MCP Gateway supports oauth with cagent, we fetch the remote url and directly connect to it.
194+
if serverSpec.Type == "remote" {
195+
return mcp.NewRemoteToolset(serverSpec.Remote.URL, serverSpec.Remote.TransportType, nil), nil
196+
}
197+
198+
return mcp.NewGatewayToolset(ctx, mcpServerName, toolset.Config, envProvider)
199+
}
200+
201+
if toolset.Command != "" {
202+
env, err := environment.ExpandAll(ctx, environment.ToValues(toolset.Env), envProvider)
203+
if err != nil {
204+
return nil, fmt.Errorf("failed to expand the tool's environment variables: %w", err)
205+
}
206+
env = append(env, os.Environ()...)
207+
return mcp.NewToolsetCommand(toolset.Command, toolset.Args, env), nil
208+
}
209+
210+
if toolset.Remote.URL != "" {
211+
headers := map[string]string{}
212+
for k, v := range toolset.Remote.Headers {
213+
expanded, err := environment.Expand(ctx, v, envProvider)
214+
if err != nil {
215+
return nil, fmt.Errorf("failed to expand header '%s': %w", k, err)
216+
}
217+
218+
headers[k] = expanded
219+
}
220+
221+
return mcp.NewRemoteToolset(toolset.Remote.URL, toolset.Remote.TransportType, headers), nil
222+
}
223+
224+
return nil, fmt.Errorf("mcp toolset requires either ref, command, or remote configuration")
225+
}

0 commit comments

Comments
 (0)