Skip to content

Commit b6b6f50

Browse files
committed
Don't override DMR endpoint if we're running in a docker container
closes #118 Signed-off-by: Christopher Petito <chrisjpetito@gmail.com>
1 parent e19a9f3 commit b6b6f50

2 files changed

Lines changed: 112 additions & 6 deletions

File tree

pkg/model/provider/dmr/client.go

Lines changed: 26 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ import (
77
"errors"
88
"fmt"
99
"log/slog"
10+
"os"
1011
"os/exec"
1112
"strconv"
1213
"strings"
@@ -501,10 +502,16 @@ func getDockerModelEndpointAndEngine() (endpoint, engine string, err error) {
501502
return "", "", err
502503
}
503504
endpoint = strings.TrimSpace(st.Endpoint)
504-
// TODO(krissetto): temporary override for the internal dmr endpoint that `docker model status --json` currently returns
505-
if endpoint == "http://model-runner.docker.internal/engines/v1/" {
506-
endpoint = "http://localhost:12434/engines/llama.cpp/v1"
505+
506+
inDockerContainer := false
507+
finfo, err := os.Stat("/.dockerenv")
508+
if err == nil && finfo.Mode().IsRegular() {
509+
inDockerContainer = true
507510
}
511+
512+
// normalize endpoint considering container environment
513+
endpoint = normalizeDMREndpoint(endpoint, inDockerContainer)
514+
508515
engine = strings.TrimSpace(st.Engine)
509516
if engine == "" {
510517
if st.Backends != nil {
@@ -524,6 +531,22 @@ func getDockerModelEndpointAndEngine() (endpoint, engine string, err error) {
524531
return endpoint, engine, nil
525532
}
526533

534+
// normalizeDMREndpoint applies an override to the endpoint reported by
535+
// `docker model status --json` to ensure the DMR client uses a reachable address
536+
// from the current environment.
537+
func normalizeDMREndpoint(endpoint string, inDockerContainer bool) string {
538+
// This env overriding might need to be updated if we end up having multiple separate DMR
539+
// engines with different endpoints running at the same time
540+
if hostEnvVar := os.Getenv("MODEL_RUNNER_HOST"); hostEnvVar != "" {
541+
return hostEnvVar
542+
}
543+
// Only override if not running in a docker container
544+
if endpoint == "http://model-runner.docker.internal/engines/v1/" && !inDockerContainer {
545+
return "http://localhost:12434/engines/llama.cpp/v1"
546+
}
547+
return endpoint
548+
}
549+
527550
// buildRuntimeFlagsFromModelConfig converts standard ModelConfig fields into backend-specific
528551
// runtime flags that the model-runner understands when launching the engine.
529552
// Currently supports the default engine "llama.cpp". Unknown/unsupported fields are ignored.

pkg/model/provider/dmr/client_test.go

Lines changed: 86 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
package dmr
22

33
import (
4+
"context"
45
"reflect"
56
"testing"
67

@@ -15,7 +16,7 @@ func TestNewClientWithDefaultBaseURL(t *testing.T) {
1516
// BaseURL is empty, should use default
1617
}
1718

18-
client, err := NewClient(t.Context(), cfg)
19+
client, err := NewClient(context.Background(), cfg)
1920
if err != nil {
2021
t.Fatalf("Expected no error, got %v", err)
2122
}
@@ -34,7 +35,7 @@ func TestNewClientWithExplicitBaseURL(t *testing.T) {
3435
BaseURL: customURL,
3536
}
3637

37-
client, err := NewClient(t.Context(), cfg)
38+
client, err := NewClient(context.Background(), cfg)
3839
if err != nil {
3940
t.Fatalf("Expected no error, got %v", err)
4041
}
@@ -51,7 +52,7 @@ func TestNewClientWithWrongType(t *testing.T) {
5152
Model: "gpt-4",
5253
}
5354

54-
_, err := NewClient(t.Context(), cfg)
55+
_, err := NewClient(context.Background(), cfg)
5556
if err == nil {
5657
t.Fatal("Expected error for wrong model type, got nil")
5758
}
@@ -120,3 +121,85 @@ func TestMergeRuntimeFlagsPreferUser_WarnsAndPrefersUser(t *testing.T) {
120121
t.Fatalf("unexpected merged flags.\nexpected: %#v\nactual: %#v", expected, merged)
121122
}
122123
}
124+
125+
func TestNormalizeDMREndpoint_NoEnvOverride(t *testing.T) {
126+
tests := []struct {
127+
name string
128+
endpoint string
129+
inDockerContainer bool
130+
want string
131+
}{
132+
{
133+
name: "override when not in docker",
134+
endpoint: "http://model-runner.docker.internal/engines/v1/",
135+
inDockerContainer: false,
136+
want: "http://localhost:12434/engines/llama.cpp/v1",
137+
},
138+
{
139+
name: "no override when in docker",
140+
endpoint: "http://model-runner.docker.internal/engines/v1/",
141+
inDockerContainer: true,
142+
want: "http://model-runner.docker.internal/engines/v1/",
143+
},
144+
{
145+
name: "other endpoint unchanged",
146+
endpoint: "http://example/engines/v1/",
147+
inDockerContainer: false,
148+
want: "http://example/engines/v1/",
149+
},
150+
{
151+
name: "empty endpoint unchanged",
152+
endpoint: "",
153+
inDockerContainer: false,
154+
want: "",
155+
},
156+
}
157+
158+
for _, tt := range tests {
159+
t.Run(tt.name, func(t *testing.T) {
160+
got := normalizeDMREndpoint(tt.endpoint, tt.inDockerContainer)
161+
if got != tt.want {
162+
t.Fatalf("normalizeDMREndpoint(%q, %v) = %q, want %q", tt.endpoint, tt.inDockerContainer, got, tt.want)
163+
}
164+
})
165+
}
166+
}
167+
168+
func TestNormalizeDMREndpoint_EnvOverride(t *testing.T) {
169+
t.Setenv("MODEL_RUNNER_HOST", "http://myhost:9999/custom/v1")
170+
171+
tests := []struct {
172+
name string
173+
endpoint string
174+
inDockerContainer bool
175+
want string
176+
}{
177+
{
178+
name: "env overrides non-container default endpoint",
179+
endpoint: "http://model-runner.docker.internal/engines/v1/",
180+
inDockerContainer: false,
181+
want: "http://myhost:9999/custom/v1",
182+
},
183+
{
184+
name: "env overrides in-container default endpoint",
185+
endpoint: "http://model-runner.docker.internal/engines/v1/",
186+
inDockerContainer: true,
187+
want: "http://myhost:9999/custom/v1",
188+
},
189+
{
190+
name: "env overrides arbitrary endpoint",
191+
endpoint: "http://example/engines/v1/",
192+
inDockerContainer: false,
193+
want: "http://myhost:9999/custom/v1",
194+
},
195+
}
196+
197+
for _, tt := range tests {
198+
t.Run(tt.name, func(t *testing.T) {
199+
got := normalizeDMREndpoint(tt.endpoint, tt.inDockerContainer)
200+
if got != tt.want {
201+
t.Fatalf("normalizeDMREndpoint should prefer env var: got %q, want %q", got, tt.want)
202+
}
203+
})
204+
}
205+
}

0 commit comments

Comments
 (0)