Skip to content

Commit f335c41

Browse files
committed
update anthropic default model to sonnet 4.5, and allow users to define their own default model in their global config
Signed-off-by: Christopher Petito <chrisjpetito@gmail.com>
1 parent cdd5b3b commit f335c41

10 files changed

Lines changed: 445 additions & 13 deletions

File tree

README.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -161,7 +161,7 @@ agents:
161161
models:
162162
claude:
163163
provider: anthropic
164-
model: claude-sonnet-4-0
164+
model: claude-sonnet-4-5
165165
max_tokens: 64000
166166
```
167167

@@ -425,7 +425,7 @@ these three providers in order based on the first api key it finds in your
425425
environment.
426426

427427
```sh
428-
export ANTHROPIC_API_KEY=your_api_key_here # first choice. default model claude-sonnet-4-0
428+
export ANTHROPIC_API_KEY=your_api_key_here # first choice. default model claude-sonnet-4-5
429429
export OPENAI_API_KEY=your_api_key_here # if anthropic key not set. default model gpt-5-mini
430430
export GOOGLE_API_KEY=your_api_key_here # if anthropic and openai keys are not set. default model gemini-2.5-flash
431431
```

cmd/root/flags.go

Lines changed: 27 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,12 +10,14 @@ import (
1010
"github.com/spf13/cobra"
1111

1212
"github.com/docker/cagent/pkg/config"
13+
"github.com/docker/cagent/pkg/config/latest"
1314
"github.com/docker/cagent/pkg/userconfig"
1415
)
1516

1617
const (
1718
flagModelsGateway = "models-gateway"
1819
envModelsGateway = "CAGENT_MODELS_GATEWAY"
20+
envDefaultModel = "CAGENT_DEFAULT_MODEL"
1921
)
2022

2123
func addRuntimeConfigFlags(cmd *cobra.Command, runConfig *config.RuntimeConfig) {
@@ -63,17 +65,29 @@ func addGatewayFlags(cmd *cobra.Command, runConfig *config.RuntimeConfig) {
6365

6466
persistentPreRunE := cmd.PersistentPreRunE
6567
cmd.PersistentPreRunE = func(_ *cobra.Command, args []string) error {
68+
userCfg, err := loadUserConfig()
69+
if err != nil {
70+
slog.Warn("Failed to load user config", "error", err)
71+
userCfg = &userconfig.Config{}
72+
}
73+
6674
// Precedence: CLI flag > environment variable > user config
6775
if runConfig.ModelsGateway == "" {
6876
if gateway := os.Getenv(envModelsGateway); gateway != "" {
6977
runConfig.ModelsGateway = gateway
70-
} else if userCfg, err := loadUserConfig(); err == nil && userCfg.ModelsGateway != "" {
78+
} else if userCfg.ModelsGateway != "" {
7179
runConfig.ModelsGateway = userCfg.ModelsGateway
7280
}
7381
}
74-
7582
runConfig.ModelsGateway = canonize(runConfig.ModelsGateway)
7683

84+
// Precedence for default model: environment variable > user config
85+
if model := os.Getenv(envDefaultModel); model != "" {
86+
runConfig.DefaultModel = parseModelShorthand(model)
87+
} else if userCfg.DefaultModel != nil {
88+
runConfig.DefaultModel = &userCfg.DefaultModel.ModelConfig
89+
}
90+
7791
if err := setupWorkingDirectory(runConfig.WorkingDir); err != nil {
7892
return err
7993
}
@@ -88,3 +102,14 @@ func addGatewayFlags(cmd *cobra.Command, runConfig *config.RuntimeConfig) {
88102
return nil
89103
}
90104
}
105+
106+
// parseModelShorthand parses "provider/model" into a ModelConfig
107+
func parseModelShorthand(s string) *latest.ModelConfig {
108+
if idx := strings.Index(s, "/"); idx > 0 && idx < len(s)-1 {
109+
return &latest.ModelConfig{
110+
Provider: s[:idx],
111+
Model: s[idx+1:],
112+
}
113+
}
114+
return nil
115+
}

cmd/root/flags_test.go

Lines changed: 80 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ import (
88
"github.com/stretchr/testify/require"
99

1010
"github.com/docker/cagent/pkg/config"
11+
"github.com/docker/cagent/pkg/config/latest"
1112
"github.com/docker/cagent/pkg/userconfig"
1213
)
1314

@@ -152,3 +153,82 @@ func TestCanonize(t *testing.T) {
152153
})
153154
}
154155
}
156+
157+
func TestDefaultModelLogic(t *testing.T) {
158+
tests := []struct {
159+
name string
160+
env string
161+
userConfig *userconfig.Config
162+
expectedProvider string
163+
expectedModel string
164+
}{
165+
{
166+
name: "env",
167+
env: "openai/gpt-4o",
168+
expectedProvider: "openai",
169+
expectedModel: "gpt-4o",
170+
},
171+
{
172+
name: "user_config",
173+
userConfig: &userconfig.Config{
174+
DefaultModel: &latest.FlexibleModelConfig{
175+
ModelConfig: latest.ModelConfig{Provider: "google", Model: "gemini-2.5-flash"},
176+
},
177+
},
178+
expectedProvider: "google",
179+
expectedModel: "gemini-2.5-flash",
180+
},
181+
{
182+
name: "env_overrides_user_config",
183+
env: "openai/gpt-4o",
184+
userConfig: &userconfig.Config{
185+
DefaultModel: &latest.FlexibleModelConfig{
186+
ModelConfig: latest.ModelConfig{Provider: "google", Model: "gemini-2.5-flash"},
187+
},
188+
},
189+
expectedProvider: "openai",
190+
expectedModel: "gpt-4o",
191+
},
192+
{
193+
name: "empty_when_not_set",
194+
expectedProvider: "",
195+
expectedModel: "",
196+
},
197+
}
198+
199+
for _, tt := range tests {
200+
t.Run(tt.name, func(t *testing.T) {
201+
t.Setenv("CAGENT_DEFAULT_MODEL", tt.env)
202+
203+
// Mock user config loader
204+
original := loadUserConfig
205+
loadUserConfig = func() (*userconfig.Config, error) {
206+
if tt.userConfig != nil {
207+
return tt.userConfig, nil
208+
}
209+
return &userconfig.Config{}, nil
210+
}
211+
t.Cleanup(func() { loadUserConfig = original })
212+
213+
cmd := &cobra.Command{
214+
RunE: func(*cobra.Command, []string) error {
215+
return nil
216+
},
217+
}
218+
runConfig := config.RuntimeConfig{}
219+
addGatewayFlags(cmd, &runConfig)
220+
221+
cmd.SetArgs(nil)
222+
err := cmd.Execute()
223+
224+
require.NoError(t, err)
225+
if tt.expectedProvider == "" && tt.expectedModel == "" {
226+
assert.Nil(t, runConfig.DefaultModel)
227+
} else {
228+
require.NotNil(t, runConfig.DefaultModel)
229+
assert.Equal(t, tt.expectedProvider, runConfig.DefaultModel.Provider)
230+
assert.Equal(t, tt.expectedModel, runConfig.DefaultModel.Model)
231+
}
232+
})
233+
}
234+
}

pkg/config/auto.go

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@ To fix this, you can:
5252

5353
var DefaultModels = map[string]string{
5454
"openai": "gpt-5-mini",
55-
"anthropic": "claude-sonnet-4-0",
55+
"anthropic": "claude-sonnet-4-5",
5656
"google": "gemini-2.5-flash",
5757
"dmr": "ai/qwen3:latest",
5858
"mistral": "mistral-small-latest",
@@ -82,7 +82,16 @@ func AvailableProviders(ctx context.Context, modelsGateway string, env environme
8282
return providers
8383
}
8484

85-
func AutoModelConfig(ctx context.Context, modelsGateway string, env environment.Provider) latest.ModelConfig {
85+
func AutoModelConfig(ctx context.Context, modelsGateway string, env environment.Provider, defaultModel *latest.ModelConfig) latest.ModelConfig {
86+
// If user specified a default model config, use it (with defaults for unset fields)
87+
if defaultModel != nil && defaultModel.Provider != "" && defaultModel.Model != "" {
88+
result := *defaultModel
89+
if result.MaxTokens == nil {
90+
result.MaxTokens = PreferredMaxTokens(result.Provider)
91+
}
92+
return result
93+
}
94+
8695
availableProviders := AvailableProviders(ctx, modelsGateway, env)
8796
firstAvailable := availableProviders[0]
8897

pkg/config/auto_test.go

Lines changed: 105 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,8 @@ import (
55
"testing"
66

77
"github.com/stretchr/testify/assert"
8+
9+
"github.com/docker/cagent/pkg/config/latest"
810
)
911

1012
type mockEnvProvider struct {
@@ -175,7 +177,7 @@ func TestAutoModelConfig(t *testing.T) {
175177
"ANTHROPIC_API_KEY": "test-key",
176178
},
177179
expectedProvider: "anthropic",
178-
expectedModel: "claude-sonnet-4-0",
180+
expectedModel: "claude-sonnet-4-5",
179181
expectedMaxTokens: 32000,
180182
},
181183
{
@@ -217,7 +219,7 @@ func TestAutoModelConfig(t *testing.T) {
217219
envVars: map[string]string{},
218220
gateway: "gateway:8080",
219221
expectedProvider: "anthropic",
220-
expectedModel: "claude-sonnet-4-0",
222+
expectedModel: "claude-sonnet-4-5",
221223
expectedMaxTokens: 32000,
222224
},
223225
}
@@ -226,7 +228,7 @@ func TestAutoModelConfig(t *testing.T) {
226228
t.Run(tt.name, func(t *testing.T) {
227229
t.Parallel()
228230

229-
modelConfig := AutoModelConfig(t.Context(), tt.gateway, &mockEnvProvider{envVars: tt.envVars})
231+
modelConfig := AutoModelConfig(t.Context(), tt.gateway, &mockEnvProvider{envVars: tt.envVars}, nil)
230232

231233
assert.Equal(t, tt.expectedProvider, modelConfig.Provider)
232234
assert.Equal(t, tt.expectedModel, modelConfig.Model)
@@ -295,7 +297,7 @@ func TestDefaultModels(t *testing.T) {
295297

296298
// Test specific model values
297299
assert.Equal(t, "gpt-5-mini", DefaultModels["openai"])
298-
assert.Equal(t, "claude-sonnet-4-0", DefaultModels["anthropic"])
300+
assert.Equal(t, "claude-sonnet-4-5", DefaultModels["anthropic"])
299301
assert.Equal(t, "gemini-2.5-flash", DefaultModels["google"])
300302
assert.Equal(t, "ai/qwen3:latest", DefaultModels["dmr"])
301303
assert.Equal(t, "mistral-small-latest", DefaultModels["mistral"])
@@ -326,7 +328,7 @@ func TestAutoModelConfig_IntegrationWithDefaultModels(t *testing.T) {
326328
envVars["MISTRAL_API_KEY"] = "test-key"
327329
}
328330

329-
modelConfig := AutoModelConfig(t.Context(), "", &mockEnvProvider{envVars: envVars})
331+
modelConfig := AutoModelConfig(t.Context(), "", &mockEnvProvider{envVars: envVars}, nil)
330332

331333
// Verify the returned model matches the DefaultModels entry
332334
expectedModel := DefaultModels[provider]
@@ -339,7 +341,7 @@ func TestAutoModelConfig_IntegrationWithDefaultModels(t *testing.T) {
339341
t.Run("dmr", func(t *testing.T) {
340342
t.Parallel()
341343

342-
modelConfig := AutoModelConfig(t.Context(), "", &mockEnvProvider{envVars: map[string]string{}})
344+
modelConfig := AutoModelConfig(t.Context(), "", &mockEnvProvider{envVars: map[string]string{}}, nil)
343345

344346
assert.Equal(t, "dmr", modelConfig.Provider)
345347
assert.Equal(t, DefaultModels["dmr"], modelConfig.Model)
@@ -399,3 +401,100 @@ func TestAvailableProviders_PrecedenceOrder(t *testing.T) {
399401
providers = AvailableProviders(t.Context(), "", env)
400402
assert.Equal(t, "dmr", providers[0])
401403
}
404+
405+
func TestAutoModelConfig_UserDefaultModel(t *testing.T) {
406+
t.Parallel()
407+
408+
tests := []struct {
409+
name string
410+
defaultModel *latest.ModelConfig
411+
envVars map[string]string
412+
expectedProvider string
413+
expectedModel string
414+
expectedMaxTokens int64
415+
}{
416+
{
417+
name: "user default model overrides auto detection",
418+
defaultModel: &latest.ModelConfig{Provider: "openai", Model: "gpt-4o"},
419+
envVars: map[string]string{"ANTHROPIC_API_KEY": "test-key"},
420+
expectedProvider: "openai",
421+
expectedModel: "gpt-4o",
422+
expectedMaxTokens: 32000,
423+
},
424+
{
425+
name: "user default model with dmr provider",
426+
defaultModel: &latest.ModelConfig{Provider: "dmr", Model: "ai/llama3.2"},
427+
envVars: map[string]string{"OPENAI_API_KEY": "test-key"},
428+
expectedProvider: "dmr",
429+
expectedModel: "ai/llama3.2",
430+
expectedMaxTokens: 16000,
431+
},
432+
{
433+
name: "user default model with anthropic provider",
434+
defaultModel: &latest.ModelConfig{Provider: "anthropic", Model: "claude-sonnet-4-0"},
435+
envVars: map[string]string{},
436+
expectedProvider: "anthropic",
437+
expectedModel: "claude-sonnet-4-0",
438+
expectedMaxTokens: 32000,
439+
},
440+
{
441+
name: "nil default model falls back to auto detection",
442+
defaultModel: nil,
443+
envVars: map[string]string{"GOOGLE_API_KEY": "test-key"},
444+
expectedProvider: "google",
445+
expectedModel: "gemini-2.5-flash",
446+
expectedMaxTokens: 32000,
447+
},
448+
{
449+
name: "empty provider falls back to auto detection",
450+
defaultModel: &latest.ModelConfig{Provider: "", Model: "model-only"},
451+
envVars: map[string]string{"MISTRAL_API_KEY": "test-key"},
452+
expectedProvider: "mistral",
453+
expectedModel: "mistral-small-latest",
454+
expectedMaxTokens: 32000,
455+
},
456+
{
457+
name: "empty model falls back to auto detection",
458+
defaultModel: &latest.ModelConfig{Provider: "openai", Model: ""},
459+
envVars: map[string]string{"ANTHROPIC_API_KEY": "test-key"},
460+
expectedProvider: "anthropic",
461+
expectedModel: "claude-sonnet-4-5",
462+
expectedMaxTokens: 32000,
463+
},
464+
}
465+
466+
for _, tt := range tests {
467+
t.Run(tt.name, func(t *testing.T) {
468+
t.Parallel()
469+
470+
modelConfig := AutoModelConfig(t.Context(), "", &mockEnvProvider{envVars: tt.envVars}, tt.defaultModel)
471+
472+
assert.Equal(t, tt.expectedProvider, modelConfig.Provider)
473+
assert.Equal(t, tt.expectedModel, modelConfig.Model)
474+
assert.Equal(t, tt.expectedMaxTokens, *modelConfig.MaxTokens)
475+
})
476+
}
477+
}
478+
479+
func TestAutoModelConfig_UserDefaultModelWithOptions(t *testing.T) {
480+
t.Parallel()
481+
482+
// Test that user-provided options like max_tokens, thinking_budget are preserved
483+
customMaxTokens := int64(64000)
484+
thinkingBudget := &latest.ThinkingBudget{Tokens: 10000}
485+
486+
defaultModel := &latest.ModelConfig{
487+
Provider: "anthropic",
488+
Model: "claude-sonnet-4-5",
489+
MaxTokens: &customMaxTokens,
490+
ThinkingBudget: thinkingBudget,
491+
}
492+
493+
modelConfig := AutoModelConfig(t.Context(), "", &mockEnvProvider{envVars: map[string]string{}}, defaultModel)
494+
495+
assert.Equal(t, "anthropic", modelConfig.Provider)
496+
assert.Equal(t, "claude-sonnet-4-5", modelConfig.Model)
497+
assert.Equal(t, int64(64000), *modelConfig.MaxTokens)
498+
assert.NotNil(t, modelConfig.ThinkingBudget)
499+
assert.Equal(t, 10000, modelConfig.ThinkingBudget.Tokens)
500+
}

0 commit comments

Comments
 (0)