Skip to content

Commit cf92f8d

Browse files
authored
Merge pull request #314 from dgageot/simpler-dmr-code
Simpler DMR code
2 parents 353bf79 + 2c5ded2 commit cf92f8d

2 files changed

Lines changed: 26 additions & 59 deletions

File tree

pkg/model/provider/dmr/client.go

Lines changed: 15 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -212,8 +212,9 @@ func convertMessages(messages []chat.Message) []openai.ChatCompletionMessage {
212212
}
213213

214214
openaiMessage := openai.ChatCompletionMessage{
215-
Role: string(msg.Role),
216-
Name: msg.Name,
215+
Role: string(msg.Role),
216+
Name: msg.Name,
217+
ToolCallID: msg.ToolCallID,
217218
}
218219

219220
if len(msg.MultiContent) == 0 {
@@ -229,22 +230,15 @@ func convertMessages(messages []chat.Message) []openai.ChatCompletionMessage {
229230
}
230231
}
231232

232-
if len(msg.ToolCalls) > 0 {
233-
openaiMessage.ToolCalls = make([]openai.ToolCall, len(msg.ToolCalls))
234-
for j, toolCall := range msg.ToolCalls {
235-
openaiMessage.ToolCalls[j] = openai.ToolCall{
236-
ID: toolCall.ID,
237-
Type: openai.ToolType(toolCall.Type),
238-
Function: openai.FunctionCall{
239-
Name: toolCall.Function.Name,
240-
Arguments: toolCall.Function.Arguments,
241-
},
242-
}
243-
}
244-
}
245-
246-
if msg.ToolCallID != "" {
247-
openaiMessage.ToolCallID = msg.ToolCallID
233+
for _, call := range msg.ToolCalls {
234+
openaiMessage.ToolCalls = append(openaiMessage.ToolCalls, openai.ToolCall{
235+
ID: call.ID,
236+
Type: openai.ToolType(call.Type),
237+
Function: openai.FunctionCall{
238+
Name: call.Function.Name,
239+
Arguments: call.Function.Arguments,
240+
},
241+
})
248242
}
249243

250244
openaiMessages = append(openaiMessages, openaiMessage)
@@ -297,11 +291,7 @@ func convertMessages(messages []chat.Message) []openai.ChatCompletionMessage {
297291

298292
// CreateChatCompletionStream creates a streaming chat completion request
299293
// It returns a stream that can be iterated over to get completion chunks
300-
func (c *Client) CreateChatCompletionStream(
301-
ctx context.Context,
302-
messages []chat.Message,
303-
requestTools []tools.Tool,
304-
) (chat.MessageStream, error) {
294+
func (c *Client) CreateChatCompletionStream(ctx context.Context, messages []chat.Message, requestTools []tools.Tool) (chat.MessageStream, error) {
305295
slog.Debug("Creating DMR chat completion stream",
306296
"model", c.config.Model,
307297
"message_count", len(messages),
@@ -314,10 +304,7 @@ func (c *Client) CreateChatCompletionStream(
314304
return nil, errors.New("at least one message is required")
315305
}
316306

317-
trackUsage := true
318-
if c.config.TrackUsage != nil && *c.config.TrackUsage == false {
319-
trackUsage = false
320-
}
307+
trackUsage := c.config.TrackUsage == nil || *c.config.TrackUsage
321308

322309
request := openai.ChatCompletionRequest{
323310
Model: c.config.Model,
@@ -387,10 +374,7 @@ func (c *Client) CreateChatCompletionStream(
387374
return newStreamAdapter(stream, trackUsage), nil
388375
}
389376

390-
func (c *Client) CreateChatCompletion(
391-
ctx context.Context,
392-
messages []chat.Message,
393-
) (string, error) {
377+
func (c *Client) CreateChatCompletion(ctx context.Context, messages []chat.Message) (string, error) {
394378
slog.Debug("Creating DMR chat completion", "model", c.config.Model, "message_count", len(messages), "base_url", c.baseURL)
395379

396380
request := openai.ChatCompletionRequest{
@@ -412,15 +396,13 @@ func (c *Client) CreateChatCompletion(
412396
// In particular, it avoids shadowing built-in mapping methods like `keys()` by removing a literal "keys"
413397
// field from property schemas if present, and guarantees the outer structure is an object with a properties map.
414398
func sanitizeToolParameters(p tools.FunctionParameters) any {
415-
// Start with a safe container
416399
out := map[string]any{
417400
"type": "object",
418401
"properties": map[string]any{},
419402
}
420403
if p.Type != "" {
421404
out["type"] = p.Type
422405
}
423-
// Copy required if present
424406
if len(p.Required) > 0 {
425407
out["required"] = p.Required
426408
}

pkg/model/provider/dmr/client_test.go

Lines changed: 11 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,12 @@
11
package dmr
22

33
import (
4-
"reflect"
54
"testing"
65

7-
latest "github.com/docker/cagent/pkg/config/v2"
86
"github.com/stretchr/testify/assert"
97
"github.com/stretchr/testify/require"
8+
9+
latest "github.com/docker/cagent/pkg/config/v2"
1010
)
1111

1212
func TestNewClientWithExplicitBaseURL(t *testing.T) {
@@ -33,27 +33,19 @@ func TestNewClientWithWrongType(t *testing.T) {
3333

3434
func TestBuildDockerConfigureArgs(t *testing.T) {
3535
args := buildDockerModelConfigureArgs("ai/qwen3:14B-Q6_K", 8192, []string{"--temp", "0.7", "--top-p", "0.9"})
36-
expected := []string{"model", "configure", "--context-size=8192", "ai/qwen3:14B-Q6_K", "--", "--temp", "0.7", "--top-p", "0.9"}
37-
if !reflect.DeepEqual(args, expected) {
38-
t.Fatalf("unexpected args.\nexpected: %#v\nactual: %#v", expected, args)
39-
}
36+
37+
assert.Equal(t, []string{"model", "configure", "--context-size=8192", "ai/qwen3:14B-Q6_K", "--", "--temp", "0.7", "--top-p", "0.9"}, args)
4038
}
4139

4240
func TestBuildRuntimeFlagsFromModelConfig_LlamaCpp(t *testing.T) {
43-
cfg := &latest.ModelConfig{
41+
flags := buildRuntimeFlagsFromModelConfig("llama.cpp", &latest.ModelConfig{
4442
Temperature: 0.6,
4543
TopP: 0.95,
4644
FrequencyPenalty: 0.2,
4745
PresencePenalty: 0.1,
48-
}
46+
})
4947

50-
flags := buildRuntimeFlagsFromModelConfig("llama.cpp", cfg)
51-
52-
// Order matters based on implementation
53-
expected := []string{"--temp", "0.6", "--top-p", "0.95", "--frequency-penalty", "0.2", "--presence-penalty", "0.1"}
54-
if !reflect.DeepEqual(flags, expected) {
55-
t.Fatalf("unexpected runtime flags.\nexpected: %#v\nactual: %#v", expected, flags)
56-
}
48+
assert.Equal(t, []string{"--temp", "0.6", "--top-p", "0.95", "--frequency-penalty", "0.2", "--presence-penalty", "0.1"}, flags)
5749
}
5850

5951
func TestIntegrateFlagsWithProviderOptsOrder(t *testing.T) {
@@ -71,10 +63,7 @@ func TestIntegrateFlagsWithProviderOptsOrder(t *testing.T) {
7163
merged := append(derived, []string{"--threads", "6"}...)
7264

7365
args := buildDockerModelConfigureArgs("ai/qwen3:14B-Q6_K", cfg.MaxTokens, merged)
74-
expected := []string{"model", "configure", "--context-size=4096", "ai/qwen3:14B-Q6_K", "--", "--temp", "0.6", "--top-p", "0.9", "--threads", "6"}
75-
if !reflect.DeepEqual(args, expected) {
76-
t.Fatalf("unexpected configure args.\nexpected: %#v\nactual: %#v", expected, args)
77-
}
66+
assert.Equal(t, []string{"model", "configure", "--context-size=4096", "ai/qwen3:14B-Q6_K", "--", "--temp", "0.6", "--top-p", "0.9", "--threads", "6"}, args)
7867
}
7968

8069
func TestMergeRuntimeFlagsPreferUser_WarnsAndPrefersUser(t *testing.T) {
@@ -85,12 +74,8 @@ func TestMergeRuntimeFlagsPreferUser_WarnsAndPrefersUser(t *testing.T) {
8574
merged, warnings := mergeRuntimeFlagsPreferUser(derived, user)
8675

8776
// Expect 1 warnings for --temp overriding
88-
if len(warnings) != 1 {
89-
t.Fatalf("expected 1 warning1, got %d: %#v", len(warnings), warnings)
90-
}
77+
require.Len(t, warnings, 1)
78+
9179
// Derived conflicting flags should be dropped, user ones kept and appended
92-
expected := []string{"--top-p", "0.8", "--temp", "0.7", "--threads", "8"}
93-
if !reflect.DeepEqual(merged, expected) {
94-
t.Fatalf("unexpected merged flags.\nexpected: %#v\nactual: %#v", expected, merged)
95-
}
80+
assert.Equal(t, []string{"--top-p", "0.8", "--temp", "0.7", "--threads", "8"}, merged)
9681
}

0 commit comments

Comments
 (0)