Skip to content

Commit 1c92feb

Browse files
committed
Fix OAuth flow breaking MCP session initialization
When remote MCP servers require OAuth authentication, the OAuth flow was being triggered during the MCP session initialization handshake. This caused the session state to become corrupted because: 1. client.Connect() starts the MCP protocol handshake 2. Initial HTTP request returns 401 Unauthorized 3. oauthTransport intercepts and runs the full OAuth flow 4. After OAuth succeeds, the retry finds the MCP session in a broken state 5. Subsequent requests fail with "400 Bad Request: broken session" This fix adds retry logic that detects "broken session" errors during initialization and automatically retries once after OAuth completes, ensuring: - OAuth completes at the HTTP transport layer first - MCP session initialization happens with authentication already in place - No permanent session corruption from the OAuth interruption The retry is limited to OAuth-related "broken session" errors to avoid masking other legitimate connection failures. Fixes the issue where OAuth-protected MCP servers (like mcp.prisma.io) would fail to initialize despite successful user authentication.
1 parent 6bb603b commit 1c92feb

1 file changed

Lines changed: 75 additions & 32 deletions

File tree

pkg/tools/mcp/remote.go

Lines changed: 75 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,9 @@ import (
66
"iter"
77
"log/slog"
88
"net/http"
9+
"strings"
910
"sync"
11+
"time"
1012

1113
"github.com/modelcontextprotocol/go-sdk/mcp"
1214

@@ -84,47 +86,76 @@ func (c *remoteMCPClient) Initialize(ctx context.Context, request *mcp.Initializ
8486
// Create HTTP client with OAuth support
8587
httpClient := c.createHTTPClient()
8688

87-
var transport mcp.Transport
89+
// Attempt MCP initialization with retry logic for OAuth-related failures.
90+
// When a server requires OAuth, the first connection attempt may fail with a "broken session"
91+
// error because OAuth flow (even successful flow) interrupts the MCP handshake. Let's retry once after OAuth completes.
92+
// Example of such MCP Server that broke the session with OAuth flow: https://mcp.prisma.io/mcp
93+
const maxAttempts = 2
94+
var lastErr error
8895

89-
switch c.transportType {
90-
case "sse":
91-
transport = &mcp.SSEClientTransport{
92-
Endpoint: c.url,
93-
HTTPClient: httpClient,
96+
for attempt := 1; attempt <= maxAttempts; attempt++ {
97+
if attempt > 1 {
98+
slog.Debug("Retrying MCP initialization after OAuth flow", "attempt", attempt)
9499
}
95-
case "streamable", "streamable-http":
96-
transport = &mcp.StreamableClientTransport{
97-
Endpoint: c.url,
98-
HTTPClient: httpClient,
100+
101+
var transport mcp.Transport
102+
103+
switch c.transportType {
104+
case "sse":
105+
transport = &mcp.SSEClientTransport{
106+
Endpoint: c.url,
107+
HTTPClient: httpClient,
108+
}
109+
case "streamable", "streamable-http":
110+
transport = &mcp.StreamableClientTransport{
111+
Endpoint: c.url,
112+
HTTPClient: httpClient,
113+
}
114+
default:
115+
return nil, fmt.Errorf("unsupported transport type: %s", c.transportType)
99116
}
100-
default:
101-
return nil, fmt.Errorf("unsupported transport type: %s", c.transportType)
102-
}
103117

104-
// Create an MCP client with elicitation support
105-
impl := &mcp.Implementation{
106-
Name: "cagent",
107-
Version: "1.0.0",
108-
}
118+
// Create an MCP client with elicitation support
119+
impl := &mcp.Implementation{
120+
Name: "cagent",
121+
Version: "1.0.0",
122+
}
109123

110-
opts := &mcp.ClientOptions{
111-
ElicitationHandler: c.handleElicitationRequest,
112-
}
124+
opts := &mcp.ClientOptions{
125+
ElicitationHandler: c.handleElicitationRequest,
126+
}
113127

114-
client := mcp.NewClient(impl, opts)
128+
client := mcp.NewClient(impl, opts)
129+
130+
// Connect to the MCP server
131+
session, err := client.Connect(ctx, transport, nil)
132+
if err != nil {
133+
lastErr = err
134+
135+
// Check if this is a "broken session" error that might be OAuth-related
136+
if attempt < maxAttempts && isBrokenSessionError(err) {
137+
slog.Debug("MCP connection failed with broken session error, retrying after OAuth", "error", err)
138+
// Brief pause before retry to allow OAuth state to settle
139+
select {
140+
case <-ctx.Done():
141+
return nil, fmt.Errorf("failed to connect to MCP server: %w", ctx.Err())
142+
case <-time.After(100 * time.Millisecond):
143+
}
144+
continue
145+
}
146+
147+
return nil, fmt.Errorf("failed to connect to MCP server: %w", err)
148+
}
115149

116-
// Connect to the MCP server
117-
session, err := client.Connect(ctx, transport, nil)
118-
if err != nil {
119-
return nil, fmt.Errorf("failed to connect to MCP server: %w", err)
120-
}
150+
c.mu.Lock()
151+
c.session = session
152+
c.mu.Unlock()
121153

122-
c.mu.Lock()
123-
c.session = session
124-
c.mu.Unlock()
154+
slog.Debug("Remote MCP client connected successfully", "attempt", attempt)
155+
return session.InitializeResult(), nil
156+
}
125157

126-
slog.Debug("Remote MCP client connected successfully")
127-
return session.InitializeResult(), nil
158+
return nil, fmt.Errorf("failed to connect to MCP server after %d attempts: %w", maxAttempts, lastErr)
128159
}
129160

130161
// createHTTPClient creates an HTTP client with OAuth support
@@ -194,3 +225,15 @@ func (c *remoteMCPClient) requestUserConsent(ctx context.Context) (bool, error)
194225

195226
return result.Action == "accept", nil
196227
}
228+
229+
// isBrokenSessionError checks if an error is a "broken session" error from the MCP SDK
230+
// This error typically occurs when OAuth interrupts the MCP session handshake
231+
func isBrokenSessionError(err error) bool {
232+
if err == nil {
233+
return false
234+
}
235+
errMsg := strings.ToLower(err.Error())
236+
// The error message comes from mcp-go/mcp/streamable.go:1211
237+
// "broken session: 400 Bad Request"
238+
return strings.Contains(errMsg, "broken session")
239+
}

0 commit comments

Comments
 (0)