Skip to content

Commit ed21eba

Browse files
authored
Merge pull request #617 from dgageot/better-gateway
Better handling of the gateway
2 parents 56e5993 + 573650a commit ed21eba

4 files changed

Lines changed: 144 additions & 2 deletions

File tree

pkg/httpclient/client.go

Lines changed: 102 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,102 @@
1+
package httpclient
2+
3+
import (
4+
"fmt"
5+
"maps"
6+
"net/http"
7+
"runtime"
8+
9+
"github.com/docker/cagent/pkg/version"
10+
)
11+
12+
type HTTPOptions struct {
13+
Header http.Header
14+
}
15+
16+
type Opt func(*HTTPOptions)
17+
18+
func NewHTTPClient(opts ...Opt) *http.Client {
19+
httpOptions := HTTPOptions{
20+
Header: make(http.Header),
21+
}
22+
23+
for _, opt := range opts {
24+
opt(&httpOptions)
25+
}
26+
27+
// Enforce a consistent User-Agent header
28+
httpOptions.Header.Set("User-Agent", fmt.Sprintf("Cagent/%s (%s; %s)", version.Version, getNormalizedOS(), getNormalizedArchitecture()))
29+
30+
return &http.Client{
31+
Transport: &userAgentTransport{
32+
httpOptions: httpOptions,
33+
rt: http.DefaultTransport,
34+
},
35+
}
36+
}
37+
38+
func WithHeader(key, value string) Opt {
39+
return func(o *HTTPOptions) {
40+
o.Header.Set(key, value)
41+
}
42+
}
43+
44+
func WithProxiedBaseURL(value string) Opt {
45+
return func(o *HTTPOptions) {
46+
o.Header.Set("X-Cagent-Forward", value)
47+
48+
// Enforce consistent headers (Anthropic client sets similar header already)
49+
o.Header.Set("X-Cagent-Lang", "go")
50+
o.Header.Set("X-Cagent-OS", getNormalizedOS())
51+
o.Header.Set("X-Cagent-Arch", getNormalizedArchitecture())
52+
o.Header.Set("X-Cagent-Runtime", "cagent")
53+
o.Header.Set("X-Cagent-Runtime-Version", version.Version)
54+
}
55+
}
56+
57+
type userAgentTransport struct {
58+
httpOptions HTTPOptions
59+
rt http.RoundTripper
60+
}
61+
62+
func (u *userAgentTransport) RoundTrip(req *http.Request) (*http.Response, error) {
63+
r2 := req.Clone(req.Context())
64+
maps.Copy(r2.Header, u.httpOptions.Header)
65+
return u.rt.RoundTrip(r2)
66+
}
67+
68+
func getNormalizedOS() string {
69+
switch runtime.GOOS {
70+
case "ios":
71+
return "iOS"
72+
case "android":
73+
return "Android"
74+
case "darwin":
75+
return "MacOS"
76+
case "window":
77+
return "Windows"
78+
case "freebsd":
79+
return "FreeBSD"
80+
case "openbsd":
81+
return "OpenBSD"
82+
case "linux":
83+
return "Linux"
84+
default:
85+
return fmt.Sprintf("Other:%s", runtime.GOOS)
86+
}
87+
}
88+
89+
func getNormalizedArchitecture() string {
90+
switch runtime.GOARCH {
91+
case "386":
92+
return "x32"
93+
case "amd64":
94+
return "x64"
95+
case "arm":
96+
return "arm"
97+
case "arm64":
98+
return "arm64"
99+
default:
100+
return fmt.Sprintf("other:%s", runtime.GOARCH)
101+
}
102+
}

pkg/model/provider/anthropic/client.go

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ import (
1414
"github.com/docker/cagent/pkg/chat"
1515
latest "github.com/docker/cagent/pkg/config/v2"
1616
"github.com/docker/cagent/pkg/environment"
17+
"github.com/docker/cagent/pkg/httpclient"
1718
"github.com/docker/cagent/pkg/model/provider/base"
1819
"github.com/docker/cagent/pkg/model/provider/options"
1920
"github.com/docker/cagent/pkg/tools"
@@ -81,6 +82,7 @@ func NewClient(ctx context.Context, cfg *latest.ModelConfig, env environment.Pro
8182
slog.Debug("Anthropic API key found, creating client")
8283
requestOptions := []option.RequestOption{
8384
option.WithAPIKey(authToken),
85+
option.WithHTTPClient(httpclient.NewHTTPClient()),
8486
}
8587
if cfg.BaseURL != "" {
8688
requestOptions = append(requestOptions, option.WithBaseURL(cfg.BaseURL))
@@ -108,6 +110,9 @@ func NewClient(ctx context.Context, cfg *latest.ModelConfig, env environment.Pro
108110
option.WithAuthToken(authToken),
109111
option.WithAPIKey(authToken),
110112
option.WithBaseURL(gateway),
113+
option.WithHTTPClient(httpclient.NewHTTPClient(
114+
httpclient.WithProxiedBaseURL(defaultsTo(cfg.BaseURL, "https://api.anthropic.com/")),
115+
)),
111116
), nil
112117
}
113118
}
@@ -654,3 +659,10 @@ func countAnthropicTokens(
654659
}
655660
return result.InputTokens, nil
656661
}
662+
663+
func defaultsTo(value, defaultValue string) string {
664+
if value != "" {
665+
return value
666+
}
667+
return defaultValue
668+
}

