Skip to content

Commit ddf834c

Browse files
authored
Merge pull request #2133 from dgageot/board/look-at-the-mcp-gateway-code-and-find-pl-033c022b
gateway: harden and simplify MCP gateway code
2 parents e087d99 + eb5bb1c commit ddf834c

5 files changed

Lines changed: 95 additions & 37 deletions

File tree

pkg/gateway/catalog.go

Lines changed: 29 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -19,10 +19,10 @@ const (
1919
DockerCatalogURL = "https://desktop.docker.com/mcp/catalog/v3/catalog.yaml"
2020
catalogCacheFileName = "mcp_catalog.json"
2121
fetchTimeout = 15 * time.Second
22-
)
2322

24-
// catalogJSON is the URL we actually fetch (JSON is ~3x faster to parse than YAML).
25-
var catalogJSON = strings.Replace(DockerCatalogURL, ".yaml", ".json", 1)
23+
// catalogJSON is the URL we actually fetch (JSON is ~3x faster to parse than YAML).
24+
catalogJSON = "https://desktop.docker.com/mcp/catalog/v3/catalog.json"
25+
)
2626

2727
func RequiredEnvVars(ctx context.Context, serverName string) ([]Secret, error) {
2828
server, err := ServerSpec(ctx, serverName)
@@ -40,8 +40,8 @@ func RequiredEnvVars(ctx context.Context, serverName string) ([]Secret, error) {
4040
return server.Secrets, nil
4141
}
4242

43-
func ServerSpec(ctx context.Context, serverName string) (Server, error) {
44-
catalog, err := loadCatalog(ctx)
43+
func ServerSpec(_ context.Context, serverName string) (Server, error) {
44+
catalog, err := catalogOnce()
4545
if err != nil {
4646
return Server{}, err
4747
}
@@ -54,6 +54,11 @@ func ServerSpec(ctx context.Context, serverName string) (Server, error) {
5454
return server, nil
5555
}
5656

57+
// ParseServerRef strips the optional "docker:" prefix from a server reference.
58+
func ParseServerRef(ref string) string {
59+
return strings.TrimPrefix(ref, "docker:")
60+
}
61+
5762
// cachedCatalog is the on-disk cache format.
5863
type cachedCatalog struct {
5964
Catalog Catalog `json:"catalog"`
@@ -69,12 +74,6 @@ var catalogOnce = sync.OnceValues(func() (Catalog, error) {
6974
return fetchAndCache(context.Background())
7075
})
7176

72-
// loadCatalog returns the catalog, fetching it at most once per process run.
73-
// On network failure it falls back to the disk cache.
74-
func loadCatalog(_ context.Context) (Catalog, error) {
75-
return catalogOnce()
76-
}
77-
7877
// fetchAndCache tries to fetch the catalog from the network (using ETag for
7978
// conditional requests) and falls back to the disk cache on any failure.
8079
func fetchAndCache(ctx context.Context) (Catalog, error) {
@@ -128,16 +127,24 @@ func saveToDisk(path string, catalog Catalog, etag string) {
128127
}
129128

130129
dir := filepath.Dir(path)
131-
if err := os.MkdirAll(dir, 0o755); err != nil {
132-
slog.Warn("Failed to create MCP catalog cache directory", "error", err)
133-
return
134-
}
135130

136131
// Write to a temp file and rename so readers never see a partial file.
132+
// Try creating the temp file first; only create the directory if needed.
137133
tmp, err := os.CreateTemp(dir, ".mcp_catalog_*.tmp")
138134
if err != nil {
139-
slog.Warn("Failed to create MCP catalog temp file", "error", err)
140-
return
135+
if !os.IsNotExist(err) {
136+
slog.Warn("Failed to create MCP catalog temp file", "error", err)
137+
return
138+
}
139+
if mkErr := os.MkdirAll(dir, 0o755); mkErr != nil {
140+
slog.Warn("Failed to create MCP catalog cache directory", "error", mkErr)
141+
return
142+
}
143+
tmp, err = os.CreateTemp(dir, ".mcp_catalog_*.tmp")
144+
if err != nil {
145+
slog.Warn("Failed to create MCP catalog temp file", "error", err)
146+
return
147+
}
141148
}
142149
tmpName := tmp.Name()
143150

@@ -159,6 +166,10 @@ func saveToDisk(path string, catalog Catalog, etag string) {
159166
}
160167
}
161168

169+
// catalogClient is a dedicated HTTP client for catalog fetches, isolated from
170+
// http.DefaultClient so that other parts of the process cannot interfere.
171+
var catalogClient = &http.Client{}
172+
162173
// fetchFromNetwork fetches the catalog, using the ETag for conditional requests.
163174
// It returns (nil, "", nil) when the server responds with 304 Not Modified.
164175
func fetchFromNetwork(ctx context.Context, etag string) (Catalog, string, error) {
@@ -174,7 +185,7 @@ func fetchFromNetwork(ctx context.Context, etag string) (Catalog, string, error)
174185
req.Header.Set("If-None-Match", etag)
175186
}
176187

177-
resp, err := http.DefaultClient.Do(req)
188+
resp, err := catalogClient.Do(req)
178189
if err != nil {
179190
return nil, "", err
180191
}

pkg/gateway/catalog_test.go

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,45 @@
11
package gateway
22

33
import (
4+
"os"
45
"testing"
56

67
"github.com/stretchr/testify/assert"
78
"github.com/stretchr/testify/require"
89
)
910

11+
// testCatalog is a self-contained catalog used by all tests, removing the
12+
// dependency on the live Docker MCP catalog and the network.
13+
var testCatalog = Catalog{
14+
"github-official": {
15+
Type: "server",
16+
Secrets: []Secret{
17+
{Name: "github.personal_access_token", Env: "GITHUB_PERSONAL_ACCESS_TOKEN"},
18+
},
19+
},
20+
"fetch": {
21+
Type: "server",
22+
},
23+
"apify": {
24+
Type: "remote",
25+
Secrets: []Secret{
26+
{Name: "apify.token", Env: "APIFY_TOKEN"},
27+
},
28+
Remote: Remote{
29+
URL: "https://mcp.apify.com",
30+
TransportType: "streamable-http",
31+
},
32+
},
33+
}
34+
35+
func TestMain(m *testing.M) {
36+
// Override the production catalogOnce so that tests never hit the network.
37+
catalogOnce = func() (Catalog, error) {
38+
return testCatalog, nil
39+
}
40+
os.Exit(m.Run())
41+
}
42+
1043
func TestRequiredEnvVars_local(t *testing.T) {
1144
secrets, err := RequiredEnvVars(t.Context(), "github-official")
1245
require.NoError(t, err)
@@ -38,3 +71,15 @@ func TestServerSpec_remote(t *testing.T) {
3871
assert.Equal(t, "https://mcp.apify.com", server.Remote.URL)
3972
assert.Equal(t, "streamable-http", server.Remote.TransportType)
4073
}
74+
75+
func TestServerSpec_notFound(t *testing.T) {
76+
_, err := ServerSpec(t.Context(), "nonexistent")
77+
require.Error(t, err)
78+
79+
assert.Contains(t, err.Error(), "not found in MCP catalog")
80+
}
81+
82+
func TestParseServerRef(t *testing.T) {
83+
assert.Equal(t, "github-official", ParseServerRef("docker:github-official"))
84+
assert.Equal(t, "github-official", ParseServerRef("github-official"))
85+
}

pkg/gateway/servers.go

Lines changed: 0 additions & 9 deletions
This file was deleted.

pkg/teamloader/registry.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -263,7 +263,7 @@ func createMCPTool(ctx context.Context, toolset latest.Toolset, _ string, runCon
263263
envProvider,
264264
)
265265

266-
return mcp.NewGatewayToolset(ctx, toolset.Name, mcpServerName, toolset.Config, envProvider, runConfig.WorkingDir)
266+
return mcp.NewGatewayToolset(ctx, toolset.Name, mcpServerName, serverSpec.Secrets, toolset.Config, envProvider, runConfig.WorkingDir)
267267

268268
// STDIO MCP Server from shell command
269269
case toolset.Command != "":

pkg/tools/mcp/gateway.go

Lines changed: 20 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -22,15 +22,9 @@ type GatewayToolset struct {
2222

2323
var _ tools.ToolSet = (*GatewayToolset)(nil)
2424

25-
func NewGatewayToolset(ctx context.Context, name, mcpServerName string, config any, envProvider environment.Provider, cwd string) (*GatewayToolset, error) {
25+
func NewGatewayToolset(ctx context.Context, name, mcpServerName string, secrets []gateway.Secret, config any, envProvider environment.Provider, cwd string) (*GatewayToolset, error) {
2626
slog.Debug("Creating MCP Gateway toolset", "name", mcpServerName)
2727

28-
// Check which secrets (env vars) are required by the MCP server.
29-
secrets, err := gateway.RequiredEnvVars(ctx, mcpServerName)
30-
if err != nil {
31-
return nil, fmt.Errorf("reading which secrets the MCP server needs: %w", err)
32-
}
33-
3428
// Make sure all the required secrets are available in the environment.
3529
// TODO(dga): Ideally, the MCP gateway would use the same provider that we have.
3630
fileSecrets, err := writeSecretsToFile(ctx, mcpServerName, secrets, envProvider)
@@ -66,7 +60,14 @@ func NewGatewayToolset(ctx context.Context, name, mcpServerName string, config a
6660
}
6761

6862
func (t *GatewayToolset) Stop(ctx context.Context) error {
69-
return errors.Join(t.Toolset.Stop(ctx), t.cleanUp())
63+
stopErr := t.Toolset.Stop(ctx)
64+
65+
cleanUpErr := t.cleanUp()
66+
if cleanUpErr != nil {
67+
slog.Warn("Failed to clean up MCP Gateway temp files", "error", cleanUpErr)
68+
}
69+
70+
return errors.Join(stopErr, cleanUpErr)
7071
}
7172

7273
func writeSecretsToFile(ctx context.Context, mcpServerName string, secrets []gateway.Secret, envProvider environment.Provider) (string, error) {
@@ -77,6 +78,10 @@ func writeSecretsToFile(ctx context.Context, mcpServerName string, secrets []gat
7778
return "", errors.New("missing environment variable " + secret.Env + " required by MCP server " + mcpServerName)
7879
}
7980

81+
if strings.ContainsAny(v, "\n\r") {
82+
return "", fmt.Errorf("secret %s contains newline characters", secret.Env)
83+
}
84+
8085
secretValues = append(secretValues, fmt.Sprintf("%s=%s", secret.Name, v))
8186
}
8287

@@ -100,9 +105,15 @@ func writeTempFile(nameTemplate string, content []byte) (string, error) {
100105
if err != nil {
101106
return "", fmt.Errorf("creating temp file: %w", err)
102107
}
103-
defer f.Close()
104108

105109
if _, err := f.Write(content); err != nil {
110+
f.Close()
111+
os.Remove(f.Name())
112+
return "", err
113+
}
114+
115+
if err := f.Close(); err != nil {
116+
os.Remove(f.Name())
106117
return "", err
107118
}
108119

0 commit comments

Comments
 (0)