Skip to content

Commit 8a370b3

Browse files
authored
Merge pull request #231 from trungutt/handle-multiple-mcp-oauth
More specificity in toolset start stop calls
2 parents 40a1b47 + 0036a92 commit 8a370b3

3 files changed

Lines changed: 65 additions & 26 deletions

File tree

pkg/agent/agent.go

Lines changed: 45 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -5,20 +5,35 @@ import (
55
"fmt"
66
"log/slog"
77
"math/rand"
8-
"sync/atomic"
8+
"sync"
99

1010
"github.com/docker/cagent/pkg/memorymanager"
1111
"github.com/docker/cagent/pkg/model/provider"
1212
"github.com/docker/cagent/pkg/tools"
1313
)
1414

15+
// ToolSetStartupError wraps toolset startup failures with context
16+
type ToolSetStartupError struct {
17+
Err error
18+
Index int
19+
}
20+
21+
func (e *ToolSetStartupError) Error() string {
22+
return fmt.Sprintf("failed to start toolset: %v", e.Err)
23+
}
24+
25+
func (e *ToolSetStartupError) Unwrap() error {
26+
return e.Err
27+
}
28+
1529
// Agent represents an AI agent
1630
type Agent struct {
1731
name string
1832
description string
1933
instruction string
2034
toolsets []tools.ToolSet
21-
toolsetsStarted atomic.Bool
35+
startedToolsets map[tools.ToolSet]bool
36+
toolsetsMutex sync.RWMutex
2237
models []provider.Provider
2338
subAgents []*Agent
2439
parents []*Agent
@@ -31,8 +46,9 @@ type Agent struct {
3146
// New creates a new agent
3247
func New(name, prompt string, opts ...Opt) *Agent {
3348
agent := &Agent{
34-
name: name,
35-
instruction: prompt,
49+
name: name,
50+
instruction: prompt,
51+
startedToolsets: make(map[tools.ToolSet]bool),
3652
}
3753

3854
for _, opt := range opts {
@@ -129,31 +145,46 @@ func (a *Agent) ToolSets() []tools.ToolSet {
129145
}
130146

131147
func (a *Agent) ensureToolSetsAreStarted(ctx context.Context) error {
132-
if a.toolsetsStarted.Load() {
133-
return nil
134-
}
148+
a.toolsetsMutex.Lock()
149+
defer a.toolsetsMutex.Unlock()
150+
151+
for i, toolSet := range a.toolsets {
152+
// Skip if toolset is already started
153+
if a.startedToolsets[toolSet] {
154+
continue
155+
}
135156

136-
for _, toolSet := range a.toolsets {
137157
if err := toolSet.Start(ctx); err != nil {
138-
return fmt.Errorf("failed to start toolset: %w", err)
158+
return &ToolSetStartupError{
159+
Err: err,
160+
Index: i,
161+
}
139162
}
163+
164+
// Mark toolset as started
165+
a.startedToolsets[toolSet] = true
140166
}
141167

142-
a.toolsetsStarted.Store(true)
143168
return nil
144169
}
145170

146171
func (a *Agent) StopToolSets() error {
147-
if !a.toolsetsStarted.Load() {
148-
return nil
149-
}
172+
a.toolsetsMutex.Lock()
173+
defer a.toolsetsMutex.Unlock()
150174

151175
for _, toolSet := range a.toolsets {
176+
// Only stop toolsets that are marked as started
177+
if !a.startedToolsets[toolSet] {
178+
continue
179+
}
180+
152181
if err := toolSet.Stop(); err != nil {
153182
return fmt.Errorf("failed to stop toolset: %w", err)
154183
}
184+
185+
// Mark toolset as stopped
186+
a.startedToolsets[toolSet] = false
155187
}
156188

157-
a.toolsetsStarted.Store(false)
158189
return nil
159190
}

pkg/runtime/runtime.go

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -126,12 +126,18 @@ func (r *runtime) registerDefaultTools() {
126126
func (r *runtime) getAgentToolsWithOAuthHandling(ctx context.Context, a *agent.Agent) ([]tools.Tool, error) {
127127
agentTools, err := a.Tools(ctx)
128128
if err != nil {
129-
// If this is an OAuth authorization error, wrap it with server info
130-
if client.IsOAuthAuthorizationRequiredError(err) {
131-
// Try to find which toolset caused the OAuth error by checking each one
132-
for _, toolSet := range a.ToolSets() {
133-
if serverInfoProvider, ok := toolSet.(oauth.ServerInfoProvider); ok {
134-
return nil, oauth.WrapOAuthError(err, serverInfoProvider)
129+
// Check if this is a ToolSetStartupError
130+
var toolSetStartupErr *agent.ToolSetStartupError
131+
if errors.As(err, &toolSetStartupErr) {
132+
// Check if the inner error is an OAuth authorization error
133+
if client.IsOAuthAuthorizationRequiredError(toolSetStartupErr.Err) {
134+
// Use the index from ToolSetStartupError to get the specific toolset
135+
toolsets := a.ToolSets()
136+
if toolSetStartupErr.Index >= 0 && toolSetStartupErr.Index < len(toolsets) {
137+
toolSet := toolsets[toolSetStartupErr.Index]
138+
if serverInfoProvider, ok := toolSet.(oauth.ServerInfoProvider); ok {
139+
return nil, oauth.WrapOAuthError(toolSetStartupErr.Err, serverInfoProvider)
140+
}
135141
}
136142
}
137143
}

pkg/tools/mcp/toolset.go

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -74,24 +74,26 @@ func (t *Toolset) Tools(ctx context.Context) ([]tools.Tool, error) {
7474

7575
// Start starts the toolset
7676
func (t *Toolset) Start(ctx context.Context) error {
77-
slog.Debug("Starting MCP toolset")
77+
serverURL, _ := t.c.GetServerInfo()
78+
slog.Debug("Starting MCP toolset", "server", serverURL)
7879
err := t.c.Start(ctx)
7980
if err != nil {
80-
slog.Error("Failed to start MCP toolset", "error", err)
81+
slog.Error("Failed to start MCP toolset", "server", serverURL, "error", err)
8182
return err
8283
}
83-
slog.Debug("Started MCP toolset successfully")
84+
slog.Debug("Started MCP toolset successfully", "server", serverURL)
8485
return nil
8586
}
8687

8788
// Stop stops the toolset
8889
func (t *Toolset) Stop() error {
89-
slog.Debug("Stopping MCP toolset")
90+
serverURL, _ := t.c.GetServerInfo()
91+
slog.Debug("Stopping MCP toolset", "server", serverURL)
9092
err := t.c.Stop()
9193
if err != nil {
92-
slog.Error("Failed to stop MCP toolset", "error", err)
94+
slog.Error("Failed to stop MCP toolset", "server", serverURL, "error", err)
9395
return err
9496
}
95-
slog.Debug("Stopped MCP toolset successfully")
97+
slog.Debug("Stopped MCP toolset successfully", "server", serverURL)
9698
return nil
9799
}

0 commit comments

Comments
 (0)