|
6 | 6 | "iter" |
7 | 7 | "log/slog" |
8 | 8 | "net/http" |
| 9 | + "strings" |
9 | 10 | "sync" |
| 11 | + "time" |
10 | 12 |
|
11 | 13 | "github.com/modelcontextprotocol/go-sdk/mcp" |
12 | 14 |
|
@@ -84,47 +86,76 @@ func (c *remoteMCPClient) Initialize(ctx context.Context, _ *mcp.InitializeReque |
84 | 86 | // Create HTTP client with OAuth support |
85 | 87 | httpClient := c.createHTTPClient() |
86 | 88 |
|
87 | | - var transport mcp.Transport |
| 89 | + // Attempt MCP initialization with retry logic for OAuth-related failures. |
| 90 | + // When a server requires OAuth, the first connection attempt may fail with a "broken session" |
| 91 | + // error because OAuth flow (even successful flow) interrupts the MCP handshake. Let's retry once after OAuth completes. |
| 92 | + // Example of such MCP Server that broke the session with OAuth flow: https://mcp.prisma.io/mcp |
| 93 | + const maxAttempts = 2 |
| 94 | + var lastErr error |
88 | 95 |
|
89 | | - switch c.transportType { |
90 | | - case "sse": |
91 | | - transport = &mcp.SSEClientTransport{ |
92 | | - Endpoint: c.url, |
93 | | - HTTPClient: httpClient, |
| 96 | + for attempt := 1; attempt <= maxAttempts; attempt++ { |
| 97 | + if attempt > 1 { |
| 98 | + slog.Debug("Retrying MCP initialization after OAuth flow", "attempt", attempt) |
94 | 99 | } |
95 | | - case "streamable", "streamable-http": |
96 | | - transport = &mcp.StreamableClientTransport{ |
97 | | - Endpoint: c.url, |
98 | | - HTTPClient: httpClient, |
| 100 | + |
| 101 | + var transport mcp.Transport |
| 102 | + |
| 103 | + switch c.transportType { |
| 104 | + case "sse": |
| 105 | + transport = &mcp.SSEClientTransport{ |
| 106 | + Endpoint: c.url, |
| 107 | + HTTPClient: httpClient, |
| 108 | + } |
| 109 | + case "streamable", "streamable-http": |
| 110 | + transport = &mcp.StreamableClientTransport{ |
| 111 | + Endpoint: c.url, |
| 112 | + HTTPClient: httpClient, |
| 113 | + } |
| 114 | + default: |
| 115 | + return nil, fmt.Errorf("unsupported transport type: %s", c.transportType) |
99 | 116 | } |
100 | | - default: |
101 | | - return nil, fmt.Errorf("unsupported transport type: %s", c.transportType) |
102 | | - } |
103 | 117 |
|
104 | | - // Create an MCP client with elicitation support |
105 | | - impl := &mcp.Implementation{ |
106 | | - Name: "cagent", |
107 | | - Version: "1.0.0", |
108 | | - } |
| 118 | + // Create an MCP client with elicitation support |
| 119 | + impl := &mcp.Implementation{ |
| 120 | + Name: "cagent", |
| 121 | + Version: "1.0.0", |
| 122 | + } |
109 | 123 |
|
110 | | - opts := &mcp.ClientOptions{ |
111 | | - ElicitationHandler: c.handleElicitationRequest, |
112 | | - } |
| 124 | + opts := &mcp.ClientOptions{ |
| 125 | + ElicitationHandler: c.handleElicitationRequest, |
| 126 | + } |
113 | 127 |
|
114 | | - client := mcp.NewClient(impl, opts) |
| 128 | + client := mcp.NewClient(impl, opts) |
| 129 | + |
| 130 | + // Connect to the MCP server |
| 131 | + session, err := client.Connect(ctx, transport, nil) |
| 132 | + if err != nil { |
| 133 | + lastErr = err |
| 134 | + |
| 135 | + // Check if this is a "broken session" error that might be OAuth-related |
| 136 | + if attempt < maxAttempts && isBrokenSessionError(err) { |
| 137 | + slog.Debug("MCP connection failed with broken session error, retrying after OAuth", "error", err) |
| 138 | + // Brief pause before retry to allow OAuth state to settle |
| 139 | + select { |
| 140 | + case <-ctx.Done(): |
| 141 | + return nil, fmt.Errorf("failed to connect to MCP server: %w", ctx.Err()) |
| 142 | + case <-time.After(100 * time.Millisecond): |
| 143 | + } |
| 144 | + continue |
| 145 | + } |
| 146 | + |
| 147 | + return nil, fmt.Errorf("failed to connect to MCP server: %w", err) |
| 148 | + } |
115 | 149 |
|
116 | | - // Connect to the MCP server |
117 | | - session, err := client.Connect(ctx, transport, nil) |
118 | | - if err != nil { |
119 | | - return nil, fmt.Errorf("failed to connect to MCP server: %w", err) |
120 | | - } |
| 150 | + c.mu.Lock() |
| 151 | + c.session = session |
| 152 | + c.mu.Unlock() |
121 | 153 |
|
122 | | - c.mu.Lock() |
123 | | - c.session = session |
124 | | - c.mu.Unlock() |
| 154 | + slog.Debug("Remote MCP client connected successfully", "attempt", attempt) |
| 155 | + return session.InitializeResult(), nil |
| 156 | + } |
125 | 157 |
|
126 | | - slog.Debug("Remote MCP client connected successfully") |
127 | | - return session.InitializeResult(), nil |
| 158 | + return nil, fmt.Errorf("failed to connect to MCP server after %d attempts: %w", maxAttempts, lastErr) |
128 | 159 | } |
129 | 160 |
|
130 | 161 | // createHTTPClient creates an HTTP client with OAuth support |
@@ -194,3 +225,15 @@ func (c *remoteMCPClient) requestUserConsent(ctx context.Context) (bool, error) |
194 | 225 |
|
195 | 226 | return result.Action == "accept", nil |
196 | 227 | } |
| 228 | + |
| 229 | +// isBrokenSessionError checks if an error is a "broken session" error from the MCP SDK |
| 230 | +// This error typically occurs when OAuth interrupts the MCP session handshake |
| 231 | +func isBrokenSessionError(err error) bool { |
| 232 | + if err == nil { |
| 233 | + return false |
| 234 | + } |
| 235 | + errMsg := strings.ToLower(err.Error()) |
| 236 | + // The error message comes from mcp-go/mcp/streamable.go:1211 |
| 237 | + // "broken session: 400 Bad Request" |
| 238 | + return strings.Contains(errMsg, "broken session") |
| 239 | +} |
0 commit comments