Skip to content

Commit c82363a

Browse files
authored
Merge pull request #602 from dgageot/auto-pull-dmr
Auto-pull DMR models
2 parents aa1d241 + 7986fb5 commit c82363a

5 files changed

Lines changed: 108 additions & 35 deletions

File tree

cmd/root/new.go

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ import (
88
"github.com/spf13/cobra"
99

1010
"github.com/docker/cagent/pkg/creator"
11+
"github.com/docker/cagent/pkg/input"
1112
"github.com/docker/cagent/pkg/runtime"
1213
"github.com/docker/cagent/pkg/telemetry"
1314
)
@@ -83,7 +84,7 @@ func NewNewCmd() *cobra.Command {
8384
fmt.Print(blue("> "))
8485

8586
var err error
86-
prompt, err = readLine(ctx, os.Stdin)
87+
prompt, err = input.ReadLine(ctx, os.Stdin)
8788
if err != nil {
8889
return fmt.Errorf("failed to read purpose: %w", err)
8990
}

cmd/root/run.go

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ import (
2424
"github.com/docker/cagent/pkg/config"
2525
"github.com/docker/cagent/pkg/content"
2626
"github.com/docker/cagent/pkg/evaluation"
27+
"github.com/docker/cagent/pkg/input"
2728
"github.com/docker/cagent/pkg/remote"
2829
"github.com/docker/cagent/pkg/runtime"
2930
"github.com/docker/cagent/pkg/session"
@@ -594,7 +595,7 @@ func runWithoutTUI(ctx context.Context, agentFilename string, rt runtime.Runtime
594595
fmt.Print(blue("> "))
595596
firstQuestion = false
596597

597-
line, err := readLine(ctx, os.Stdin)
598+
line, err := input.ReadLine(ctx, os.Stdin)
598599
if err != nil {
599600
return err
600601
}
@@ -665,8 +666,8 @@ func runUserCommand(userInput string, sess *session.Session, rt runtime.Runtime,
665666

666667
// parseAttachCommand parses user input for /attach commands
667668
// Returns the message text (with /attach commands removed) and the attachment path
668-
func parseAttachCommand(input string) (messageText, attachPath string) {
669-
lines := strings.Split(input, "\n")
669+
func parseAttachCommand(userInput string) (messageText, attachPath string) {
670+
lines := strings.Split(userInput, "\n")
670671
var messageLines []string
671672

672673
for _, line := range lines {

cmd/root/run_text_utils.go

Lines changed: 4 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
package root
22

33
import (
4-
"bufio"
54
"context"
65
"encoding/json"
76
"fmt"
@@ -12,6 +11,7 @@ import (
1211
"github.com/fatih/color"
1312
"golang.org/x/term"
1413

14+
"github.com/docker/cagent/pkg/input"
1515
"github.com/docker/cagent/pkg/tools"
1616
)
1717

@@ -98,7 +98,7 @@ func printToolCallWithConfirmation(ctx context.Context, toolCall tools.ToolCall,
9898
}
9999

100100
// Fallback: line-based scanner (requires Enter)
101-
text, err := readLine(ctx, rd)
101+
text, err := input.ReadLine(ctx, rd)
102102
if err != nil {
103103
return ConfirmationReject
104104
}
@@ -125,7 +125,7 @@ func promptMaxIterationsContinue(ctx context.Context, maxIterations int) Confirm
125125
fmt.Printf("%s\n", white("This can happen with smaller or less capable models."))
126126
fmt.Printf("\n%s (y/n): ", blue("Do you want to continue for 10 more iterations?"))
127127

128-
response, err := readLine(ctx, os.Stdin)
128+
response, err := input.ReadLine(ctx, os.Stdin)
129129
if err != nil {
130130
fmt.Printf("\n%s\n", red("Failed to read input, exiting..."))
131131
return ConfirmationAbort
@@ -148,7 +148,7 @@ func promptOAuthAuthorization(ctx context.Context, serverURL string) Confirmatio
148148
fmt.Printf("%s\n", white("Your browser will open automatically to complete the authorization."))
149149
fmt.Printf("\n%s (y/n): ", blue("Do you want to authorize access?"))
150150

151-
response, err := readLine(ctx, os.Stdin)
151+
response, err := input.ReadLine(ctx, os.Stdin)
152152
if err != nil {
153153
fmt.Printf("\n%s\n", red("Failed to read input, aborting authorization..."))
154154
return ConfirmationAbort
@@ -291,30 +291,3 @@ func formatJSONValue(key string, value any) string {
291291
return fmt.Sprintf("%s: %s", bold(key), string(jsonBytes))
292292
}
293293
}
294-
295-
func readLine(ctx context.Context, rd io.Reader) (string, error) {
296-
lines := make(chan string)
297-
errs := make(chan error)
298-
299-
go func() {
300-
defer close(lines)
301-
defer close(errs)
302-
303-
reader := bufio.NewReader(rd)
304-
line, err := reader.ReadString('\n')
305-
if err != nil {
306-
errs <- err
307-
} else {
308-
lines <- line
309-
}
310-
}()
311-
312-
select {
313-
case <-ctx.Done():
314-
return "", ctx.Err()
315-
case err := <-errs:
316-
return "", err
317-
case line := <-lines:
318-
return line, nil
319-
}
320-
}

pkg/input/readline.go

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
package input
2+
3+
import (
4+
"bufio"
5+
"context"
6+
"io"
7+
)
8+
9+
func ReadLine(ctx context.Context, rd io.Reader) (string, error) {
10+
lines := make(chan string)
11+
errs := make(chan error)
12+
13+
go func() {
14+
defer close(lines)
15+
defer close(errs)
16+
17+
reader := bufio.NewReader(rd)
18+
line, err := reader.ReadString('\n')
19+
if err != nil {
20+
errs <- err
21+
} else {
22+
lines <- line
23+
}
24+
}()
25+
26+
select {
27+
case <-ctx.Done():
28+
return "", ctx.Err()
29+
case err := <-errs:
30+
return "", err
31+
case line := <-lines:
32+
return line, nil
33+
}
34+
}

pkg/model/provider/dmr/client.go

Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ import (
66
"encoding/json"
77
"errors"
88
"fmt"
9+
"io"
910
"log/slog"
1011
"net"
1112
"net/http"
@@ -15,9 +16,11 @@ import (
1516
"strings"
1617

1718
"github.com/sashabaranov/go-openai"
19+
"golang.org/x/term"
1820

1921
"github.com/docker/cagent/pkg/chat"
2022
latest "github.com/docker/cagent/pkg/config/v2"
23+
"github.com/docker/cagent/pkg/input"
2124
"github.com/docker/cagent/pkg/model/provider/base"
2225
"github.com/docker/cagent/pkg/model/provider/options"
2326
"github.com/docker/cagent/pkg/tools"
@@ -51,6 +54,12 @@ func NewClient(ctx context.Context, cfg *latest.ModelConfig, opts ...options.Opt
5154
endpoint, engine, err := getDockerModelEndpointAndEngine(ctx)
5255
if err != nil {
5356
slog.Debug("docker model status query failed", "error", err)
57+
} else {
58+
// Auto-pull the model if needed
59+
if err := pullDockerModelIfNeeded(ctx, cfg.Model); err != nil {
60+
slog.Debug("docker model pull failed", "error", err)
61+
return nil, err
62+
}
5463
}
5564

5665
clientConfig := openai.DefaultConfig("")
@@ -430,6 +439,61 @@ func parseDMRProviderOpts(cfg *latest.ModelConfig) (contextSize int, runtimeFlag
430439
return contextSize, runtimeFlags
431440
}
432441

442+
func pullDockerModelIfNeeded(ctx context.Context, model string) error {
443+
// Check if running in interactive mode (stdin is a terminal)
444+
interactive := term.IsTerminal(int(os.Stdin.Fd()))
445+
if !interactive {
446+
// In non-interactive mode (CI / Servers), do not attempt to pull the model
447+
return nil
448+
}
449+
450+
if modelExists(ctx, model) {
451+
slog.Debug("Model already exists, skipping pull", "model", model)
452+
return nil
453+
}
454+
455+
// Prompt user for confirmation in interactive mode
456+
fmt.Printf("\nModel %s not found locally.\n", model)
457+
fmt.Printf("Do you want to pull it now? ([y]es/[n]o): ")
458+
459+
response, err := input.ReadLine(ctx, os.Stdin)
460+
if err != nil {
461+
return fmt.Errorf("failed to read user input: %w", err)
462+
}
463+
464+
response = strings.TrimSpace(strings.ToLower(response))
465+
if response != "y" && response != "yes" {
466+
return fmt.Errorf("model pull declined by user")
467+
}
468+
469+
// Pull the model
470+
slog.Info("Pulling DMR model", "model", model)
471+
fmt.Printf("Pulling model %s...\n", model)
472+
cmd := exec.CommandContext(ctx, "docker", "model", "pull", model)
473+
cmd.Stdout = os.Stdout
474+
cmd.Stderr = os.Stderr
475+
if err := cmd.Run(); err != nil {
476+
return fmt.Errorf("failed to pull model %s: %w", model, err)
477+
}
478+
479+
slog.Info("Model pulled successfully", "model", model)
480+
fmt.Printf("Model %s pulled successfully.\n", model)
481+
482+
return nil
483+
}
484+
485+
func modelExists(ctx context.Context, model string) bool {
486+
cmd := exec.CommandContext(ctx, "docker", "model", "inspect", model)
487+
var stderr bytes.Buffer
488+
cmd.Stdout = io.Discard
489+
cmd.Stderr = &stderr
490+
if err := cmd.Run(); err != nil {
491+
slog.Debug("Model does not exist", "model", model, "error", strings.TrimSpace(stderr.String()))
492+
return false
493+
}
494+
return true
495+
}
496+
433497
func configureDockerModel(ctx context.Context, model string, contextSize int, runtimeFlags []string) error {
434498
args := buildDockerModelConfigureArgs(model, contextSize, runtimeFlags)
435499

0 commit comments

Comments
 (0)