Skip to content

Commit b1a2a7b

Browse files
committed
config: Make max tokens a pointer
Since it's optional Signed-off-by: Djordje Lukic <djordje.lukic@docker.com>
1 parent 3e9ccbe commit b1a2a7b

10 files changed

Lines changed: 48 additions & 51 deletions

File tree

pkg/config/auto.go

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -50,11 +50,12 @@ func AutoModelConfig(ctx context.Context, modelsGateway string, env environment.
5050
}
5151
}
5252

53-
func PreferredMaxTokens(provider string) int {
53+
func PreferredMaxTokens(provider string) *int64 {
54+
var mt int64 = 32000
5455
if provider == "dmr" {
55-
return 16000
56+
mt = 16000
5657
}
57-
return 64000
58+
return &mt
5859
}
5960

6061
// AutoEmbeddingModelConfigs returns the ordered list of embedding-capable models

pkg/config/latest/types.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@ type ModelConfig struct {
4646
Provider string `json:"provider,omitempty"`
4747
Model string `json:"model,omitempty"`
4848
Temperature *float64 `json:"temperature,omitempty"`
49-
MaxTokens int `json:"max_tokens,omitempty"`
49+
MaxTokens *int64 `json:"max_tokens,omitempty"`
5050
TopP *float64 `json:"top_p,omitempty"`
5151
FrequencyPenalty *float64 `json:"frequency_penalty,omitempty"`
5252
PresencePenalty *float64 `json:"presence_penalty,omitempty"`

pkg/model/provider/anthropic/beta_client.go

Lines changed: 1 addition & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -339,14 +339,7 @@ func (c *Client) Rerank(ctx context.Context, query string, documents []types.Doc
339339
"additionalProperties": false,
340340
}
341341

342-
// Use max_tokens from model config if specified, otherwise use a reasonable
343-
// default (8192) that works for most reranking scenarios. Anthropic requires
344-
// max_tokens to be set explicitly (unlike OpenAI which can rely on defaults).
345-
maxTokens := int64(8192)
346-
if c.ModelConfig.MaxTokens > 0 {
347-
maxTokens = int64(c.ModelConfig.MaxTokens)
348-
}
349-
342+
maxTokens := c.ModelOptions.MaxTokens()
350343
params := anthropic.BetaMessageNewParams{
351344
Model: anthropic.Model(c.ModelConfig.Model),
352345
MaxTokens: maxTokens,

pkg/model/provider/anthropic/client.go

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ func (c *Client) adjustMaxTokensForThinking(maxTokens int64) (int64, error) {
4141
minRequired := thinkingTokens + 1024 // configured thinking budget + minimum output buffer
4242

4343
if maxTokens <= thinkingTokens {
44-
userSetMaxTokens := c.ModelConfig.MaxTokens > 0
44+
userSetMaxTokens := c.ModelConfig.MaxTokens != nil
4545
if userSetMaxTokens {
4646
// User explicitly set max_tokens too low - return error
4747
slog.Error("Anthropic: max_tokens must be greater than thinking_budget",
@@ -193,11 +193,7 @@ func (c *Client) CreateChatCompletionStream(
193193
"message_count", len(messages),
194194
"tool_count", len(requestTools))
195195

196-
maxTokens := int64(c.ModelConfig.MaxTokens)
197-
if maxTokens == 0 {
198-
maxTokens = 8192 // Default output budget when not specified
199-
}
200-
196+
maxTokens := c.ModelOptions.MaxTokens()
201197
maxTokens, err := c.adjustMaxTokensForThinking(maxTokens)
202198
if err != nil {
203199
return nil, err

pkg/model/provider/clone.go

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -23,9 +23,8 @@ func CloneWithOptions(ctx context.Context, base Provider, opts ...options.Opt) P
2323
for _, opt := range mergedOpts {
2424
tempOpts := &options.ModelOptions{}
2525
opt(tempOpts)
26-
if maxTokens := tempOpts.MaxTokens(); maxTokens != nil {
27-
modelConfig.MaxTokens = *maxTokens
28-
}
26+
mt := tempOpts.MaxTokens()
27+
modelConfig.MaxTokens = &mt
2928
}
3029

3130
clone, err := New(ctx, &modelConfig, config.Env, mergedOpts...)

pkg/model/provider/dmr/client.go

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -536,9 +536,9 @@ func (c *Client) CreateChatCompletionStream(ctx context.Context, messages []chat
536536
params.ParallelToolCalls = openai.Bool(*c.ModelConfig.ParallelToolCalls)
537537
}
538538

539-
if c.ModelConfig.MaxTokens > 0 {
540-
params.MaxTokens = openai.Int(int64(c.ModelConfig.MaxTokens))
541-
slog.Debug("DMR request configured with max tokens", "max_tokens", c.ModelConfig.MaxTokens)
539+
if c.ModelConfig.MaxTokens != nil {
540+
params.MaxTokens = openai.Int(*c.ModelConfig.MaxTokens)
541+
slog.Debug("DMR request configured with max tokens", "max_tokens", *c.ModelConfig.MaxTokens)
542542
}
543543

544544
if len(requestTools) > 0 {
@@ -982,9 +982,9 @@ type speculativeDecodingOpts struct {
982982
acceptanceRate float64
983983
}
984984

985-
func parseDMRProviderOpts(cfg *latest.ModelConfig) (contextSize int, runtimeFlags []string, specOpts *speculativeDecodingOpts) {
985+
func parseDMRProviderOpts(cfg *latest.ModelConfig) (contextSize *int64, runtimeFlags []string, specOpts *speculativeDecodingOpts) {
986986
if cfg == nil {
987-
return 0, nil, nil
987+
return nil, nil, nil
988988
}
989989

990990
// Context length is now sourced from the standard max_tokens field
@@ -1129,7 +1129,7 @@ func modelExists(ctx context.Context, model string) bool {
11291129
return true
11301130
}
11311131

1132-
func configureDockerModel(ctx context.Context, model string, contextSize int, runtimeFlags []string, specOpts *speculativeDecodingOpts) error {
1132+
func configureDockerModel(ctx context.Context, model string, contextSize *int64, runtimeFlags []string, specOpts *speculativeDecodingOpts) error {
11331133
args := buildDockerModelConfigureArgs(model, contextSize, runtimeFlags, specOpts)
11341134

11351135
cmd := exec.CommandContext(ctx, "docker", args...)
@@ -1146,10 +1146,10 @@ func configureDockerModel(ctx context.Context, model string, contextSize int, ru
11461146

11471147
// buildDockerModelConfigureArgs returns the argument vector passed to `docker` for model configuration.
11481148
// It formats context size, speculative decoding options, and runtime flags consistently with the CLI contract.
1149-
func buildDockerModelConfigureArgs(model string, contextSize int, runtimeFlags []string, specOpts *speculativeDecodingOpts) []string {
1149+
func buildDockerModelConfigureArgs(model string, contextSize *int64, runtimeFlags []string, specOpts *speculativeDecodingOpts) []string {
11501150
args := []string{"model", "configure"}
1151-
if contextSize > 0 {
1152-
args = append(args, "--context-size="+strconv.Itoa(contextSize))
1151+
if contextSize != nil {
1152+
args = append(args, "--context-size="+strconv.FormatInt(*contextSize, 10))
11531153
}
11541154
if specOpts != nil {
11551155
if specOpts.draftModel != "" {

pkg/model/provider/dmr/client_test.go

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ func TestNewClientWithWrongType(t *testing.T) {
3232
}
3333

3434
func TestBuildDockerConfigureArgs(t *testing.T) {
35-
args := buildDockerModelConfigureArgs("ai/qwen3:14B-Q6_K", 8192, []string{"--temp", "0.7", "--top-p", "0.9"}, nil)
35+
args := buildDockerModelConfigureArgs("ai/qwen3:14B-Q6_K", int64Ptr(8192), []string{"--temp", "0.7", "--top-p", "0.9"}, nil)
3636

3737
assert.Equal(t, []string{"model", "configure", "--context-size=8192", "ai/qwen3:14B-Q6_K", "--", "--temp", "0.7", "--top-p", "0.9"}, args)
3838
}
@@ -52,7 +52,7 @@ func TestIntegrateFlagsWithProviderOptsOrder(t *testing.T) {
5252
cfg := &latest.ModelConfig{
5353
Temperature: floatPtr(0.6),
5454
TopP: floatPtr(0.9),
55-
MaxTokens: 4096,
55+
MaxTokens: int64Ptr(4096),
5656
ProviderOpts: map[string]any{
5757
"runtime_flags": []string{"--threads", "6"},
5858
},
@@ -84,13 +84,17 @@ func floatPtr(f float64) *float64 {
8484
return &f
8585
}
8686

87+
func int64Ptr(i int64) *int64 {
88+
return &i
89+
}
90+
8791
func TestBuildDockerConfigureArgsWithSpeculativeDecoding(t *testing.T) {
8892
specOpts := &speculativeDecodingOpts{
8993
draftModel: "ai/qwen3:1B",
9094
numTokens: 5,
9195
acceptanceRate: 0.8,
9296
}
93-
args := buildDockerModelConfigureArgs("ai/qwen3:14B-Q6_K", 8192, []string{"--temp", "0.7"}, specOpts)
97+
args := buildDockerModelConfigureArgs("ai/qwen3:14B-Q6_K", int64Ptr(8192), []string{"--temp", "0.7"}, specOpts)
9498

9599
assert.Equal(t, []string{
96100
"model", "configure",
@@ -110,7 +114,7 @@ func TestBuildDockerConfigureArgsWithPartialSpeculativeDecoding(t *testing.T) {
110114
numTokens: 5,
111115
// acceptanceRate not set (0 value)
112116
}
113-
args := buildDockerModelConfigureArgs("ai/qwen3:14B-Q6_K", 0, nil, specOpts)
117+
args := buildDockerModelConfigureArgs("ai/qwen3:14B-Q6_K", nil, nil, specOpts)
114118

115119
assert.Equal(t, []string{
116120
"model", "configure",
@@ -122,7 +126,7 @@ func TestBuildDockerConfigureArgsWithPartialSpeculativeDecoding(t *testing.T) {
122126

123127
func TestParseDMRProviderOptsWithSpeculativeDecoding(t *testing.T) {
124128
cfg := &latest.ModelConfig{
125-
MaxTokens: 4096,
129+
MaxTokens: int64Ptr(4096),
126130
ProviderOpts: map[string]any{
127131
"speculative_draft_model": "ai/qwen3:1B",
128132
"speculative_num_tokens": "5",
@@ -143,7 +147,7 @@ func TestParseDMRProviderOptsWithSpeculativeDecoding(t *testing.T) {
143147

144148
func TestParseDMRProviderOptsWithoutSpeculativeDecoding(t *testing.T) {
145149
cfg := &latest.ModelConfig{
146-
MaxTokens: 4096,
150+
MaxTokens: int64Ptr(4096),
147151
ProviderOpts: map[string]any{
148152
"runtime_flags": []string{"--threads", "8"},
149153
},

pkg/model/provider/gemini/client.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -287,8 +287,8 @@ func convertMessagesToGemini(messages []chat.Message) []*genai.Content {
287287
// buildConfig creates GenerateContentConfig from model config
288288
func (c *Client) buildConfig() *genai.GenerateContentConfig {
289289
config := &genai.GenerateContentConfig{}
290-
if c.ModelConfig.MaxTokens > 0 {
291-
config.MaxOutputTokens = int32(c.ModelConfig.MaxTokens)
290+
if c.ModelConfig.MaxTokens != nil {
291+
config.MaxOutputTokens = int32(*c.ModelConfig.MaxTokens)
292292
}
293293
if c.ModelConfig.Temperature != nil {
294294
config.Temperature = genai.Ptr(float32(*c.ModelConfig.Temperature))

pkg/model/provider/openai/client.go

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -314,12 +314,12 @@ func (c *Client) CreateChatCompletionStream(
314314
params.PresencePenalty = openai.Float(*c.ModelConfig.PresencePenalty)
315315
}
316316

317-
if maxToken := c.ModelConfig.MaxTokens; maxToken > 0 {
317+
if maxToken := c.ModelConfig.MaxTokens; maxToken != nil {
318318
if !isResponsesOnlyModel(c.ModelConfig.Model) {
319-
params.MaxTokens = openai.Int(int64(maxToken))
320-
slog.Debug("OpenAI request configured with max tokens", "max_tokens", maxToken, "model", c.ModelConfig.Model)
319+
params.MaxTokens = openai.Int(*maxToken)
320+
slog.Debug("OpenAI request configured with max tokens", "max_tokens", *maxToken, "model", c.ModelConfig.Model)
321321
} else {
322-
params.MaxCompletionTokens = openai.Int(int64(maxToken))
322+
params.MaxCompletionTokens = openai.Int(*maxToken)
323323
slog.Debug("using max_completion_tokens instead of max_tokens for Responses-API models", "model", c.ModelConfig.Model)
324324
}
325325
}
@@ -428,8 +428,8 @@ func (c *Client) CreateResponseStream(
428428
params.TopP = param.NewOpt(*c.ModelConfig.TopP)
429429
}
430430

431-
if maxToken := c.ModelConfig.MaxTokens; maxToken > 0 {
432-
params.MaxOutputTokens = param.NewOpt(int64(maxToken))
431+
if maxToken := c.ModelConfig.MaxTokens; maxToken != nil {
432+
params.MaxOutputTokens = param.NewOpt(*maxToken)
433433
slog.Debug("OpenAI responses request configured with max output tokens", "max_output_tokens", maxToken)
434434
}
435435

pkg/model/provider/options/options.go

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -4,11 +4,13 @@ import (
44
"github.com/docker/cagent/pkg/config/latest"
55
)
66

7+
const defaultMaxTokens = 32000
8+
79
type ModelOptions struct {
810
gateway string
911
structuredOutput *latest.StructuredOutput
1012
generatingTitle bool
11-
maxTokens *int
13+
maxTokens int64
1214
}
1315

1416
func (c *ModelOptions) Gateway() string {
@@ -23,7 +25,10 @@ func (c *ModelOptions) GeneratingTitle() bool {
2325
return c.generatingTitle
2426
}
2527

26-
func (c *ModelOptions) MaxTokens() *int {
28+
func (c *ModelOptions) MaxTokens() int64 {
29+
if c.maxTokens == 0 {
30+
return defaultMaxTokens
31+
}
2732
return c.maxTokens
2833
}
2934

@@ -47,9 +52,9 @@ func WithGeneratingTitle() Opt {
4752
}
4853
}
4954

50-
func WithMaxTokens(maxTokens int) Opt {
55+
func WithMaxTokens(maxTokens int64) Opt {
5156
return func(cfg *ModelOptions) {
52-
cfg.maxTokens = &maxTokens
57+
cfg.maxTokens = maxTokens
5358
}
5459
}
5560

@@ -66,8 +71,7 @@ func FromModelOptions(m ModelOptions) []Opt {
6671
if m.generatingTitle {
6772
out = append(out, WithGeneratingTitle())
6873
}
69-
if m.maxTokens != nil {
70-
out = append(out, WithMaxTokens(*m.maxTokens))
71-
}
74+
out = append(out, WithMaxTokens(m.maxTokens))
75+
7276
return out
7377
}

0 commit comments

Comments
 (0)