Skip to content

Commit 025dcb3

Browse files
authored
Merge pull request #1751 from cogvel/feat/model-name-header
feat: add `X-Cagent-Model-Name` header to models gateway requests
2 parents b86ea81 + 108e1de commit 025dcb3

8 files changed

Lines changed: 116 additions & 0 deletions

File tree

pkg/config/latest/types.go

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -259,6 +259,9 @@ func (a *AgentConfig) GetFallbackCooldown() time.Duration {
259259

260260
// ModelConfig represents the configuration for a model
261261
type ModelConfig struct {
262+
// Name is the manifest model name (map key), populated at runtime.
263+
// Not serialized — set by teamloader/model_switcher when resolving models.
264+
Name string `json:"-"`
262265
Provider string `json:"provider,omitempty"`
263266
Model string `json:"model,omitempty"`
264267
Temperature *float64 `json:"temperature,omitempty"`

pkg/httpclient/client.go

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,14 @@ func WithModel(model string) Opt {
7676
}
7777
}
7878

79+
func WithModelName(name string) Opt {
80+
return func(o *HTTPOptions) {
81+
if name != "" {
82+
o.Header.Set("X-Cagent-Model-Name", name)
83+
}
84+
}
85+
}
86+
7987
func WithQuery(query url.Values) Opt {
8088
return func(o *HTTPOptions) {
8189
o.Query = query

pkg/httpclient/client_test.go

Lines changed: 97 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,97 @@
1+
package httpclient
2+
3+
import (
4+
"net/http"
5+
"net/http/httptest"
6+
"testing"
7+
8+
"github.com/stretchr/testify/assert"
9+
"github.com/stretchr/testify/require"
10+
)
11+
12+
func TestWithModelName(t *testing.T) {
13+
t.Parallel()
14+
15+
tests := []struct {
16+
name string
17+
modelName string
18+
wantSet bool
19+
}{
20+
{
21+
name: "sets header when name is provided",
22+
modelName: "my-fast-model",
23+
wantSet: true,
24+
},
25+
{
26+
name: "skips header when name is empty",
27+
modelName: "",
28+
wantSet: false,
29+
},
30+
}
31+
32+
for _, tt := range tests {
33+
t.Run(tt.name, func(t *testing.T) {
34+
t.Parallel()
35+
36+
var capturedHeaders http.Header
37+
srv := httptest.NewServer(http.HandlerFunc(func(_ http.ResponseWriter, r *http.Request) {
38+
capturedHeaders = r.Header
39+
}))
40+
defer srv.Close()
41+
42+
client := NewHTTPClient(WithModelName(tt.modelName))
43+
req, err := http.NewRequest(http.MethodGet, srv.URL, http.NoBody)
44+
require.NoError(t, err)
45+
46+
resp, err := client.Do(req)
47+
require.NoError(t, err)
48+
defer func() { _ = resp.Body.Close() }()
49+
50+
if tt.wantSet {
51+
assert.Equal(t, tt.modelName, capturedHeaders.Get("X-Cagent-Model-Name"))
52+
} else {
53+
assert.Empty(t, capturedHeaders.Get("X-Cagent-Model-Name"))
54+
}
55+
})
56+
}
57+
}
58+
59+
func TestWithModel(t *testing.T) {
60+
t.Parallel()
61+
62+
var capturedHeaders http.Header
63+
srv := httptest.NewServer(http.HandlerFunc(func(_ http.ResponseWriter, r *http.Request) {
64+
capturedHeaders = r.Header
65+
}))
66+
defer srv.Close()
67+
68+
client := NewHTTPClient(WithModel("gpt-4o"))
69+
req, err := http.NewRequest(http.MethodGet, srv.URL, http.NoBody)
70+
require.NoError(t, err)
71+
72+
resp, err := client.Do(req)
73+
require.NoError(t, err)
74+
defer func() { _ = resp.Body.Close() }()
75+
76+
assert.Equal(t, "gpt-4o", capturedHeaders.Get("X-Cagent-Model"))
77+
}
78+
79+
func TestWithProvider(t *testing.T) {
80+
t.Parallel()
81+
82+
var capturedHeaders http.Header
83+
srv := httptest.NewServer(http.HandlerFunc(func(_ http.ResponseWriter, r *http.Request) {
84+
capturedHeaders = r.Header
85+
}))
86+
defer srv.Close()
87+
88+
client := NewHTTPClient(WithProvider("openai"))
89+
req, err := http.NewRequest(http.MethodGet, srv.URL, http.NoBody)
90+
require.NoError(t, err)
91+
92+
resp, err := client.Do(req)
93+
require.NoError(t, err)
94+
defer func() { _ = resp.Body.Close() }()
95+
96+
assert.Equal(t, "openai", capturedHeaders.Get("X-Cagent-Provider"))
97+
}

pkg/model/provider/anthropic/client.go

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -184,6 +184,7 @@ func NewClient(ctx context.Context, cfg *latest.ModelConfig, env environment.Pro
184184
httpclient.WithProxiedBaseURL(cmp.Or(cfg.BaseURL, "https://api.anthropic.com/")),
185185
httpclient.WithProvider(cfg.Provider),
186186
httpclient.WithModel(cfg.Model),
187+
httpclient.WithModelName(cfg.Name),
187188
httpclient.WithQuery(url.Query()),
188189
}
189190
if globalOptions.GeneratingTitle() {

pkg/model/provider/gemini/client.go

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -130,6 +130,7 @@ func NewClient(ctx context.Context, cfg *latest.ModelConfig, env environment.Pro
130130
httpclient.WithProxiedBaseURL(cmp.Or(cfg.BaseURL, "https://generativelanguage.googleapis.com/")),
131131
httpclient.WithProvider(cfg.Provider),
132132
httpclient.WithModel(cfg.Model),
133+
httpclient.WithModelName(cfg.Name),
133134
httpclient.WithQuery(url.Query()),
134135
}
135136
if globalOptions.GeneratingTitle() {

pkg/model/provider/openai/client.go

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -118,6 +118,7 @@ func NewClient(ctx context.Context, cfg *latest.ModelConfig, env environment.Pro
118118
httpclient.WithProxiedBaseURL(cmp.Or(cfg.BaseURL, "https://api.openai.com/v1")),
119119
httpclient.WithProvider(cfg.Provider),
120120
httpclient.WithModel(cfg.Model),
121+
httpclient.WithModelName(cfg.Name),
121122
httpclient.WithQuery(url.Query()),
122123
}
123124
if globalOptions.GeneratingTitle() {

pkg/runtime/model_switcher.go

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,7 @@ func (r *LocalRuntime) SetAgentModel(ctx context.Context, agentName, modelRef st
8484

8585
// Check if modelRef is a named model from config
8686
if modelConfig, exists := r.modelSwitcherCfg.Models[modelRef]; exists {
87+
modelConfig.Name = modelRef
8788
// Check if this is an alloy model (no provider, comma-separated models)
8889
if isAlloyModelConfig(modelConfig) {
8990
providers, err := r.createProvidersFromAlloyConfig(ctx, modelConfig)
@@ -175,6 +176,7 @@ func (r *LocalRuntime) createProvidersFromInlineAlloy(ctx context.Context, model
175176

176177
// Check if this part exists as a named model in config
177178
if modelCfg, exists := r.modelSwitcherCfg.Models[part]; exists {
179+
modelCfg.Name = part
178180
prov, err := r.createProviderFromConfig(ctx, &modelCfg)
179181
if err != nil {
180182
return nil, fmt.Errorf("failed to create provider for %q: %w", part, err)
@@ -219,6 +221,7 @@ func (r *LocalRuntime) createProvidersFromAlloyConfig(ctx context.Context, alloy
219221

220222
// Check if this model reference exists in the config
221223
if modelCfg, exists := r.modelSwitcherCfg.Models[modelRef]; exists {
224+
modelCfg.Name = modelRef
222225
prov, err := r.createProviderFromConfig(ctx, &modelCfg)
223226
if err != nil {
224227
return nil, fmt.Errorf("failed to create provider for %q: %w", modelRef, err)

pkg/teamloader/teamloader.go

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -301,6 +301,7 @@ func getModelsForAgent(ctx context.Context, cfg *latest.Config, a *latest.AgentC
301301
return nil, false, fmt.Errorf("model '%s' not found in configuration", name)
302302
}
303303
}
304+
modelCfg.Name = name
304305

305306
// Check if thinking_budget was explicitly configured BEFORE provider defaults are applied.
306307
// This is used to initialize session thinking state - thinking is only enabled by default
@@ -371,6 +372,7 @@ func getFallbackModelsForAgent(ctx context.Context, cfg *latest.Config, a *lates
371372
Model: modelName,
372373
}
373374
}
375+
modelCfg.Name = name
374376

375377
// Use max_tokens from config if specified, otherwise look up from models.dev
376378
maxTokens := &defaultMaxTokens

0 commit comments

Comments
 (0)