Skip to content

Commit 1499d3f

Browse files
authored
Merge pull request #776 from dgageot/prepare-default-agent
Prepare for default agent
2 parents c4de15f + 089b601 commit 1499d3f

5 files changed

Lines changed: 330 additions & 224 deletions

File tree

cmd/root/run.go

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -86,18 +86,21 @@ func (f *runExecFlags) runOrExec(ctx context.Context, out *cli.Printer, args []s
8686
var rt runtime.Runtime
8787
var sess *session.Session
8888
var err error
89-
if f.remoteAddress != "" {
90-
rt, sess, err = f.createRemoteRuntimeAndSession(ctx, args[0])
89+
switch {
90+
case f.remoteAddress != "":
91+
agentFileName = args[0]
92+
rt, sess, err = f.createRemoteRuntimeAndSession(ctx, agentFileName)
9193
if err != nil {
9294
return err
9395
}
94-
} else {
96+
97+
default:
9598
agentFileName, err = f.resolveAgentFile(ctx, out, args[0])
9699
if err != nil {
97100
return err
98101
}
99102

100-
t, err := f.loadAgents(ctx, agentFileName)
103+
t, err := f.loadAgentFrom(ctx, teamloader.NewFileSource(agentFileName))
101104
if err != nil {
102105
return err
103106
}
@@ -133,8 +136,8 @@ func (f *runExecFlags) resolveAgentFile(ctx context.Context, out *cli.Printer, a
133136
return agentfile.Resolve(ctx, out, agentFilename)
134137
}
135138

136-
func (f *runExecFlags) loadAgents(ctx context.Context, agentFilename string) (*team.Team, error) {
137-
t, err := teamloader.Load(ctx, agentFilename, f.runConfig, teamloader.WithModelOverrides(f.modelOverrides))
139+
func (f *runExecFlags) loadAgentFrom(ctx context.Context, source teamloader.AgentSource) (*team.Team, error) {
140+
t, err := teamloader.LoadFrom(ctx, source, f.runConfig, teamloader.WithModelOverrides(f.modelOverrides))
138141
if err != nil {
139142
return nil, err
140143
}

pkg/config/config.go

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -19,22 +19,26 @@ import (
1919
func LoadConfig(path string, fs filesystem.FS) (*latest.Config, error) {
2020
data, err := fs.ReadFile(path)
2121
if err != nil {
22-
return nil, fmt.Errorf("reading config file: %w", err)
22+
return nil, fmt.Errorf("reading config file %s: %w", path, err)
2323
}
2424

25+
return LoadConfigBytes(data)
26+
}
27+
28+
func LoadConfigBytes(data []byte) (*latest.Config, error) {
2529
var raw struct {
2630
Version string `yaml:"version,omitempty"`
2731
}
2832
if err := yaml.UnmarshalWithOptions(data, &raw); err != nil {
29-
return nil, fmt.Errorf("looking for version in config file %s\n%s", path, yaml.FormatError(err, true, true))
33+
return nil, fmt.Errorf("looking for version in config file\n%s", yaml.FormatError(err, true, true))
3034
}
3135
if raw.Version == "" {
3236
raw.Version = latest.Version
3337
}
3438

3539
oldConfig, err := parseCurrentVersion(data, raw.Version)
3640
if err != nil {
37-
return nil, fmt.Errorf("parsing config file %s\n%s", path, yaml.FormatError(err, true, true))
41+
return nil, fmt.Errorf("parsing config file\n%s", yaml.FormatError(err, true, true))
3842
}
3943

4044
config, err := migrateToLatestConfig(oldConfig)

pkg/teamloader/registry.go

Lines changed: 225 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,225 @@
1+
package teamloader
2+
3+
import (
4+
"context"
5+
"fmt"
6+
"os"
7+
"path/filepath"
8+
"time"
9+
10+
"github.com/docker/cagent/pkg/config"
11+
latest "github.com/docker/cagent/pkg/config/v2"
12+
"github.com/docker/cagent/pkg/environment"
13+
"github.com/docker/cagent/pkg/gateway"
14+
"github.com/docker/cagent/pkg/js"
15+
"github.com/docker/cagent/pkg/memory/database/sqlite"
16+
"github.com/docker/cagent/pkg/path"
17+
"github.com/docker/cagent/pkg/tools"
18+
"github.com/docker/cagent/pkg/tools/builtin"
19+
"github.com/docker/cagent/pkg/tools/mcp"
20+
)
21+
22+
// ToolsetCreator is a function that creates a toolset based on the provided configuration
23+
type ToolsetCreator func(ctx context.Context, toolset latest.Toolset, parentDir string, envProvider environment.Provider, runtimeConfig config.RuntimeConfig) (tools.ToolSet, error)
24+
25+
// ToolsetRegistry manages the registration of toolset creators by type
26+
type ToolsetRegistry struct {
27+
creators map[string]ToolsetCreator
28+
}
29+
30+
// NewToolsetRegistry creates a new empty toolset registry
31+
func NewToolsetRegistry() *ToolsetRegistry {
32+
return &ToolsetRegistry{
33+
creators: make(map[string]ToolsetCreator),
34+
}
35+
}
36+
37+
// Register adds a new toolset creator for the given type
38+
func (r *ToolsetRegistry) Register(toolsetType string, creator ToolsetCreator) {
39+
r.creators[toolsetType] = creator
40+
}
41+
42+
// Get retrieves a toolset creator for the given type
43+
func (r *ToolsetRegistry) Get(toolsetType string) (ToolsetCreator, bool) {
44+
creator, ok := r.creators[toolsetType]
45+
return creator, ok
46+
}
47+
48+
// CreateTool creates a toolset using the registered creator for the given type
49+
func (r *ToolsetRegistry) CreateTool(ctx context.Context, toolset latest.Toolset, parentDir string, envProvider environment.Provider, runtimeConfig config.RuntimeConfig) (tools.ToolSet, error) {
50+
creator, ok := r.Get(toolset.Type)
51+
if !ok {
52+
return nil, fmt.Errorf("unknown toolset type: %s", toolset.Type)
53+
}
54+
return creator(ctx, toolset, parentDir, envProvider, runtimeConfig)
55+
}
56+
57+
func NewDefaultToolsetRegistry() *ToolsetRegistry {
58+
r := NewToolsetRegistry()
59+
// Register all built-in toolset creators
60+
r.Register("todo", createTodoTool)
61+
r.Register("memory", createMemoryTool)
62+
r.Register("think", createThinkTool)
63+
r.Register("shell", createShellTool)
64+
r.Register("script", createScriptTool)
65+
r.Register("filesystem", createFilesystemTool)
66+
r.Register("fetch", createFetchTool)
67+
r.Register("mcp", createMCPTool)
68+
r.Register("api", createAPITool)
69+
return r
70+
}
71+
72+
func createTodoTool(ctx context.Context, toolset latest.Toolset, parentDir string, envProvider environment.Provider, runtimeConfig config.RuntimeConfig) (tools.ToolSet, error) {
73+
if toolset.Shared {
74+
return builtin.NewSharedTodoTool(), nil
75+
}
76+
return builtin.NewTodoTool(), nil
77+
}
78+
79+
func createMemoryTool(ctx context.Context, toolset latest.Toolset, parentDir string, envProvider environment.Provider, runtimeConfig config.RuntimeConfig) (tools.ToolSet, error) {
80+
var memoryPath string
81+
if filepath.IsAbs(toolset.Path) {
82+
memoryPath = ""
83+
} else if wd, err := os.Getwd(); err == nil {
84+
memoryPath = wd
85+
} else {
86+
memoryPath = parentDir
87+
}
88+
89+
validatedMemoryPath, err := path.ValidatePathInDirectory(toolset.Path, memoryPath)
90+
if err != nil {
91+
return nil, fmt.Errorf("invalid memory database path: %w", err)
92+
}
93+
if err := os.MkdirAll(filepath.Dir(validatedMemoryPath), 0o700); err != nil {
94+
return nil, fmt.Errorf("failed to create memory database directory: %w", err)
95+
}
96+
97+
db, err := sqlite.NewMemoryDatabase(validatedMemoryPath)
98+
if err != nil {
99+
return nil, fmt.Errorf("failed to create memory database: %w", err)
100+
}
101+
102+
return builtin.NewMemoryTool(db), nil
103+
}
104+
105+
func createThinkTool(ctx context.Context, toolset latest.Toolset, parentDir string, envProvider environment.Provider, runtimeConfig config.RuntimeConfig) (tools.ToolSet, error) {
106+
return builtin.NewThinkTool(), nil
107+
}
108+
109+
func createShellTool(ctx context.Context, toolset latest.Toolset, parentDir string, envProvider environment.Provider, runtimeConfig config.RuntimeConfig) (tools.ToolSet, error) {
110+
env, err := environment.ExpandAll(ctx, environment.ToValues(toolset.Env), envProvider)
111+
if err != nil {
112+
return nil, fmt.Errorf("failed to expand the tool's environment variables: %w", err)
113+
}
114+
env = append(env, os.Environ()...)
115+
return builtin.NewShellTool(env), nil
116+
}
117+
118+
func createScriptTool(ctx context.Context, toolset latest.Toolset, parentDir string, envProvider environment.Provider, runtimeConfig config.RuntimeConfig) (tools.ToolSet, error) {
119+
if len(toolset.Shell) == 0 {
120+
return nil, fmt.Errorf("shell is required for script toolset")
121+
}
122+
123+
env, err := environment.ExpandAll(ctx, environment.ToValues(toolset.Env), envProvider)
124+
if err != nil {
125+
return nil, fmt.Errorf("failed to expand the tool's environment variables: %w", err)
126+
}
127+
env = append(env, os.Environ()...)
128+
return builtin.NewScriptShellTool(toolset.Shell, env), nil
129+
}
130+
131+
func createFilesystemTool(ctx context.Context, toolset latest.Toolset, parentDir string, envProvider environment.Provider, runtimeConfig config.RuntimeConfig) (tools.ToolSet, error) {
132+
wd := runtimeConfig.WorkingDir
133+
if wd == "" {
134+
var err error
135+
wd, err = os.Getwd()
136+
if err != nil {
137+
return nil, fmt.Errorf("failed to get working directory: %w", err)
138+
}
139+
}
140+
141+
var opts []builtin.FileSystemOpt
142+
143+
// Handle ignore_vcs configuration (default to true)
144+
ignoreVCS := true
145+
if toolset.IgnoreVCS != nil {
146+
ignoreVCS = *toolset.IgnoreVCS
147+
}
148+
opts = append(opts, builtin.WithIgnoreVCS(ignoreVCS))
149+
150+
// Handle post-edit commands
151+
if len(toolset.PostEdit) > 0 {
152+
postEditConfigs := make([]builtin.PostEditConfig, len(toolset.PostEdit))
153+
for i, pe := range toolset.PostEdit {
154+
postEditConfigs[i] = builtin.PostEditConfig{
155+
Path: pe.Path,
156+
Cmd: pe.Cmd,
157+
}
158+
}
159+
opts = append(opts, builtin.WithPostEditCommands(postEditConfigs))
160+
}
161+
162+
return builtin.NewFilesystemTool([]string{wd}, opts...), nil
163+
}
164+
165+
func createAPITool(ctx context.Context, toolset latest.Toolset, parentDir string, envProvider environment.Provider, runtimeConfig config.RuntimeConfig) (tools.ToolSet, error) {
166+
if toolset.APIConfig.Endpoint == "" {
167+
return nil, fmt.Errorf("api tool requires an endpoint in api_config")
168+
}
169+
170+
toolset.APIConfig.Headers = js.Expand(ctx, toolset.APIConfig.Headers, envProvider)
171+
172+
return builtin.NewAPITool(toolset.APIConfig), nil
173+
}
174+
175+
func createFetchTool(ctx context.Context, toolset latest.Toolset, parentDir string, envProvider environment.Provider, runtimeConfig config.RuntimeConfig) (tools.ToolSet, error) {
176+
var opts []builtin.FetchToolOption
177+
if toolset.Timeout > 0 {
178+
timeout := time.Duration(toolset.Timeout) * time.Second
179+
opts = append(opts, builtin.WithTimeout(timeout))
180+
}
181+
return builtin.NewFetchTool(opts...), nil
182+
}
183+
184+
func createMCPTool(ctx context.Context, toolset latest.Toolset, parentDir string, envProvider environment.Provider, runtimeConfig config.RuntimeConfig) (tools.ToolSet, error) {
185+
// MCP tool has three different modes: ref, command, and remote
186+
if toolset.Ref != "" {
187+
mcpServerName := gateway.ParseServerRef(toolset.Ref)
188+
serverSpec, err := gateway.ServerSpec(ctx, mcpServerName)
189+
if err != nil {
190+
return nil, fmt.Errorf("fetching MCP server spec for %q: %w", mcpServerName, err)
191+
}
192+
193+
// TODO(dga): until the MCP Gateway supports oauth with cagent, we fetch the remote url and directly connect to it.
194+
if serverSpec.Type == "remote" {
195+
return mcp.NewRemoteToolset(serverSpec.Remote.URL, serverSpec.Remote.TransportType, nil), nil
196+
}
197+
198+
return mcp.NewGatewayToolset(ctx, mcpServerName, toolset.Config, envProvider)
199+
}
200+
201+
if toolset.Command != "" {
202+
env, err := environment.ExpandAll(ctx, environment.ToValues(toolset.Env), envProvider)
203+
if err != nil {
204+
return nil, fmt.Errorf("failed to expand the tool's environment variables: %w", err)
205+
}
206+
env = append(env, os.Environ()...)
207+
return mcp.NewToolsetCommand(toolset.Command, toolset.Args, env), nil
208+
}
209+
210+
if toolset.Remote.URL != "" {
211+
headers := map[string]string{}
212+
for k, v := range toolset.Remote.Headers {
213+
expanded, err := environment.Expand(ctx, v, envProvider)
214+
if err != nil {
215+
return nil, fmt.Errorf("failed to expand header '%s': %w", k, err)
216+
}
217+
218+
headers[k] = expanded
219+
}
220+
221+
return mcp.NewRemoteToolset(toolset.Remote.URL, toolset.Remote.TransportType, headers), nil
222+
}
223+
224+
return nil, fmt.Errorf("mcp toolset requires either ref, command, or remote configuration")
225+
}

pkg/teamloader/sources.go

Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,77 @@
1+
package teamloader
2+
3+
import (
4+
"fmt"
5+
"os"
6+
"path/filepath"
7+
)
8+
9+
type AgentSource interface {
10+
Name() string
11+
ParentDir() string
12+
Read() ([]byte, error)
13+
}
14+
15+
// fileSource is used to load an agent configuration from a YAML file.
16+
type fileSource struct {
17+
path string
18+
}
19+
20+
func NewFileSource(path string) AgentSource {
21+
return fileSource{
22+
path: path,
23+
}
24+
}
25+
26+
func (a fileSource) Name() string {
27+
return filepath.Base(a.path)
28+
}
29+
30+
func (a fileSource) ParentDir() string {
31+
return filepath.Dir(a.path)
32+
}
33+
34+
func (a fileSource) Read() ([]byte, error) {
35+
parentDir := a.ParentDir()
36+
fs, err := os.OpenRoot(parentDir)
37+
if err != nil {
38+
return nil, fmt.Errorf("opening filesystem %s: %w", parentDir, err)
39+
}
40+
41+
fileName := a.Name()
42+
data, err := fs.ReadFile(fileName)
43+
if err != nil {
44+
return nil, fmt.Errorf("reading config file %s: %w", fileName, err)
45+
}
46+
47+
return data, nil
48+
}
49+
50+
// bytesSource is used to load an agent configuration from a []byte.
51+
type bytesSource struct {
52+
workingDir string
53+
data []byte
54+
}
55+
56+
func NewBytesSource(workingDir string, data []byte) AgentSource {
57+
// TODO(dga): this is not perfect but ok for now
58+
if workingDir == "" {
59+
workingDir = "."
60+
}
61+
return bytesSource{
62+
workingDir: workingDir,
63+
data: data,
64+
}
65+
}
66+
67+
func (a bytesSource) Name() string {
68+
return ""
69+
}
70+
71+
func (a bytesSource) ParentDir() string {
72+
return a.workingDir
73+
}
74+
75+
func (a bytesSource) Read() ([]byte, error) {
76+
return a.data, nil
77+
}

0 commit comments

Comments
 (0)