pkg/model/provider/gemini/client.go

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ import (
1414
"github.com/docker/cagent/pkg/chat"
1515
latest "github.com/docker/cagent/pkg/config/v2"
1616
"github.com/docker/cagent/pkg/environment"
17+
"github.com/docker/cagent/pkg/httpclient"
1718
"github.com/docker/cagent/pkg/model/provider/base"
1819
"github.com/docker/cagent/pkg/model/provider/options"
1920
"github.com/docker/cagent/pkg/tools"
@@ -49,8 +50,12 @@ func NewClient(ctx context.Context, cfg *latest.ModelConfig, env environment.Pro
4950
}
5051

5152
client, err := genai.NewClient(ctx, &genai.ClientConfig{
52-
APIKey: apiKey,
53-
Backend: genai.BackendGeminiAPI,
53+
APIKey: apiKey,
54+
Backend: genai.BackendGeminiAPI,
55+
HTTPClient: httpclient.NewHTTPClient(),
56+
HTTPOptions: genai.HTTPOptions{
57+
BaseURL: cfg.BaseURL,
58+
},
5459
})
5560
if err != nil {
5661
return nil, err
@@ -77,6 +82,9 @@ func NewClient(ctx context.Context, cfg *latest.ModelConfig, env environment.Pro
7782
return genai.NewClient(ctx, &genai.ClientConfig{
7883
APIKey: authToken,
7984
Backend: genai.BackendGeminiAPI,
85+
HTTPClient: httpclient.NewHTTPClient(
86+
httpclient.WithProxiedBaseURL(defaultsTo(cfg.BaseURL, "https://generativelanguage.googleapis.com/")),
87+
),
8088
HTTPOptions: genai.HTTPOptions{
8189
BaseURL: gateway,
8290
Headers: http.Header{
@@ -340,3 +348,10 @@ func (c *Client) CreateChatCompletionStream(
340348
iter := client.Models.GenerateContentStream(ctx, c.ModelConfig.Model, contents, config)
341349
return NewStreamAdapter(iter, c.ModelConfig.Model), nil
342350
}
351+
352+
func defaultsTo(value, defaultValue string) string {
353+
if value != "" {
354+
return value
355+
}
356+
return defaultValue
357+
}

pkg/model/provider/openai/client.go

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ import (
1414
"github.com/docker/cagent/pkg/chat"
1515
latest "github.com/docker/cagent/pkg/config/v2"
1616
"github.com/docker/cagent/pkg/environment"
17+
"github.com/docker/cagent/pkg/httpclient"
1718
"github.com/docker/cagent/pkg/model/provider/base"
1819
"github.com/docker/cagent/pkg/model/provider/options"
1920
"github.com/docker/cagent/pkg/tools"
@@ -71,6 +72,8 @@ func NewClient(ctx context.Context, cfg *latest.ModelConfig, env environment.Pro
7172
openaiConfig.BaseURL = cfg.BaseURL
7273
}
7374

75+
openaiConfig.HTTPClient = httpclient.NewHTTPClient()
76+
7477
// TODO: Move this logic to ProviderAliases as a config function
7578
if cfg.ProviderOpts != nil {
7679
switch cfg.Provider { //nolint:gocritic
@@ -106,6 +109,9 @@ func NewClient(ctx context.Context, cfg *latest.ModelConfig, env environment.Pro
106109

107110
openaiConfig := openai.DefaultConfig(authToken)
108111
openaiConfig.BaseURL = gateway + "/v1"
112+
openaiConfig.HTTPClient = httpclient.NewHTTPClient(
113+
httpclient.WithProxiedBaseURL(defaultsTo(cfg.BaseURL, "https://api.openai.com/v1")),
114+
)
109115

110116
return openai.NewClientWithConfig(openaiConfig), nil
111117
}
@@ -389,3 +395,10 @@ type jsonSchema map[string]any
389395
func (j jsonSchema) MarshalJSON() ([]byte, error) {
390396
return json.Marshal(map[string]any(j))
391397
}
398+
399+
func defaultsTo(value, defaultValue string) string {
400+
if value != "" {
401+
return value
402+
}
403+
return defaultValue
404+
}

0 commit comments

Comments
 (0)