Skip to content

Commit ac79960

Browse files
authored
Merge pull request #679 from dgageot/fix-676
Fix model cloning
2 parents f29bc29 + 6f8510d commit ac79960

10 files changed

Lines changed: 65 additions & 83 deletions

File tree

pkg/model/provider/anthropic/client.go

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ type Client struct {
3131
// models:provider_opts:interleaved_thinking: true
3232
func (c *Client) interleavedThinkingEnabled() bool {
3333
// Default to false if not provided
34-
if c == nil || c.ModelConfig == nil || len(c.ModelConfig.ProviderOpts) == 0 {
34+
if c == nil || len(c.ModelConfig.ProviderOpts) == 0 {
3535
return false
3636
}
3737
v, ok := c.ModelConfig.ProviderOpts["interleaved_thinking"]
@@ -121,14 +121,15 @@ func NewClient(ctx context.Context, cfg *latest.ModelConfig, env environment.Pro
121121

122122
slog.Debug("Anthropic client created successfully", "model", cfg.Model)
123123

124-
if globalOptions.StructuredOutput != nil {
125-
return &Client{}, errors.New("anthropic does not support native structured_output")
124+
if globalOptions.StructuredOutput() != nil {
125+
return nil, errors.New("anthropic does not support native structured_output")
126126
}
127127

128128
return &Client{
129129
Config: base.Config{
130-
ModelConfig: cfg,
130+
ModelConfig: *cfg,
131131
ModelOptions: globalOptions,
132+
Env: env,
132133
},
133134
clientFn: clientFn,
134135
}, nil

pkg/model/provider/base/base.go

Lines changed: 5 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -2,30 +2,23 @@ package base
22

33
import (
44
latest "github.com/docker/cagent/pkg/config/v2"
5+
"github.com/docker/cagent/pkg/environment"
56
"github.com/docker/cagent/pkg/model/provider/options"
67
)
78

89
// Config is a common base configuration shared by all provider clients.
910
// It can be embedded in provider-specific Client structs to avoid code duplication.
1011
type Config struct {
11-
ModelConfig *latest.ModelConfig
12+
ModelConfig latest.ModelConfig
1213
ModelOptions options.ModelOptions
14+
Env environment.Provider
1315
}
1416

1517
// ID returns the provider and model ID in the format "provider/model"
1618
func (c *Config) ID() string {
1719
return c.ModelConfig.Provider + "/" + c.ModelConfig.Model
1820
}
1921

20-
// MaxTokens returns the maximum tokens configured for this provider's model
21-
func (c *Config) MaxTokens() int {
22-
if c.ModelConfig == nil {
23-
return 0
24-
}
25-
return c.ModelConfig.MaxTokens
26-
}
27-
28-
// Options returns the effective model options used by this provider's model
29-
func (c *Config) Options() options.ModelOptions {
30-
return c.ModelOptions
22+
func (c *Config) BaseConfig() Config {
23+
return *c
3124
}

pkg/model/provider/clone.go

Lines changed: 7 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -3,40 +3,25 @@ package provider
33
import (
44
"context"
55
"log/slog"
6-
"strings"
76

8-
latest "github.com/docker/cagent/pkg/config/v2"
9-
"github.com/docker/cagent/pkg/environment"
107
"github.com/docker/cagent/pkg/model/provider/options"
118
)
129

1310
// CloneWithOptions returns a new Provider instance using the same provider/model
1411
// as the base provider, applying the provided options. If cloning fails, the
1512
// original base provider is returned.
16-
func CloneWithOptions(ctx context.Context, base Provider, env environment.Provider, opts ...options.Opt) Provider {
17-
if base == nil {
18-
return nil
19-
}
20-
21-
id := strings.TrimSpace(base.ID())
22-
parts := strings.SplitN(id, "/", 2)
23-
if len(parts) != 2 {
24-
return base
25-
}
26-
27-
cfg := &latest.ModelConfig{Provider: parts[0], Model: parts[1]}
28-
if env == nil {
29-
env = environment.NewDefaultProvider()
30-
}
13+
func CloneWithOptions(ctx context.Context, base Provider, opts ...options.Opt) Provider {
14+
config := base.BaseConfig()
3115

3216
// Preserve existing options, then apply overrides. Later opts take precedence.
33-
baseOpts := options.FromModelOptions(base.Options())
17+
baseOpts := options.FromModelOptions(config.ModelOptions)
3418
mergedOpts := append(baseOpts, opts...)
3519

36-
cloned, err := New(ctx, cfg, env, mergedOpts...)
20+
clone, err := New(ctx, &config.ModelConfig, config.Env, mergedOpts...)
3721
if err != nil {
38-
slog.Debug("Failed to clone provider; using base provider", "error", err, "id", id)
22+
slog.Debug("Failed to clone provider; using base provider", "error", err, "id", base.ID())
3923
return base
4024
}
41-
return cloned
25+
26+
return clone
4227
}

pkg/model/provider/dmr/client.go

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -107,7 +107,7 @@ func NewClient(ctx context.Context, cfg *latest.ModelConfig, opts ...options.Opt
107107

108108
return &Client{
109109
Config: base.Config{
110-
ModelConfig: cfg,
110+
ModelConfig: *cfg,
111111
ModelOptions: globalOptions,
112112
},
113113
client: openai.NewClientWithConfig(clientConfig),
@@ -384,15 +384,16 @@ func (c *Client) CreateChatCompletionStream(ctx context.Context, messages []chat
384384
} else {
385385
slog.Error("Failed to marshal DMR request to JSON", "error", err)
386386
}
387-
if c.ModelOptions.StructuredOutput != nil {
388-
slog.Debug("Adding structured output to DMR request", "structured_output", c.ModelOptions.StructuredOutput)
387+
if structuredOutput := c.ModelOptions.StructuredOutput(); structuredOutput != nil {
388+
slog.Debug("Adding structured output to DMR request", "structured_output", structuredOutput)
389+
389390
request.ResponseFormat = &openai.ChatCompletionResponseFormat{
390391
Type: openai.ChatCompletionResponseFormatTypeJSONSchema,
391392
JSONSchema: &openai.ChatCompletionResponseFormatJSONSchema{
392-
Name: c.ModelOptions.StructuredOutput.Name,
393-
Description: c.ModelOptions.StructuredOutput.Description,
394-
Schema: jsonSchema(c.ModelOptions.StructuredOutput.Schema),
395-
Strict: c.ModelOptions.StructuredOutput.Strict,
393+
Name: structuredOutput.Name,
394+
Description: structuredOutput.Description,
395+
Schema: jsonSchema(structuredOutput.Schema),
396+
Strict: structuredOutput.Strict,
396397
},
397398
}
398399
}

pkg/model/provider/gemini/client.go

Lines changed: 8 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -37,13 +37,13 @@ func NewClient(ctx context.Context, cfg *latest.ModelConfig, env environment.Pro
3737
return nil, errors.New("model type must be 'google'")
3838
}
3939

40-
var modelOptions options.ModelOptions
40+
var globalOptions options.ModelOptions
4141
for _, opt := range opts {
42-
opt(&modelOptions)
42+
opt(&globalOptions)
4343
}
4444

4545
var clientFn func(context.Context) (*genai.Client, error)
46-
if gateway := modelOptions.Gateway(); gateway == "" {
46+
if gateway := globalOptions.Gateway(); gateway == "" {
4747
apiKey := env.Get(ctx, "GOOGLE_API_KEY")
4848
if apiKey == "" {
4949
return nil, errors.New("GOOGLE_API_KEY environment variable is required")
@@ -101,8 +101,9 @@ func NewClient(ctx context.Context, cfg *latest.ModelConfig, env environment.Pro
101101

102102
return &Client{
103103
Config: base.Config{
104-
ModelConfig: cfg,
105-
ModelOptions: modelOptions,
104+
ModelConfig: *cfg,
105+
ModelOptions: globalOptions,
106+
Env: env,
106107
},
107108
clientFn: clientFn,
108109
}, nil
@@ -213,10 +214,6 @@ func convertMessagesToGemini(messages []chat.Message) []*genai.Content {
213214

214215
// buildConfig creates GenerateContentConfig from model config
215216
func (c *Client) buildConfig() *genai.GenerateContentConfig {
216-
if c.ModelConfig == nil {
217-
return nil
218-
}
219-
220217
config := &genai.GenerateContentConfig{
221218
Temperature: genai.Ptr(float32(c.ModelConfig.Temperature)),
222219
TopP: genai.Ptr(float32(c.ModelConfig.TopP)),
@@ -251,9 +248,9 @@ func (c *Client) buildConfig() *genai.GenerateContentConfig {
251248
}
252249
}
253250

254-
if c.ModelOptions.StructuredOutput != nil {
251+
if structuredOutput := c.ModelOptions.StructuredOutput(); structuredOutput != nil {
255252
config.ResponseMIMEType = "application/json"
256-
config.ResponseJsonSchema = c.ModelOptions.StructuredOutput.Schema
253+
config.ResponseJsonSchema = structuredOutput.Schema
257254
}
258255

259256
return config

pkg/model/provider/openai/client.go

Lines changed: 14 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -123,8 +123,9 @@ func NewClient(ctx context.Context, cfg *latest.ModelConfig, env environment.Pro
123123

124124
return &Client{
125125
Config: base.Config{
126-
ModelConfig: cfg,
126+
ModelConfig: *cfg,
127127
ModelOptions: globalOptions,
128+
Env: env,
128129
},
129130
clientFn: clientFn,
130131
}, nil
@@ -235,12 +236,12 @@ func (c *Client) CreateChatCompletionStream(
235236
},
236237
}
237238

238-
if c.MaxTokens() > 0 {
239+
if maxToken := c.ModelConfig.MaxTokens; maxToken > 0 {
239240
if !isResponsesOnlyModel(c.ModelConfig.Model) {
240-
request.MaxTokens = c.MaxTokens()
241-
slog.Debug("OpenAI request configured with max tokens", "max_tokens", c.MaxTokens())
241+
request.MaxTokens = maxToken
242+
slog.Debug("OpenAI request configured with max tokens", "max_tokens", maxToken, "model", c.ModelConfig.Model)
242243
} else {
243-
request.MaxCompletionTokens = c.MaxTokens()
244+
request.MaxCompletionTokens = maxToken
244245
slog.Debug("using max_completion_tokens instead of max_tokens for Responses-API models", "model", c.ModelConfig.Model)
245246
}
246247
}
@@ -273,7 +274,7 @@ func (c *Client) CreateChatCompletionStream(
273274

274275
// Apply thinking budget: set reasoning_effort parameter
275276
if c.ModelConfig.ThinkingBudget != nil {
276-
effort, err := getOpenAIReasoningEffort(c.ModelConfig)
277+
effort, err := getOpenAIReasoningEffort(&c.ModelConfig)
277278
if err != nil {
278279
slog.Error("OpenAI request using thinking_budget failed", "error", err)
279280
return nil, err
@@ -283,17 +284,18 @@ func (c *Client) CreateChatCompletionStream(
283284
}
284285

285286
// Apply structured output configuration
286-
if c.ModelOptions.StructuredOutput != nil {
287+
if structuredOutput := c.ModelOptions.StructuredOutput(); structuredOutput != nil {
288+
slog.Debug("OpenAI request using structured output", "name", structuredOutput.Name, "strict", structuredOutput.Strict)
289+
287290
request.ResponseFormat = &openai.ChatCompletionResponseFormat{
288291
Type: openai.ChatCompletionResponseFormatTypeJSONSchema,
289292
JSONSchema: &openai.ChatCompletionResponseFormatJSONSchema{
290-
Name: c.ModelOptions.StructuredOutput.Name,
291-
Description: c.ModelOptions.StructuredOutput.Description,
292-
Schema: jsonSchema(c.ModelOptions.StructuredOutput.Schema),
293-
Strict: c.ModelOptions.StructuredOutput.Strict,
293+
Name: structuredOutput.Name,
294+
Description: structuredOutput.Description,
295+
Schema: jsonSchema(structuredOutput.Schema),
296+
Strict: structuredOutput.Strict,
294297
},
295298
}
296-
slog.Debug("OpenAI request using structured output", "name", c.ModelOptions.StructuredOutput.Name, "strict", c.ModelOptions.StructuredOutput.Strict)
297299
}
298300

299301
// Log the request in JSON format for debugging

pkg/model/provider/options/options.go

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -6,13 +6,17 @@ import (
66

77
type ModelOptions struct {
88
gateway string
9-
StructuredOutput *latest.StructuredOutput
9+
structuredOutput *latest.StructuredOutput
1010
}
1111

1212
func (c *ModelOptions) Gateway() string {
1313
return c.gateway
1414
}
1515

16+
func (c *ModelOptions) StructuredOutput() *latest.StructuredOutput {
17+
return c.structuredOutput
18+
}
19+
1620
type Opt func(*ModelOptions)
1721

1822
func WithGateway(gateway string) Opt {
@@ -21,9 +25,9 @@ func WithGateway(gateway string) Opt {
2125
}
2226
}
2327

24-
func WithStructuredOutput(output *latest.StructuredOutput) Opt {
28+
func WithStructuredOutput(structuredOutput *latest.StructuredOutput) Opt {
2529
return func(cfg *ModelOptions) {
26-
cfg.StructuredOutput = output
30+
cfg.structuredOutput = structuredOutput
2731
}
2832
}
2933

@@ -34,8 +38,8 @@ func FromModelOptions(m ModelOptions) []Opt {
3438
if g := m.Gateway(); g != "" {
3539
out = append(out, WithGateway(g))
3640
}
37-
if m.StructuredOutput != nil {
38-
out = append(out, WithStructuredOutput(m.StructuredOutput))
41+
if m.structuredOutput != nil {
42+
out = append(out, WithStructuredOutput(m.structuredOutput))
3943
}
4044
return out
4145
}

pkg/model/provider/provider.go

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ import (
99
latest "github.com/docker/cagent/pkg/config/v2"
1010
"github.com/docker/cagent/pkg/environment"
1111
"github.com/docker/cagent/pkg/model/provider/anthropic"
12+
"github.com/docker/cagent/pkg/model/provider/base"
1213
"github.com/docker/cagent/pkg/model/provider/dmr"
1314
"github.com/docker/cagent/pkg/model/provider/gemini"
1415
"github.com/docker/cagent/pkg/model/provider/openai"
@@ -52,10 +53,8 @@ type Provider interface {
5253
messages []chat.Message,
5354
tools []tools.Tool,
5455
) (chat.MessageStream, error)
55-
// Options returns the effective model options used by this provider
56-
Options() options.ModelOptions
57-
// MaxTokens returns the maximum tokens configured for this provider
58-
MaxTokens() int
56+
// BaseConfig returns the base configuration of this provider
57+
BaseConfig() base.Config
5958
}
6059

6160
func New(ctx context.Context, cfg *latest.ModelConfig, env environment.Provider, opts ...options.Opt) (Provider, error) {

pkg/runtime/runtime.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1014,7 +1014,7 @@ func (r *LocalRuntime) generateSessionTitle(ctx context.Context, sess *session.S
10141014
systemPrompt := "You are a helpful AI assistant that generates concise, descriptive titles for conversations. You will be given a conversation history and asked to create a title that captures the main topic."
10151015
userPrompt := fmt.Sprintf("Based on the following message a user sent to an AI assistant, generate a short, descriptive title (maximum 50 characters) that captures the main topic or purpose of the conversation. Return ONLY the title text, nothing else.\n\nUser message:%s\n\n", conversationHistory.String())
10161016

1017-
titleModel := provider.CloneWithOptions(ctx, r.CurrentAgent().Model(), nil, options.WithStructuredOutput(nil))
1017+
titleModel := provider.CloneWithOptions(ctx, r.CurrentAgent().Model(), options.WithStructuredOutput(nil))
10181018
newTeam := team.New(
10191019
team.WithID("title-generator"),
10201020
team.WithAgents(agent.New("root", systemPrompt, agent.WithModel(titleModel))),
@@ -1080,7 +1080,7 @@ func (r *LocalRuntime) Summarize(ctx context.Context, sess *session.Session, eve
10801080
// Create a new session for summary generation
10811081
systemPrompt := "You are a helpful AI assistant that creates comprehensive summaries of conversations. You will be given a conversation history and asked to create a concise yet thorough summary that captures the key points, decisions made, and outcomes."
10821082
userPrompt := fmt.Sprintf("Based on the following conversation between a user and an AI assistant, create a comprehensive summary that captures:\n- The main topics discussed\n- Key information exchanged\n- Decisions made or conclusions reached\n- Important outcomes or results\n\nProvide a well-structured summary (2-4 paragraphs) that someone could read to understand what happened in this conversation. Return ONLY the summary text, nothing else.\n\nConversation history:%s\n\nGenerate a summary for this conversation:", conversationHistory.String())
1083-
newModel := provider.CloneWithOptions(ctx, r.CurrentAgent().Model(), nil, options.WithStructuredOutput(nil))
1083+
newModel := provider.CloneWithOptions(ctx, r.CurrentAgent().Model(), options.WithStructuredOutput(nil))
10841084
newTeam := team.New(
10851085
team.WithID("summary-generator"),
10861086
team.WithAgents(agent.New("root", systemPrompt, agent.WithModel(newModel))),

pkg/runtime/runtime_test.go

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ import (
1313

1414
"github.com/docker/cagent/pkg/agent"
1515
"github.com/docker/cagent/pkg/chat"
16-
"github.com/docker/cagent/pkg/model/provider/options"
16+
"github.com/docker/cagent/pkg/model/provider/base"
1717
"github.com/docker/cagent/pkg/modelsdev"
1818
"github.com/docker/cagent/pkg/session"
1919
"github.com/docker/cagent/pkg/team"
@@ -143,7 +143,7 @@ func (m *mockProvider) CreateChatCompletionStream(context.Context, []chat.Messag
143143
return m.stream, nil
144144
}
145145

146-
func (m *mockProvider) Options() options.ModelOptions { return options.ModelOptions{} }
146+
func (m *mockProvider) BaseConfig() base.Config { return base.Config{} }
147147

148148
func (m *mockProvider) MaxTokens() int { return 0 }
149149

@@ -157,7 +157,7 @@ func (m *mockProviderWithError) CreateChatCompletionStream(context.Context, []ch
157157
return nil, fmt.Errorf("simulated error creating chat completion stream")
158158
}
159159

160-
func (m *mockProviderWithError) Options() options.ModelOptions { return options.ModelOptions{} }
160+
func (m *mockProviderWithError) BaseConfig() base.Config { return base.Config{} }
161161

162162
func (m *mockProviderWithError) MaxTokens() int { return 0 }
163163

@@ -453,7 +453,7 @@ func (p *queueProvider) CreateChatCompletionStream(context.Context, []chat.Messa
453453
return s, nil
454454
}
455455

456-
func (p *queueProvider) Options() options.ModelOptions { return options.ModelOptions{} }
456+
func (p *queueProvider) BaseConfig() base.Config { return base.Config{} }
457457

458458
func (p *queueProvider) MaxTokens() int { return 0 }
459459

0 commit comments

Comments
 (0)