Skip to content

Commit 353bf79

Browse files
authored
Merge pull request #312 from dgageot/fix-281
Connect to DMR on the docker socket when possible
2 parents e2e498f + 9fc4905 commit 353bf79

6 files changed

Lines changed: 75 additions & 187 deletions

File tree

README.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -183,7 +183,7 @@ models:
183183
184184
You'll find a curated list of agents examples, spread into 3 categories, [Basic](https://github.com/docker/cagent/tree/main/examples#basic-configurations), [Advanced](https://github.com/docker/cagent/tree/main/examples#advanced-configurations) and [multi-agents](https://github.com/docker/cagent/tree/main/examples#multi-agent-configurations) in the `/examples/` directory.
185185

186-
### DMR provider options
186+
### DMR (Docker Model Runner) provider options
187187

188188
When using the `dmr` provider, you can use the `provider_opts` key for DMR runtime-specific (e.g. llama.cpp) options:
189189

@@ -197,7 +197,7 @@ models:
197197
runtime_flags: ["--ngl=33", "--repeat-penalty=1.2", ...] # or comma/space-separated string
198198
```
199199

200-
The default base_url `cagent` will use for dmr providers is `http://localhost:12434/engines/llama.cpp/v1`. DMR itself might need to be enabled via [Docker Desktop's settings](https://docs.docker.com/ai/model-runner/get-started/#enable-dmr-in-docker-desktop) on MacOS and Windows, and via command line on [Docker CE on Linux](https://docs.docker.com/ai/model-runner/get-started/#enable-dmr-in-docker-engine).
200+
The default base_url `cagent` will use for DMR providers is `http://localhost:12434/engines/llama.cpp/v1`. DMR itself might need to be enabled via [Docker Desktop's settings](https://docs.docker.com/ai/model-runner/get-started/#enable-dmr-in-docker-desktop) on MacOS and Windows, and via command line on [Docker CE on Linux](https://docs.docker.com/ai/model-runner/get-started/#enable-dmr-in-docker-engine).
201201

202202
## Quickly generate agents and agent teams with `cagent new`
203203

examples/pirate.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,3 +8,4 @@ agents:
88
model: openai/gpt-4o
99
# model: anthropic/claude-3-5-sonnet-latest
1010
# model: gemini/gemini-2.5-flash
11+
# model: dmr/ai/llama3.2

pkg/model/provider/dmr/client.go

Lines changed: 59 additions & 64 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,8 @@ import (
77
"errors"
88
"fmt"
99
"log/slog"
10+
"net"
11+
"net/http"
1012
"os"
1113
"os/exec"
1214
"strconv"
@@ -29,7 +31,7 @@ type Client struct {
2931
}
3032

3133
// NewClient creates a new DMR client from the provided configuration
32-
func NewClient(_ context.Context, cfg *latest.ModelConfig, opts ...options.Opt) (*Client, error) {
34+
func NewClient(ctx context.Context, cfg *latest.ModelConfig, opts ...options.Opt) (*Client, error) {
3335
if cfg == nil {
3436
slog.Error("DMR client creation failed", "error", "model configuration is required")
3537
return nil, errors.New("model configuration is required")
@@ -45,49 +47,63 @@ func NewClient(_ context.Context, cfg *latest.ModelConfig, opts ...options.Opt)
4547
opt(&globalOptions)
4648
}
4749

48-
// Resolve base_url for DMR models. If not provided, configure with the docker model plugin, else fallback.
49-
baseURL := cfg.BaseURL
50-
if baseURL == "" {
51-
endpoint, engine, err := getDockerModelEndpointAndEngine()
52-
if err != nil {
53-
slog.Debug("docker model status query failed", "error", err)
54-
}
50+
endpoint, engine, err := getDockerModelEndpointAndEngine(ctx)
51+
if err != nil {
52+
slog.Debug("docker model status query failed", "error", err)
53+
}
5554

56-
// Build runtime flags from ModelConfig and engine
57-
contextSize, providerRuntimeFlags := parseDMRProviderOpts(cfg)
58-
configFlags := buildRuntimeFlagsFromModelConfig(engine, cfg)
59-
finalFlags, warnings := mergeRuntimeFlagsPreferUser(configFlags, providerRuntimeFlags)
60-
for _, w := range warnings {
61-
slog.Warn(w)
62-
}
63-
slog.Debug("DMR provider_opts parsed", "model", cfg.Model, "context_size", contextSize, "runtime_flags", finalFlags, "engine", engine)
64-
if err := configureDockerModel(cfg.Model, contextSize, finalFlags); err != nil {
65-
slog.Debug("docker model configure skipped or failed", "error", err)
66-
}
55+
clientConfig := openai.DefaultConfig("")
6756

68-
if endpoint != "" {
69-
baseURL = endpoint
70-
slog.Debug("Using docker model endpoint for DMR base_url", "base_url", baseURL)
71-
} else {
72-
baseURL = "http://localhost:12434/engines/llama.cpp/v1"
73-
slog.Debug("Using default DMR base_url", "base_url", baseURL)
57+
switch {
58+
case cfg.BaseURL != "":
59+
clientConfig.BaseURL = cfg.BaseURL
60+
case os.Getenv("MODEL_RUNNER_HOST") != "":
61+
clientConfig.BaseURL = os.Getenv("MODEL_RUNNER_HOST")
62+
case inContainer():
63+
// This won't work with Docker CE but we have no way to detect that from inside the container.
64+
clientConfig.BaseURL = "http://model-runner.docker.internal/engines/v1/"
65+
case endpoint == "http://model-runner.docker.internal/engines/v1/":
66+
// Docker Desktop
67+
clientConfig.BaseURL = "http://_/exp/vDD4.40/engines/v1"
68+
clientConfig.HTTPClient = &http.Client{
69+
Transport: &http.Transport{
70+
DialContext: func(ctx context.Context, _, _ string) (net.Conn, error) {
71+
var d net.Dialer
72+
return d.DialContext(ctx, "unix", "/var/run/docker.sock")
73+
},
74+
},
7475
}
76+
default:
77+
// Docker CE
78+
clientConfig.BaseURL = endpoint
7579
}
7680

77-
slog.Debug("Creating DMR client config", "base_url", baseURL)
78-
clientConfig := openai.DefaultConfig("")
79-
clientConfig.BaseURL = baseURL
81+
// Build runtime flags from ModelConfig and engine
82+
contextSize, providerRuntimeFlags := parseDMRProviderOpts(cfg)
83+
configFlags := buildRuntimeFlagsFromModelConfig(engine, cfg)
84+
finalFlags, warnings := mergeRuntimeFlagsPreferUser(configFlags, providerRuntimeFlags)
85+
for _, w := range warnings {
86+
slog.Warn(w)
87+
}
88+
slog.Debug("DMR provider_opts parsed", "model", cfg.Model, "context_size", contextSize, "runtime_flags", finalFlags, "engine", engine)
89+
if err := configureDockerModel(ctx, cfg.Model, contextSize, finalFlags); err != nil {
90+
slog.Debug("docker model configure skipped or failed", "error", err)
91+
}
8092

81-
client := openai.NewClientWithConfig(clientConfig)
82-
slog.Debug("DMR client created successfully", "model", cfg.Model, "base_url", baseURL)
93+
slog.Debug("DMR client created successfully", "model", cfg.Model, "base_url", clientConfig.BaseURL)
8394

8495
return &Client{
85-
client: client,
96+
client: openai.NewClientWithConfig(clientConfig),
8697
config: cfg,
87-
baseURL: baseURL,
98+
baseURL: clientConfig.BaseURL,
8899
}, nil
89100
}
90101

102+
func inContainer() bool {
103+
finfo, err := os.Stat("/.dockerenv")
104+
return err == nil && finfo.Mode().IsRegular()
105+
}
106+
91107
func convertMultiContent(multiContent []chat.MessagePart) []openai.ChatMessagePart {
92108
openaiMultiContent := make([]openai.ChatMessagePart, len(multiContent))
93109
for i, part := range multiContent {
@@ -290,7 +306,8 @@ func (c *Client) CreateChatCompletionStream(
290306
"model", c.config.Model,
291307
"message_count", len(messages),
292308
"tool_count", len(requestTools),
293-
"base_url", c.baseURL)
309+
"base_url", c.baseURL,
310+
)
294311

295312
if len(messages) == 0 {
296313
slog.Error("DMR stream creation failed", "error", "at least one message is required")
@@ -366,7 +383,7 @@ func (c *Client) CreateChatCompletionStream(
366383
return nil, err
367384
}
368385

369-
slog.Debug("DMR chat completion stream created successfully", "model", c.config.Model)
386+
slog.Debug("DMR chat completion stream created successfully", "model", c.config.Model, "base_url", c.baseURL)
370387
return newStreamAdapter(stream, trackUsage), nil
371388
}
372389

@@ -383,7 +400,7 @@ func (c *Client) CreateChatCompletion(
383400

384401
response, err := c.client.CreateChatCompletion(ctx, request)
385402
if err != nil {
386-
slog.Error("DMR chat completion failed", "error", err, "model", c.config.Model, "base_url", c.baseURL)
403+
slog.Error("DMR chat completion failed", "error", err, "model", c.config.Model)
387404
return "", err
388405
}
389406

@@ -464,10 +481,10 @@ func parseDMRProviderOpts(cfg *latest.ModelConfig) (contextSize int, runtimeFlag
464481
return contextSize, runtimeFlags
465482
}
466483

467-
func configureDockerModel(model string, contextSize int, runtimeFlags []string) error {
484+
func configureDockerModel(ctx context.Context, model string, contextSize int, runtimeFlags []string) error {
468485
args := buildDockerModelConfigureArgs(model, contextSize, runtimeFlags)
469486

470-
cmd := exec.Command("docker", args...)
487+
cmd := exec.CommandContext(ctx, "docker", args...)
471488
slog.Debug("Running docker model configure", "model", model, "args", args)
472489
var stdout, stderr bytes.Buffer
473490
cmd.Stdout = &stdout
@@ -494,14 +511,15 @@ func buildDockerModelConfigureArgs(model string, contextSize int, runtimeFlags [
494511
return args
495512
}
496513

497-
func getDockerModelEndpointAndEngine() (endpoint, engine string, err error) {
498-
cmd := exec.Command("docker", "model", "status", "--json")
514+
func getDockerModelEndpointAndEngine(ctx context.Context) (endpoint, engine string, err error) {
515+
cmd := exec.CommandContext(ctx, "docker", "model", "status", "--json")
499516
var stdout, stderr bytes.Buffer
500517
cmd.Stdout = &stdout
501518
cmd.Stderr = &stderr
502519
if err := cmd.Run(); err != nil {
503520
return "", "", errors.New(strings.TrimSpace(stderr.String()))
504521
}
522+
505523
type status struct {
506524
Running bool `json:"running"`
507525
Backends map[string]string `json:"backends"`
@@ -512,16 +530,8 @@ func getDockerModelEndpointAndEngine() (endpoint, engine string, err error) {
512530
if err := json.Unmarshal(stdout.Bytes(), &st); err != nil {
513531
return "", "", err
514532
}
515-
endpoint = strings.TrimSpace(st.Endpoint)
516-
517-
inDockerContainer := false
518-
finfo, err := os.Stat("/.dockerenv")
519-
if err == nil && finfo.Mode().IsRegular() {
520-
inDockerContainer = true
521-
}
522533

523-
// normalize endpoint considering container environment
524-
endpoint = normalizeDMREndpoint(endpoint, inDockerContainer)
534+
endpoint = strings.TrimSpace(st.Endpoint)
525535

526536
engine = strings.TrimSpace(st.Engine)
527537
if engine == "" {
@@ -539,23 +549,8 @@ func getDockerModelEndpointAndEngine() (endpoint, engine string, err error) {
539549
if engine == "" {
540550
engine = "llama.cpp"
541551
}
542-
return endpoint, engine, nil
543-
}
544552

545-
// normalizeDMREndpoint applies an override to the endpoint reported by
546-
// `docker model status --json` to ensure the DMR client uses a reachable address
547-
// from the current environment.
548-
func normalizeDMREndpoint(endpoint string, inDockerContainer bool) string {
549-
// This env overriding might need to be updated if we end up having multiple separate DMR
550-
// engines with different endpoints running at the same time
551-
if hostEnvVar := os.Getenv("MODEL_RUNNER_HOST"); hostEnvVar != "" {
552-
return hostEnvVar
553-
}
554-
// Only override if not running in a docker container
555-
if endpoint == "http://model-runner.docker.internal/engines/v1/" && !inDockerContainer {
556-
return "http://localhost:12434/engines/llama.cpp/v1"
557-
}
558-
return endpoint
553+
return endpoint, engine, nil
559554
}
560555

561556
// buildRuntimeFlagsFromModelConfig converts standard ModelConfig fields into backend-specific
Lines changed: 9 additions & 118 deletions
Original file line numberDiff line numberDiff line change
@@ -1,61 +1,34 @@
11
package dmr
22

33
import (
4-
"context"
54
"reflect"
65
"testing"
76

87
latest "github.com/docker/cagent/pkg/config/v2"
8+
"github.com/stretchr/testify/assert"
9+
"github.com/stretchr/testify/require"
910
)
1011

11-
func TestNewClientWithDefaultBaseURL(t *testing.T) {
12-
// No base_url provided, should use default
13-
cfg := &latest.ModelConfig{
14-
Provider: "dmr",
15-
Model: "ai/qwen3",
16-
// BaseURL is empty, should use default
17-
}
18-
19-
client, err := NewClient(context.Background(), cfg)
20-
if err != nil {
21-
t.Fatalf("Expected no error, got %v", err)
22-
}
23-
24-
if client.baseURL != "http://localhost:12434/engines/llama.cpp/v1" {
25-
t.Errorf("Expected default baseURL to be 'http://localhost:12434/engines/llama.cpp/v1', got '%s'", client.baseURL)
26-
}
27-
}
28-
2912
func TestNewClientWithExplicitBaseURL(t *testing.T) {
30-
// Explicit base_url provided, should use that
31-
customURL := "https://custom.example.com:8080/api/v1"
3213
cfg := &latest.ModelConfig{
3314
Provider: "dmr",
3415
Model: "ai/qwen3",
35-
BaseURL: customURL,
36-
}
37-
38-
client, err := NewClient(context.Background(), cfg)
39-
if err != nil {
40-
t.Fatalf("Expected no error, got %v", err)
16+
BaseURL: "https://custom.example.com:8080/api/v1",
4117
}
4218

43-
if client.baseURL != customURL {
44-
t.Errorf("Expected baseURL to be '%s', got '%s'", customURL, client.baseURL)
45-
}
19+
client, err := NewClient(t.Context(), cfg)
20+
require.NoError(t, err)
21+
assert.Equal(t, "https://custom.example.com:8080/api/v1", client.baseURL)
4622
}
4723

4824
func TestNewClientWithWrongType(t *testing.T) {
49-
// Wrong model type, should return error
5025
cfg := &latest.ModelConfig{
51-
Provider: "openai", // Wrong type
26+
Provider: "openai",
5227
Model: "gpt-4",
5328
}
5429

55-
_, err := NewClient(context.Background(), cfg)
56-
if err == nil {
57-
t.Fatal("Expected error for wrong model type, got nil")
58-
}
30+
_, err := NewClient(t.Context(), cfg)
31+
require.Error(t, err)
5932
}
6033

6134
func TestBuildDockerConfigureArgs(t *testing.T) {
@@ -121,85 +94,3 @@ func TestMergeRuntimeFlagsPreferUser_WarnsAndPrefersUser(t *testing.T) {
12194
t.Fatalf("unexpected merged flags.\nexpected: %#v\nactual: %#v", expected, merged)
12295
}
12396
}
124-
125-
func TestNormalizeDMREndpoint_NoEnvOverride(t *testing.T) {
126-
tests := []struct {
127-
name string
128-
endpoint string
129-
inDockerContainer bool
130-
want string
131-
}{
132-
{
133-
name: "override when not in docker",
134-
endpoint: "http://model-runner.docker.internal/engines/v1/",
135-
inDockerContainer: false,
136-
want: "http://localhost:12434/engines/llama.cpp/v1",
137-
},
138-
{
139-
name: "no override when in docker",
140-
endpoint: "http://model-runner.docker.internal/engines/v1/",
141-
inDockerContainer: true,
142-
want: "http://model-runner.docker.internal/engines/v1/",
143-
},
144-
{
145-
name: "other endpoint unchanged",
146-
endpoint: "http://example/engines/v1/",
147-
inDockerContainer: false,
148-
want: "http://example/engines/v1/",
149-
},
150-
{
151-
name: "empty endpoint unchanged",
152-
endpoint: "",
153-
inDockerContainer: false,
154-
want: "",
155-
},
156-
}
157-
158-
for _, tt := range tests {
159-
t.Run(tt.name, func(t *testing.T) {
160-
got := normalizeDMREndpoint(tt.endpoint, tt.inDockerContainer)
161-
if got != tt.want {
162-
t.Fatalf("normalizeDMREndpoint(%q, %v) = %q, want %q", tt.endpoint, tt.inDockerContainer, got, tt.want)
163-
}
164-
})
165-
}
166-
}
167-
168-
func TestNormalizeDMREndpoint_EnvOverride(t *testing.T) {
169-
t.Setenv("MODEL_RUNNER_HOST", "http://myhost:9999/custom/v1")
170-
171-
tests := []struct {
172-
name string
173-
endpoint string
174-
inDockerContainer bool
175-
want string
176-
}{
177-
{
178-
name: "env overrides non-container default endpoint",
179-
endpoint: "http://model-runner.docker.internal/engines/v1/",
180-
inDockerContainer: false,
181-
want: "http://myhost:9999/custom/v1",
182-
},
183-
{
184-
name: "env overrides in-container default endpoint",
185-
endpoint: "http://model-runner.docker.internal/engines/v1/",
186-
inDockerContainer: true,
187-
want: "http://myhost:9999/custom/v1",
188-
},
189-
{
190-
name: "env overrides arbitrary endpoint",
191-
endpoint: "http://example/engines/v1/",
192-
inDockerContainer: false,
193-
want: "http://myhost:9999/custom/v1",
194-
},
195-
}
196-
197-
for _, tt := range tests {
198-
t.Run(tt.name, func(t *testing.T) {
199-
got := normalizeDMREndpoint(tt.endpoint, tt.inDockerContainer)
200-
if got != tt.want {
201-
t.Fatalf("normalizeDMREndpoint should prefer env var: got %q, want %q", got, tt.want)
202-
}
203-
})
204-
}
205-
}

0 commit comments

Comments
 (0)