Skip to content

Commit f93dca4

Browse files
authored
Merge pull request #490 from trungutt/fix-oauth-mcp-initialization-retry
Fix OAuth flow breaking MCP session initialization
2 parents 9e34cf1 + 1c92feb commit f93dca4

1 file changed

Lines changed: 75 additions & 32 deletions

File tree

pkg/tools/mcp/remote.go

Lines changed: 75 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,9 @@ import (
66
"iter"
77
"log/slog"
88
"net/http"
9+
"strings"
910
"sync"
11+
"time"
1012

1113
"github.com/modelcontextprotocol/go-sdk/mcp"
1214

@@ -84,47 +86,76 @@ func (c *remoteMCPClient) Initialize(ctx context.Context, _ *mcp.InitializeReque
8486
// Create HTTP client with OAuth support
8587
httpClient := c.createHTTPClient()
8688

87-
var transport mcp.Transport
89+
// Attempt MCP initialization with retry logic for OAuth-related failures.
90+
// When a server requires OAuth, the first connection attempt may fail with a "broken session"
91+
// error because OAuth flow (even successful flow) interrupts the MCP handshake. Let's retry once after OAuth completes.
92+
// Example of such MCP Server that broke the session with OAuth flow: https://mcp.prisma.io/mcp
93+
const maxAttempts = 2
94+
var lastErr error
8895

89-
switch c.transportType {
90-
case "sse":
91-
transport = &mcp.SSEClientTransport{
92-
Endpoint: c.url,
93-
HTTPClient: httpClient,
96+
for attempt := 1; attempt <= maxAttempts; attempt++ {
97+
if attempt > 1 {
98+
slog.Debug("Retrying MCP initialization after OAuth flow", "attempt", attempt)
9499
}
95-
case "streamable", "streamable-http":
96-
transport = &mcp.StreamableClientTransport{
97-
Endpoint: c.url,
98-
HTTPClient: httpClient,
100+
101+
var transport mcp.Transport
102+
103+
switch c.transportType {
104+
case "sse":
105+
transport = &mcp.SSEClientTransport{
106+
Endpoint: c.url,
107+
HTTPClient: httpClient,
108+
}
109+
case "streamable", "streamable-http":
110+
transport = &mcp.StreamableClientTransport{
111+
Endpoint: c.url,
112+
HTTPClient: httpClient,
113+
}
114+
default:
115+
return nil, fmt.Errorf("unsupported transport type: %s", c.transportType)
99116
}
100-
default:
101-
return nil, fmt.Errorf("unsupported transport type: %s", c.transportType)
102-
}
103117

104-
// Create an MCP client with elicitation support
105-
impl := &mcp.Implementation{
106-
Name: "cagent",
107-
Version: "1.0.0",
108-
}
118+
// Create an MCP client with elicitation support
119+
impl := &mcp.Implementation{
120+
Name: "cagent",
121+
Version: "1.0.0",
122+
}
109123

110-
opts := &mcp.ClientOptions{
111-
ElicitationHandler: c.handleElicitationRequest,
112-
}
124+
opts := &mcp.ClientOptions{
125+
ElicitationHandler: c.handleElicitationRequest,
126+
}
113127

114-
client := mcp.NewClient(impl, opts)
128+
client := mcp.NewClient(impl, opts)
129+
130+
// Connect to the MCP server
131+
session, err := client.Connect(ctx, transport, nil)
132+
if err != nil {
133+
lastErr = err
134+
135+
// Check if this is a "broken session" error that might be OAuth-related
136+
if attempt < maxAttempts && isBrokenSessionError(err) {
137+
slog.Debug("MCP connection failed with broken session error, retrying after OAuth", "error", err)
138+
// Brief pause before retry to allow OAuth state to settle
139+
select {
140+
case <-ctx.Done():
141+
return nil, fmt.Errorf("failed to connect to MCP server: %w", ctx.Err())
142+
case <-time.After(100 * time.Millisecond):
143+
}
144+
continue
145+
}
146+
147+
return nil, fmt.Errorf("failed to connect to MCP server: %w", err)
148+
}
115149

116-
// Connect to the MCP server
117-
session, err := client.Connect(ctx, transport, nil)
118-
if err != nil {
119-
return nil, fmt.Errorf("failed to connect to MCP server: %w", err)
120-
}
150+
c.mu.Lock()
151+
c.session = session
152+
c.mu.Unlock()
121153

122-
c.mu.Lock()
123-
c.session = session
124-
c.mu.Unlock()
154+
slog.Debug("Remote MCP client connected successfully", "attempt", attempt)
155+
return session.InitializeResult(), nil
156+
}
125157

126-
slog.Debug("Remote MCP client connected successfully")
127-
return session.InitializeResult(), nil
158+
return nil, fmt.Errorf("failed to connect to MCP server after %d attempts: %w", maxAttempts, lastErr)
128159
}
129160

130161
// createHTTPClient creates an HTTP client with OAuth support
@@ -194,3 +225,15 @@ func (c *remoteMCPClient) requestUserConsent(ctx context.Context) (bool, error)
194225

195226
return result.Action == "accept", nil
196227
}
228+
229+
// isBrokenSessionError checks if an error is a "broken session" error from the MCP SDK
230+
// This error typically occurs when OAuth interrupts the MCP session handshake
231+
func isBrokenSessionError(err error) bool {
232+
if err == nil {
233+
return false
234+
}
235+
errMsg := strings.ToLower(err.Error())
236+
// The error message comes from mcp-go/mcp/streamable.go:1211
237+
// "broken session: 400 Bad Request"
238+
return strings.Contains(errMsg, "broken session")
239+
}

0 commit comments

Comments
 (0)