Skip to content

Commit 72a6047

Browse files
authored
Merge pull request #437 from Deepam02/feature/model-override-flag
Feature/model override flag
2 parents 7d68941 + 955fd19 commit 72a6047

6 files changed

Lines changed: 275 additions & 1 deletion

File tree

cmd/root/exec.go

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ func NewExecCmd() *cobra.Command {
1414
cmd.PersistentFlags().StringVar(&workingDir, "working-dir", "", "Set the working directory for the session (applies to tools and relative paths)")
1515
cmd.PersistentFlags().BoolVar(&autoApprove, "yolo", false, "Automatically approve all tool calls without prompting")
1616
cmd.PersistentFlags().StringVar(&attachmentPath, "attach", "", "Attach an image file to the message")
17+
cmd.PersistentFlags().StringArrayVar(&modelOverrides, "model", nil, "Override agent model: [agent=]provider/model (repeatable)")
1718
cmd.PersistentFlags().BoolVar(&dryRun, "dry-run", false, "Initialize the agent without executing anything")
1819
_ = cmd.PersistentFlags().MarkHidden("dry-run")
1920

cmd/root/run.go

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@ var (
4343
remoteAddress string
4444
dryRun bool
4545
commandName string
46+
modelOverrides []string
4647
)
4748

4849
const commandListSentinel = "__LIST__"
@@ -68,6 +69,7 @@ func NewRunCmd() *cobra.Command {
6869
cmd.PersistentFlags().BoolVar(&useTUI, "tui", true, "Run the agent with a Terminal User Interface (TUI)")
6970
cmd.PersistentFlags().StringVar(&remoteAddress, "remote", "", "Use remote runtime with specified address (only supported with TUI)")
7071
cmd.PersistentFlags().StringVarP(&commandName, "command", "c", "", "Run a named command from the agent's commands section")
72+
cmd.PersistentFlags().StringArrayVar(&modelOverrides, "model", nil, "Override agent model: [agent=]provider/model (repeatable)")
7173
if f := cmd.PersistentFlags().Lookup("command"); f != nil {
7274
// Allow `-c` without value to list available commands
7375
f.NoOptDefVal = commandListSentinel
@@ -194,7 +196,7 @@ func doRunCommand(ctx context.Context, args []string, exec bool) error {
194196
runConfig.RedirectURI = "http://localhost:8083/oauth-callback"
195197
}
196198

197-
agents, err = teamloader.Load(ctx, agentFilename, runConfig)
199+
agents, err = teamloader.LoadWithOverrides(ctx, agentFilename, runConfig, modelOverrides)
198200
if err != nil {
199201
return err
200202
}

docs/USAGE.md

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,11 @@ $ cagent run config.yaml --yolo # Auto-accept all the tool calls
4343
$ cagent run config.yaml "First message" # Start the conversation with the agent with a first message
4444
$ cagent run config.yaml -c df # Run with a named command from YAML
4545

46+
# Model Override Examples
47+
$ cagent run config.yaml --model anthropic/claude-sonnet-4-0 # Override all agents to use Claude
48+
$ cagent run config.yaml --model "agent1=openai/gpt-4o" # Override specific agent
49+
$ cagent run config.yaml --model "agent1=openai/gpt-4o,agent2=anthropic/claude-sonnet-4-0" # Multiple overrides
50+
4651
# One off without TUI
4752
$ cagent exec config.yaml # Run the agent once, with default instructions
4853
$ cagent exec config.yaml "First message" # Run the agent once with instructions

pkg/config/overrides.go

Lines changed: 106 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,106 @@
1+
package config
2+
3+
import (
4+
"fmt"
5+
"strings"
6+
7+
v2 "github.com/docker/cagent/pkg/config/v2"
8+
)
9+
10+
// ApplyModelOverrides applies CLI model overrides to the configuration
11+
func ApplyModelOverrides(cfg *v2.Config, overrides []string) error {
12+
for _, override := range overrides {
13+
if err := applySingleOverride(cfg, override); err != nil {
14+
return err
15+
}
16+
}
17+
18+
// After applying overrides, ensure new models are added to cfg.Models
19+
return ensureModelsExist(cfg)
20+
}
21+
22+
// applySingleOverride processes a single model override string
23+
func applySingleOverride(cfg *v2.Config, override string) error {
24+
override = strings.TrimSpace(override)
25+
if override == "" {
26+
return nil // Skip empty overrides
27+
}
28+
29+
// Handle comma-separated format: "agent1=model1,agent2=model2"
30+
if strings.Contains(override, ",") {
31+
for part := range strings.SplitSeq(override, ",") {
32+
if err := applySingleOverride(cfg, part); err != nil {
33+
return err
34+
}
35+
}
36+
return nil
37+
}
38+
39+
// Check if this is an agent-specific override (contains '=')
40+
agentName, modelSpec, ok := strings.Cut(override, "=")
41+
if ok {
42+
agentName = strings.TrimSpace(agentName)
43+
if agentName == "" {
44+
return fmt.Errorf("empty agent name in override: %s", override)
45+
}
46+
47+
modelSpec = strings.TrimSpace(modelSpec)
48+
if modelSpec == "" {
49+
return fmt.Errorf("empty model specification in override: %s", override)
50+
}
51+
52+
// Apply to specific agent
53+
agentConfig, exists := cfg.Agents[agentName]
54+
if !exists {
55+
return fmt.Errorf("unknown agent '%s'", agentName)
56+
}
57+
58+
agentConfig.Model = modelSpec
59+
cfg.Agents[agentName] = agentConfig
60+
} else {
61+
// Global override: apply to all agents
62+
modelSpec := strings.TrimSpace(override)
63+
if modelSpec == "" {
64+
return fmt.Errorf("empty model specification")
65+
}
66+
67+
for name := range cfg.Agents {
68+
agentConfig := cfg.Agents[name]
69+
agentConfig.Model = modelSpec
70+
cfg.Agents[name] = agentConfig
71+
}
72+
}
73+
74+
return nil
75+
}
76+
77+
// ensureModelsExist ensures that all models referenced by agents exist in cfg.Models
78+
// This handles inline model specs that may have been added via CLI overrides
79+
func ensureModelsExist(cfg *v2.Config) error {
80+
if cfg.Models == nil {
81+
cfg.Models = map[string]v2.ModelConfig{}
82+
}
83+
84+
for agentName := range cfg.Agents {
85+
agentConfig := cfg.Agents[agentName]
86+
87+
modelNames := strings.SplitSeq(agentConfig.Model, ",")
88+
for modelName := range modelNames {
89+
if _, exists := cfg.Models[modelName]; exists {
90+
continue
91+
}
92+
93+
providerName, model, ok := strings.Cut(modelName, "/")
94+
if !ok {
95+
return fmt.Errorf("agent '%s' references non-existent model '%s'", agentName, modelName)
96+
}
97+
98+
cfg.Models[modelName] = v2.ModelConfig{
99+
Provider: providerName,
100+
Model: model,
101+
}
102+
}
103+
}
104+
105+
return nil
106+
}

pkg/teamloader/teamloader.go

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -97,6 +97,10 @@ func checkRequiredEnvVars(ctx context.Context, cfg *latest.Config, env environme
9797
}
9898

9999
func Load(ctx context.Context, path string, runtimeConfig config.RuntimeConfig) (*team.Team, error) {
100+
return LoadWithOverrides(ctx, path, runtimeConfig, nil)
101+
}
102+
103+
func LoadWithOverrides(ctx context.Context, path string, runtimeConfig config.RuntimeConfig, modelOverrides []string) (*team.Team, error) {
100104
fileName := filepath.Base(path)
101105
parentDir := filepath.Dir(path)
102106

@@ -123,6 +127,11 @@ func Load(ctx context.Context, path string, runtimeConfig config.RuntimeConfig)
123127
return nil, err
124128
}
125129

130+
// Apply model overrides from CLI flags before checking required env vars
131+
if err := config.ApplyModelOverrides(cfg, modelOverrides); err != nil {
132+
return nil, err
133+
}
134+
126135
// Early check for required env vars before loading models and tools.
127136
if err := checkRequiredEnvVars(ctx, cfg, env, runtimeConfig); err != nil {
128137
return nil, err

pkg/teamloader/teamloader_test.go

Lines changed: 151 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ import (
1111
"github.com/stretchr/testify/require"
1212

1313
"github.com/docker/cagent/pkg/config"
14+
latest "github.com/docker/cagent/pkg/config/v2"
1415
"github.com/docker/cagent/pkg/environment"
1516
)
1617

@@ -160,3 +161,153 @@ func collectExamples(t *testing.T) []string {
160161

161162
return files
162163
}
164+
165+
func TestApplyModelOverrides(t *testing.T) {
166+
tests := []struct {
167+
name string
168+
agents map[string]latest.AgentConfig
169+
overrides []string
170+
expected map[string]string // agent name -> expected model
171+
expectError bool
172+
errorMsg string
173+
}{
174+
{
175+
name: "global override",
176+
agents: map[string]latest.AgentConfig{
177+
"root": {Model: "openai/gpt-4"},
178+
"other": {Model: "anthropic/claude-3"},
179+
},
180+
overrides: []string{"google/gemini-pro"},
181+
expected: map[string]string{
182+
"root": "google/gemini-pro",
183+
"other": "google/gemini-pro",
184+
},
185+
},
186+
{
187+
name: "single per-agent override",
188+
agents: map[string]latest.AgentConfig{
189+
"root": {Model: "openai/gpt-4"},
190+
"other": {Model: "anthropic/claude-3"},
191+
},
192+
overrides: []string{"other=google/gemini-pro"},
193+
expected: map[string]string{
194+
"root": "openai/gpt-4",
195+
"other": "google/gemini-pro",
196+
},
197+
},
198+
{
199+
name: "multiple separate flags",
200+
agents: map[string]latest.AgentConfig{
201+
"root": {Model: "openai/gpt-4"},
202+
"other": {Model: "anthropic/claude-3"},
203+
},
204+
overrides: []string{"root=openai/gpt-5", "other=anthropic/claude-sonnet-4-0"},
205+
expected: map[string]string{
206+
"root": "openai/gpt-5",
207+
"other": "anthropic/claude-sonnet-4-0",
208+
},
209+
},
210+
{
211+
name: "comma-separated format",
212+
agents: map[string]latest.AgentConfig{
213+
"root": {Model: "openai/gpt-4"},
214+
"other": {Model: "anthropic/claude-3"},
215+
"third": {Model: "google/gemini-pro"},
216+
},
217+
overrides: []string{"root=openai/gpt-5,other=anthropic/claude-sonnet-4-0"},
218+
expected: map[string]string{
219+
"root": "openai/gpt-5",
220+
"other": "anthropic/claude-sonnet-4-0",
221+
"third": "google/gemini-pro",
222+
},
223+
},
224+
{
225+
name: "mixed formats",
226+
agents: map[string]latest.AgentConfig{
227+
"root": {Model: "openai/gpt-4"},
228+
"other": {Model: "anthropic/claude-3"},
229+
"third": {Model: "google/gemini-pro"},
230+
"reviewer": {Model: "openai/gpt-3.5-turbo"},
231+
},
232+
overrides: []string{"root=openai/gpt-5,other=anthropic/claude-4", "reviewer=google/gemini-1.5-pro"},
233+
expected: map[string]string{
234+
"root": "openai/gpt-5",
235+
"other": "anthropic/claude-4",
236+
"third": "google/gemini-pro",
237+
"reviewer": "google/gemini-1.5-pro",
238+
},
239+
},
240+
{
241+
name: "last override wins",
242+
agents: map[string]latest.AgentConfig{
243+
"root": {Model: "openai/gpt-4"},
244+
},
245+
overrides: []string{"root=openai/gpt-5", "root=anthropic/claude-4"},
246+
expected: map[string]string{
247+
"root": "anthropic/claude-4",
248+
},
249+
},
250+
{
251+
name: "unknown agent error",
252+
agents: map[string]latest.AgentConfig{
253+
"root": {Model: "openai/gpt-4"},
254+
},
255+
overrides: []string{"nonexistent=openai/gpt-5"},
256+
expectError: true,
257+
errorMsg: "unknown agent 'nonexistent'",
258+
},
259+
{
260+
name: "empty model spec error",
261+
agents: map[string]latest.AgentConfig{
262+
"root": {Model: "openai/gpt-4"},
263+
},
264+
overrides: []string{"root="},
265+
expectError: true,
266+
errorMsg: "empty model specification in override: root=",
267+
},
268+
{
269+
name: "empty global model spec is skipped",
270+
agents: map[string]latest.AgentConfig{
271+
"root": {Model: "openai/gpt-4"},
272+
},
273+
overrides: []string{""},
274+
expected: map[string]string{
275+
"root": "openai/gpt-4",
276+
},
277+
},
278+
{
279+
name: "whitespace handling",
280+
agents: map[string]latest.AgentConfig{
281+
"root": {Model: "openai/gpt-4"},
282+
"other": {Model: "anthropic/claude-3"},
283+
},
284+
overrides: []string{" root = openai/gpt-5 , other = anthropic/claude-4 "},
285+
expected: map[string]string{
286+
"root": "openai/gpt-5",
287+
"other": "anthropic/claude-4",
288+
},
289+
},
290+
}
291+
292+
for _, tt := range tests {
293+
t.Run(tt.name, func(t *testing.T) {
294+
t.Parallel()
295+
296+
cfg := &latest.Config{
297+
Agents: tt.agents,
298+
Models: make(map[string]latest.ModelConfig),
299+
}
300+
301+
err := config.ApplyModelOverrides(cfg, tt.overrides)
302+
303+
if tt.expectError {
304+
require.ErrorContains(t, err, tt.errorMsg)
305+
} else {
306+
require.NoError(t, err)
307+
for agentName, expectedModel := range tt.expected {
308+
assert.Equal(t, expectedModel, cfg.Agents[agentName].Model)
309+
}
310+
}
311+
})
312+
}
313+
}

0 commit comments

Comments
 (0)