Skip to content

Commit 22f8914

Browse files
authored
Merge pull request #1081 from jeanlaurent/mcp-ignore-headers
Apply custom HTTP headers to remote MCP server requests
2 parents e343cd7 + 20af7f7 commit 22f8914

3 files changed

Lines changed: 237 additions & 12 deletions

File tree

pkg/tools/mcp/mcp.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,7 @@ func NewRemoteToolset(name, url, transport string, headers map[string]string) *T
5656

5757
return &Toolset{
5858
name: name,
59-
mcpClient: newRemoteClient(url, transport, headers, NewInMemoryTokenStore(), false),
59+
mcpClient: newRemoteClient(url, transport, headers, NewInMemoryTokenStore()),
6060
logID: url,
6161
}
6262
}

pkg/tools/mcp/remote.go

Lines changed: 42 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -25,8 +25,8 @@ type remoteMCPClient struct {
2525
mu sync.RWMutex
2626
}
2727

28-
func newRemoteClient(url, transportType string, headers map[string]string, tokenStore OAuthTokenStore, managed bool) *remoteMCPClient {
29-
slog.Debug("Creating remote MCP client", "url", url, "transport", transportType, "headers", headers, "managed", managed)
28+
func newRemoteClient(url, transportType string, headers map[string]string, tokenStore OAuthTokenStore) *remoteMCPClient {
29+
slog.Debug("Creating remote MCP client", "url", url, "transport", transportType, "headers", headers)
3030

3131
if tokenStore == nil {
3232
tokenStore = NewInMemoryTokenStore()
@@ -37,7 +37,7 @@ func newRemoteClient(url, transportType string, headers map[string]string, token
3737
transportType: transportType,
3838
headers: headers,
3939
tokenStore: tokenStore,
40-
managed: managed,
40+
managed: false,
4141
}
4242
}
4343

@@ -123,16 +123,47 @@ func (c *remoteMCPClient) Initialize(ctx context.Context, _ *mcp.InitializeReque
123123
return session.InitializeResult(), nil
124124
}
125125

126-
// createHTTPClient creates an HTTP client with OAuth support
126+
// headerTransport is a RoundTripper that adds custom headers to all requests
127+
type headerTransport struct {
128+
base http.RoundTripper
129+
headers map[string]string
130+
}
131+
132+
func (t *headerTransport) RoundTrip(req *http.Request) (*http.Response, error) {
133+
// Clone the request to avoid modifying the original
134+
req = req.Clone(req.Context())
135+
136+
// Add custom headers
137+
for key, value := range t.headers {
138+
req.Header.Set(key, value)
139+
}
140+
141+
return t.base.RoundTrip(req)
142+
}
143+
144+
// createHTTPClient creates an HTTP client with custom headers and OAuth support
127145
func (c *remoteMCPClient) createHTTPClient() *http.Client {
146+
transport := http.DefaultTransport
147+
148+
// Add custom headers first
149+
if len(c.headers) > 0 {
150+
transport = &headerTransport{
151+
base: transport,
152+
headers: c.headers,
153+
}
154+
}
155+
156+
// Then wrap with OAuth support
157+
transport = &oauthTransport{
158+
base: transport,
159+
client: c,
160+
tokenStore: c.tokenStore,
161+
baseURL: c.url,
162+
managed: c.managed,
163+
}
164+
128165
return &http.Client{
129-
Transport: &oauthTransport{
130-
base: http.DefaultTransport,
131-
client: c,
132-
tokenStore: c.tokenStore,
133-
baseURL: c.url,
134-
managed: c.managed,
135-
},
166+
Transport: transport,
136167
}
137168
}
138169

pkg/tools/mcp/remote_test.go

Lines changed: 194 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,194 @@
1+
package mcp
2+
3+
import (
4+
"context"
5+
"fmt"
6+
"net/http"
7+
"net/http/httptest"
8+
"testing"
9+
"time"
10+
11+
"github.com/stretchr/testify/assert"
12+
"github.com/stretchr/testify/require"
13+
)
14+
15+
// TestRemoteClientCustomHeaders verifies that custom headers passed to the remote
16+
// MCP client are actually applied to HTTP requests sent to the MCP server.
17+
func TestRemoteClientCustomHeaders(t *testing.T) {
18+
var capturedRequest *http.Request
19+
requestCaptured := make(chan bool, 1)
20+
21+
// Create a test SSE server that captures the request
22+
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
23+
capturedRequest = r
24+
25+
// Send a minimal SSE response to satisfy the client
26+
w.Header().Set("Content-Type", "text/event-stream")
27+
w.WriteHeader(http.StatusOK)
28+
fmt.Fprintf(w, "event: endpoint\ndata: {\"uri\":\"/message\"}\n\n")
29+
30+
select {
31+
case requestCaptured <- true:
32+
default:
33+
}
34+
35+
// Keep the connection open briefly
36+
time.Sleep(100 * time.Millisecond)
37+
}))
38+
defer server.Close()
39+
40+
// Create remote client WITH custom headers
41+
expectedHeaders := map[string]string{
42+
"X-Test-Header": "test-value",
43+
"X-API-Key": "secret-key-12345",
44+
"Authorization": "Bearer custom-token",
45+
}
46+
47+
client := newRemoteClient(server.URL, "sse", expectedHeaders, NewInMemoryTokenStore())
48+
49+
ctx, cancel := context.WithTimeout(t.Context(), 2*time.Second)
50+
defer cancel()
51+
52+
// Try to initialize (which will make the HTTP request)
53+
// We don't care if it succeeds or fails, we just need it to make the request
54+
_, _ = client.Initialize(ctx, nil)
55+
56+
// Wait for the request to be captured
57+
select {
58+
case <-requestCaptured:
59+
// Verify that custom headers were applied
60+
for key, expectedValue := range expectedHeaders {
61+
actualValue := capturedRequest.Header.Get(key)
62+
assert.Equal(t, expectedValue, actualValue,
63+
"Expected header %s to have value %q, but got %q",
64+
key, expectedValue, actualValue)
65+
}
66+
case <-time.After(1 * time.Second):
67+
t.Fatal("Server did not receive request within timeout")
68+
}
69+
}
70+
71+
// TestRemoteClientHeadersWithStreamable verifies that custom headers work with streamable transport
72+
func TestRemoteClientHeadersWithStreamable(t *testing.T) {
73+
var capturedRequest *http.Request
74+
requestCaptured := make(chan bool, 1)
75+
76+
// Create a test server for streamable transport
77+
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
78+
capturedRequest = r
79+
80+
// Send a minimal response
81+
w.Header().Set("Content-Type", "application/json")
82+
w.WriteHeader(http.StatusOK)
83+
fmt.Fprintf(w, `{"jsonrpc":"2.0","result":{"protocolVersion":"1.0.0","capabilities":{},"serverInfo":{"name":"test","version":"1.0.0"}},"id":1}`)
84+
85+
select {
86+
case requestCaptured <- true:
87+
default:
88+
}
89+
}))
90+
defer server.Close()
91+
92+
// Create remote client WITH custom headers using streamable transport
93+
expectedHeaders := map[string]string{
94+
"X-Custom-Auth": "custom-auth-value",
95+
}
96+
97+
client := newRemoteClient(server.URL, "streamable", expectedHeaders, NewInMemoryTokenStore())
98+
99+
ctx, cancel := context.WithTimeout(t.Context(), 2*time.Second)
100+
defer cancel()
101+
102+
// Try to initialize
103+
_, _ = client.Initialize(ctx, nil)
104+
105+
// Wait for the request to be captured
106+
select {
107+
case <-requestCaptured:
108+
// Verify that custom headers were applied
109+
actualValue := capturedRequest.Header.Get("X-Custom-Auth")
110+
assert.Equal(t, "custom-auth-value", actualValue,
111+
"Expected header X-Custom-Auth to have value %q, but got %q",
112+
"custom-auth-value", actualValue)
113+
case <-time.After(1 * time.Second):
114+
t.Fatal("Server did not receive request within timeout")
115+
}
116+
}
117+
118+
// TestRemoteClientNoHeaders verifies that the client works correctly even with no headers
119+
func TestRemoteClientNoHeaders(t *testing.T) {
120+
var capturedRequest *http.Request
121+
requestCaptured := make(chan bool, 1)
122+
123+
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
124+
capturedRequest = r
125+
126+
w.Header().Set("Content-Type", "text/event-stream")
127+
w.WriteHeader(http.StatusOK)
128+
fmt.Fprintf(w, "event: endpoint\ndata: {\"uri\":\"/message\"}\n\n")
129+
130+
select {
131+
case requestCaptured <- true:
132+
default:
133+
}
134+
135+
time.Sleep(100 * time.Millisecond)
136+
}))
137+
defer server.Close()
138+
139+
// Create remote client without custom headers (nil)
140+
client := newRemoteClient(server.URL, "sse", nil, NewInMemoryTokenStore())
141+
142+
ctx, cancel := context.WithTimeout(t.Context(), 2*time.Second)
143+
defer cancel()
144+
145+
_, _ = client.Initialize(ctx, nil)
146+
147+
// Wait for request
148+
select {
149+
case <-requestCaptured:
150+
// Just verify we got the request - no custom headers should be present
151+
require.NotNil(t, capturedRequest, "Request should have been captured")
152+
case <-time.After(1 * time.Second):
153+
t.Fatal("Server did not receive request within timeout")
154+
}
155+
}
156+
157+
// TestRemoteClientEmptyHeaders verifies that the client works correctly with an empty map
158+
func TestRemoteClientEmptyHeaders(t *testing.T) {
159+
var capturedRequest *http.Request
160+
requestCaptured := make(chan bool, 1)
161+
162+
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
163+
capturedRequest = r
164+
165+
w.Header().Set("Content-Type", "text/event-stream")
166+
w.WriteHeader(http.StatusOK)
167+
fmt.Fprintf(w, "event: endpoint\ndata: {\"uri\":\"/message\"}\n\n")
168+
169+
select {
170+
case requestCaptured <- true:
171+
default:
172+
}
173+
174+
time.Sleep(100 * time.Millisecond)
175+
}))
176+
defer server.Close()
177+
178+
// Create remote client with empty headers map
179+
client := newRemoteClient(server.URL, "sse", map[string]string{}, NewInMemoryTokenStore())
180+
181+
ctx, cancel := context.WithTimeout(t.Context(), 2*time.Second)
182+
defer cancel()
183+
184+
_, _ = client.Initialize(ctx, nil)
185+
186+
// Wait for request
187+
select {
188+
case <-requestCaptured:
189+
// Just verify we got the request
190+
require.NotNil(t, capturedRequest, "Request should have been captured")
191+
case <-time.After(1 * time.Second):
192+
t.Fatal("Server did not receive request within timeout")
193+
}
194+
}

0 commit comments

Comments
 (0)