77 "errors"
88 "fmt"
99 "log/slog"
10+ "net"
11+ "net/http"
1012 "os"
1113 "os/exec"
1214 "strconv"
@@ -29,7 +31,7 @@ type Client struct {
2931}
3032
3133// NewClient creates a new DMR client from the provided configuration
32- func NewClient (_ context.Context , cfg * latest.ModelConfig , opts ... options.Opt ) (* Client , error ) {
34+ func NewClient (ctx context.Context , cfg * latest.ModelConfig , opts ... options.Opt ) (* Client , error ) {
3335 if cfg == nil {
3436 slog .Error ("DMR client creation failed" , "error" , "model configuration is required" )
3537 return nil , errors .New ("model configuration is required" )
@@ -45,49 +47,63 @@ func NewClient(_ context.Context, cfg *latest.ModelConfig, opts ...options.Opt)
4547 opt (& globalOptions )
4648 }
4749
48- // Resolve base_url for DMR models. If not provided, configure with the docker model plugin, else fallback.
49- baseURL := cfg .BaseURL
50- if baseURL == "" {
51- endpoint , engine , err := getDockerModelEndpointAndEngine ()
52- if err != nil {
53- slog .Debug ("docker model status query failed" , "error" , err )
54- }
50+ endpoint , engine , err := getDockerModelEndpointAndEngine (ctx )
51+ if err != nil {
52+ slog .Debug ("docker model status query failed" , "error" , err )
53+ }
5554
56- // Build runtime flags from ModelConfig and engine
57- contextSize , providerRuntimeFlags := parseDMRProviderOpts (cfg )
58- configFlags := buildRuntimeFlagsFromModelConfig (engine , cfg )
59- finalFlags , warnings := mergeRuntimeFlagsPreferUser (configFlags , providerRuntimeFlags )
60- for _ , w := range warnings {
61- slog .Warn (w )
62- }
63- slog .Debug ("DMR provider_opts parsed" , "model" , cfg .Model , "context_size" , contextSize , "runtime_flags" , finalFlags , "engine" , engine )
64- if err := configureDockerModel (cfg .Model , contextSize , finalFlags ); err != nil {
65- slog .Debug ("docker model configure skipped or failed" , "error" , err )
66- }
55+ clientConfig := openai .DefaultConfig ("" )
6756
68- if endpoint != "" {
69- baseURL = endpoint
70- slog .Debug ("Using docker model endpoint for DMR base_url" , "base_url" , baseURL )
71- } else {
72- baseURL = "http://localhost:12434/engines/llama.cpp/v1"
73- slog .Debug ("Using default DMR base_url" , "base_url" , baseURL )
57+ switch {
58+ case cfg .BaseURL != "" :
59+ clientConfig .BaseURL = cfg .BaseURL
60+ case os .Getenv ("MODEL_RUNNER_HOST" ) != "" :
61+ clientConfig .BaseURL = os .Getenv ("MODEL_RUNNER_HOST" )
62+ case inContainer ():
63+ // This won't work with Docker CE but we have no way to detect that from inside the container.
64+ clientConfig .BaseURL = "http://model-runner.docker.internal/engines/v1/"
65+ case endpoint == "http://model-runner.docker.internal/engines/v1/" :
66+ // Docker Desktop
67+ clientConfig .BaseURL = "http://_/exp/vDD4.40/engines/v1"
68+ clientConfig .HTTPClient = & http.Client {
69+ Transport : & http.Transport {
70+ DialContext : func (ctx context.Context , _ , _ string ) (net.Conn , error ) {
71+ var d net.Dialer
72+ return d .DialContext (ctx , "unix" , "/var/run/docker.sock" )
73+ },
74+ },
7475 }
76+ default :
77+ // Docker CE
78+ clientConfig .BaseURL = endpoint
7579 }
7680
77- slog .Debug ("Creating DMR client config" , "base_url" , baseURL )
78- clientConfig := openai .DefaultConfig ("" )
79- clientConfig .BaseURL = baseURL
81+ // Build runtime flags from ModelConfig and engine
82+ contextSize , providerRuntimeFlags := parseDMRProviderOpts (cfg )
83+ configFlags := buildRuntimeFlagsFromModelConfig (engine , cfg )
84+ finalFlags , warnings := mergeRuntimeFlagsPreferUser (configFlags , providerRuntimeFlags )
85+ for _ , w := range warnings {
86+ slog .Warn (w )
87+ }
88+ slog .Debug ("DMR provider_opts parsed" , "model" , cfg .Model , "context_size" , contextSize , "runtime_flags" , finalFlags , "engine" , engine )
89+ if err := configureDockerModel (ctx , cfg .Model , contextSize , finalFlags ); err != nil {
90+ slog .Debug ("docker model configure skipped or failed" , "error" , err )
91+ }
8092
81- client := openai .NewClientWithConfig (clientConfig )
82- slog .Debug ("DMR client created successfully" , "model" , cfg .Model , "base_url" , baseURL )
93+ slog .Debug ("DMR client created successfully" , "model" , cfg .Model , "base_url" , clientConfig .BaseURL )
8394
8495 return & Client {
85- client : client ,
96+ client : openai . NewClientWithConfig ( clientConfig ) ,
8697 config : cfg ,
87- baseURL : baseURL ,
98+ baseURL : clientConfig . BaseURL ,
8899 }, nil
89100}
90101
102+ func inContainer () bool {
103+ finfo , err := os .Stat ("/.dockerenv" )
104+ return err == nil && finfo .Mode ().IsRegular ()
105+ }
106+
91107func convertMultiContent (multiContent []chat.MessagePart ) []openai.ChatMessagePart {
92108 openaiMultiContent := make ([]openai.ChatMessagePart , len (multiContent ))
93109 for i , part := range multiContent {
@@ -290,7 +306,8 @@ func (c *Client) CreateChatCompletionStream(
290306 "model" , c .config .Model ,
291307 "message_count" , len (messages ),
292308 "tool_count" , len (requestTools ),
293- "base_url" , c .baseURL )
309+ "base_url" , c .baseURL ,
310+ )
294311
295312 if len (messages ) == 0 {
296313 slog .Error ("DMR stream creation failed" , "error" , "at least one message is required" )
@@ -366,7 +383,7 @@ func (c *Client) CreateChatCompletionStream(
366383 return nil , err
367384 }
368385
369- slog .Debug ("DMR chat completion stream created successfully" , "model" , c .config .Model )
386+ slog .Debug ("DMR chat completion stream created successfully" , "model" , c .config .Model , "base_url" , c . baseURL )
370387 return newStreamAdapter (stream , trackUsage ), nil
371388}
372389
@@ -383,7 +400,7 @@ func (c *Client) CreateChatCompletion(
383400
384401 response , err := c .client .CreateChatCompletion (ctx , request )
385402 if err != nil {
386- slog .Error ("DMR chat completion failed" , "error" , err , "model" , c .config .Model , "base_url" , c . baseURL )
403+ slog .Error ("DMR chat completion failed" , "error" , err , "model" , c .config .Model )
387404 return "" , err
388405 }
389406
@@ -464,10 +481,10 @@ func parseDMRProviderOpts(cfg *latest.ModelConfig) (contextSize int, runtimeFlag
464481 return contextSize , runtimeFlags
465482}
466483
467- func configureDockerModel (model string , contextSize int , runtimeFlags []string ) error {
484+ func configureDockerModel (ctx context. Context , model string , contextSize int , runtimeFlags []string ) error {
468485 args := buildDockerModelConfigureArgs (model , contextSize , runtimeFlags )
469486
470- cmd := exec .Command ( "docker" , args ... )
487+ cmd := exec .CommandContext ( ctx , "docker" , args ... )
471488 slog .Debug ("Running docker model configure" , "model" , model , "args" , args )
472489 var stdout , stderr bytes.Buffer
473490 cmd .Stdout = & stdout
@@ -494,14 +511,15 @@ func buildDockerModelConfigureArgs(model string, contextSize int, runtimeFlags [
494511 return args
495512}
496513
497- func getDockerModelEndpointAndEngine () (endpoint , engine string , err error ) {
498- cmd := exec .Command ( "docker" , "model" , "status" , "--json" )
514+ func getDockerModelEndpointAndEngine (ctx context. Context ) (endpoint , engine string , err error ) {
515+ cmd := exec .CommandContext ( ctx , "docker" , "model" , "status" , "--json" )
499516 var stdout , stderr bytes.Buffer
500517 cmd .Stdout = & stdout
501518 cmd .Stderr = & stderr
502519 if err := cmd .Run (); err != nil {
503520 return "" , "" , errors .New (strings .TrimSpace (stderr .String ()))
504521 }
522+
505523 type status struct {
506524 Running bool `json:"running"`
507525 Backends map [string ]string `json:"backends"`
@@ -512,16 +530,8 @@ func getDockerModelEndpointAndEngine() (endpoint, engine string, err error) {
512530 if err := json .Unmarshal (stdout .Bytes (), & st ); err != nil {
513531 return "" , "" , err
514532 }
515- endpoint = strings .TrimSpace (st .Endpoint )
516-
517- inDockerContainer := false
518- finfo , err := os .Stat ("/.dockerenv" )
519- if err == nil && finfo .Mode ().IsRegular () {
520- inDockerContainer = true
521- }
522533
523- // normalize endpoint considering container environment
524- endpoint = normalizeDMREndpoint (endpoint , inDockerContainer )
534+ endpoint = strings .TrimSpace (st .Endpoint )
525535
526536 engine = strings .TrimSpace (st .Engine )
527537 if engine == "" {
@@ -539,23 +549,8 @@ func getDockerModelEndpointAndEngine() (endpoint, engine string, err error) {
539549 if engine == "" {
540550 engine = "llama.cpp"
541551 }
542- return endpoint , engine , nil
543- }
544552
545- // normalizeDMREndpoint applies an override to the endpoint reported by
546- // `docker model status --json` to ensure the DMR client uses a reachable address
547- // from the current environment.
548- func normalizeDMREndpoint (endpoint string , inDockerContainer bool ) string {
549- // This env overriding might need to be updated if we end up having multiple separate DMR
550- // engines with different endpoints running at the same time
551- if hostEnvVar := os .Getenv ("MODEL_RUNNER_HOST" ); hostEnvVar != "" {
552- return hostEnvVar
553- }
554- // Only override if not running in a docker container
555- if endpoint == "http://model-runner.docker.internal/engines/v1/" && ! inDockerContainer {
556- return "http://localhost:12434/engines/llama.cpp/v1"
557- }
558- return endpoint
553+ return endpoint , engine , nil
559554}
560555
561556// buildRuntimeFlagsFromModelConfig converts standard ModelConfig fields into backend-specific
0 commit comments