Skip to content

Commit 012c317

Browse files
authored
Merge pull request #446 from dgageot/load-mcp-catalog-once
Load MCP catalog once
2 parents 092e2ea + 73ce65b commit 012c317

9 files changed

Lines changed: 123 additions & 37 deletions

File tree

pkg/gateway/catalog.go

Lines changed: 13 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -2,42 +2,36 @@ package gateway
22

33
import (
44
"context"
5+
"encoding/json"
56
"fmt"
6-
"io"
77
"net/http"
88
"strings"
99

10-
"github.com/goccy/go-yaml"
10+
"github.com/docker/cagent/pkg/sync"
1111
)
1212

1313
const DockerCatalogURL = "https://desktop.docker.com/mcp/catalog/v3/catalog.yaml"
1414

15-
func ParseServerRef(ref string) string {
16-
return strings.TrimPrefix(ref, "docker:")
17-
}
18-
19-
func RequiredEnvVars(ctx context.Context, serverName, catalogURL string) ([]Secret, error) {
20-
catalog, err := readCatalog(ctx, catalogURL)
15+
func RequiredEnvVars(ctx context.Context, serverName string) ([]Secret, error) {
16+
catalog, err := readCatalogOnce()
2117
if err != nil {
22-
return nil, err
18+
return nil, fmt.Errorf("failed to fetch MCP catalog: %w", err)
2319
}
2420

2521
server, ok := catalog[serverName]
2622
if !ok {
27-
return nil, fmt.Errorf("MCP server %q not found in catalog %q", serverName, catalogURL)
23+
return nil, fmt.Errorf("MCP server %q not found in MCP catalog", serverName)
2824
}
2925

3026
return server.Secrets, nil
3127
}
3228

33-
// TODO(dga): cache the catalog.
34-
func readCatalog(ctx context.Context, url string) (Catalog, error) {
35-
req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, http.NoBody)
36-
if err != nil {
37-
return nil, err
38-
}
29+
// Read the MCP Catalog only once and cache the result.
30+
var readCatalogOnce = sync.OnceErr(func() (Catalog, error) {
31+
// Use the JSON version because it's 3x time faster to parse than YAML.
32+
url := strings.Replace(DockerCatalogURL, ".yaml", ".json", 1)
3933

40-
resp, err := http.DefaultClient.Do(req)
34+
resp, err := http.Get(url)
4135
if err != nil {
4236
return nil, err
4337
}
@@ -47,15 +41,10 @@ func readCatalog(ctx context.Context, url string) (Catalog, error) {
4741
return nil, fmt.Errorf("failed to fetch URL: %s, status: %s", url, resp.Status)
4842
}
4943

50-
buf, err := io.ReadAll(resp.Body)
51-
if err != nil {
52-
return nil, err
53-
}
54-
5544
var topLevel topLevel
56-
if err := yaml.Unmarshal(buf, &topLevel); err != nil {
45+
if err := json.NewDecoder(resp.Body).Decode(&topLevel); err != nil {
5746
return nil, err
5847
}
5948

6049
return topLevel.Catalog, nil
61-
}
50+
})

pkg/gateway/catalog_test.go

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
package gateway
2+
3+
import (
4+
"testing"
5+
6+
"github.com/stretchr/testify/assert"
7+
"github.com/stretchr/testify/require"
8+
)
9+
10+
func TestRequiredEnvVars(t *testing.T) {
11+
secrets, err := RequiredEnvVars(t.Context(), "github-official")
12+
require.NoError(t, err)
13+
14+
assert.Len(t, secrets, 1)
15+
assert.Equal(t, "GITHUB_PERSONAL_ACCESS_TOKEN", secrets[0].Env)
16+
assert.Equal(t, "github.personal_access_token", secrets[0].Name)
17+
}

pkg/gateway/servers.go

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
package gateway
2+
3+
import (
4+
"strings"
5+
)
6+
7+
func ParseServerRef(ref string) string {
8+
return strings.TrimPrefix(ref, "docker:")
9+
}

pkg/gateway/types.go

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,17 @@
11
package gateway
22

33
type topLevel struct {
4-
Catalog Catalog `json:"registry" yaml:"registry"`
4+
Catalog Catalog `json:"registry"`
55
}
66

77
type Catalog map[string]Server
88

99
type Server struct {
10-
Secrets []Secret `json:"secrets,omitempty" yaml:"secrets,omitempty"`
10+
Secrets []Secret `json:"secrets,omitempty"`
1111
}
1212

1313
type Secret struct {
14-
Name string `json:"name" yaml:"name"`
15-
Env string `json:"env" yaml:"env"`
16-
Example string `json:"example" yaml:"example"`
14+
Name string `json:"name"`
15+
Env string `json:"env"`
16+
Example string `json:"example"`
1717
}

pkg/secrets/gather.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -79,7 +79,7 @@ func GatherEnvVarsForTools(ctx context.Context, cfg *latest.Config) ([]string, e
7979
for _, ref := range gatherMCPServerReferences(cfg) {
8080
mcpServerName := gateway.ParseServerRef(ref)
8181

82-
secrets, err := gateway.RequiredEnvVars(ctx, mcpServerName, gateway.DockerCatalogURL)
82+
secrets, err := gateway.RequiredEnvVars(ctx, mcpServerName)
8383
if err != nil {
8484
return nil, fmt.Errorf("reading which secrets the MCP server needs: %w", err)
8585
}

pkg/sync/oncerr.go

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
package sync
2+
3+
import "sync"
4+
5+
func OnceErr[T any](fn func() (T, error)) func() (T, error) {
6+
var once sync.Once
7+
var result T
8+
var err error
9+
10+
return func() (T, error) {
11+
once.Do(func() {
12+
result, err = fn()
13+
})
14+
return result, err
15+
}
16+
}

pkg/sync/oncerr_test.go

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,52 @@
1+
package sync
2+
3+
import (
4+
"errors"
5+
"testing"
6+
7+
"github.com/stretchr/testify/require"
8+
)
9+
10+
func TestOnceErr(t *testing.T) {
11+
t.Parallel()
12+
13+
called := 0
14+
fn := func() (int, error) {
15+
called++
16+
return 42, nil
17+
}
18+
19+
memoizedFn := OnceErr(fn)
20+
21+
value, err := memoizedFn()
22+
require.NoError(t, err)
23+
require.Equal(t, 42, value)
24+
require.Equal(t, 1, called)
25+
26+
value, err = memoizedFn()
27+
require.NoError(t, err)
28+
require.Equal(t, 42, value)
29+
require.Equal(t, 1, called) // Didn't have to call the inner fn
30+
}
31+
32+
func TestOnceErr_Error(t *testing.T) {
33+
t.Parallel()
34+
35+
called := 0
36+
fn := func() (int, error) {
37+
called++
38+
return 1337, errors.New("test error")
39+
}
40+
41+
memoizedFn := OnceErr(fn)
42+
43+
value, err := memoizedFn()
44+
require.Error(t, err)
45+
require.Equal(t, 1337, value)
46+
require.Equal(t, 1, called)
47+
48+
value, err = memoizedFn()
49+
require.Error(t, err)
50+
require.Equal(t, 1337, value)
51+
require.Equal(t, 1, called) // Didn't have to call the inner fn
52+
}

pkg/teamloader/teamloader_test.go

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -100,9 +100,11 @@ func TestLoadExamples(t *testing.T) {
100100
// Collect the missing env vars.
101101
missingEnvs := map[string]bool{}
102102

103+
var runtimeConfig config.RuntimeConfig
104+
103105
for _, file := range collectExamples(t) {
104106
t.Run(file, func(t *testing.T) {
105-
_, err := Load(t.Context(), file, config.RuntimeConfig{})
107+
_, err := Load(t.Context(), file, runtimeConfig)
106108
if err != nil {
107109
envErr := &environment.RequiredEnvError{}
108110
require.ErrorAs(t, err, &envErr)
@@ -123,7 +125,7 @@ func TestLoadExamples(t *testing.T) {
123125
t.Run(file, func(t *testing.T) {
124126
t.Parallel()
125127

126-
teams, err := Load(t.Context(), file, config.RuntimeConfig{})
128+
teams, err := Load(t.Context(), file, runtimeConfig)
127129
require.NoError(t, err)
128130
require.NotEmpty(t, teams)
129131
})

pkg/tools/mcp/gateway.go

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -35,10 +35,11 @@ func NewGatewayToolset(mcpServerName string, config any, toolFilter []string, en
3535
slog.Debug("Creating MCP Gateway toolset", "name", mcpServerName, "toolFilter", toolFilter)
3636

3737
return &GatewayToolset{
38-
mcpServerName: mcpServerName,
39-
config: config,
40-
toolFilter: toolFilter,
41-
envProvider: envProvider,
38+
mcpServerName: mcpServerName,
39+
config: config,
40+
toolFilter: toolFilter,
41+
envProvider: envProvider,
42+
4243
cleanUpConfig: func() error { return nil },
4344
cleanUpSecrets: func() error { return nil },
4445
}
@@ -50,7 +51,7 @@ func (t *GatewayToolset) Instructions() string {
5051

5152
func (t *GatewayToolset) configureOnce(ctx context.Context) error {
5253
// Check which secrets (env vars) are required by the MCP server.
53-
secrets, err := gateway.RequiredEnvVars(ctx, t.mcpServerName, gateway.DockerCatalogURL)
54+
secrets, err := gateway.RequiredEnvVars(ctx, t.mcpServerName)
5455
if err != nil {
5556
return fmt.Errorf("reading which secrets the MCP server needs: %w", err)
5657
}

0 commit comments

Comments
 (0)