Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 29 additions & 0 deletions pkg/cli/weights.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package cli
import (
"context"
"fmt"
"path"
"path/filepath"
"time"

Expand Down Expand Up @@ -123,14 +124,21 @@ func weightsImportCommand(cmd *cobra.Command, args []string, dryRun, verbose boo
builder = model.NewWeightBuilder(src, fileStore, lockPath)

console.Infof("Building %d weight(s)...", len(weightSpecs))
buildProgress := docker.NewProgressWriter()
builder.SetProgressFn(func(prog model.WeightBuildProgress) {
writeWeightBuildProgress(buildProgress, prog)
})

release, err := src.DotCog.Lock(ctx)
if err != nil {
buildProgress.Close()
return err
}
defer release()

artifacts, err := buildWeightArtifactsFromPlans(ctx, builder, weightSpecs, plans)
buildProgress.Close()
builder.SetProgressFn(nil)
if err != nil {
return err
}
Expand All @@ -150,6 +158,27 @@ func weightsImportCommand(cmd *cobra.Command, args []string, dryRun, verbose boo
return pushWeightArtifacts(ctx, repo, artifacts, "Imported")
}

func writeWeightBuildProgress(pw *docker.ProgressWriter, prog model.WeightBuildProgress) {
id := prog.WeightName
if prog.FilePath != "" {
file := path.Base(prog.FilePath)
if id == "" {
id = file
} else {
id += "/" + file
}
}
if id == "" {
id = model.ShortDigest(prog.FileDigest)
}

if prog.Done {
Comment thread
anish-sahoo marked this conversation as resolved.
pw.WriteStatus(id, "Download complete")
return
}
pw.Write(id, "Downloading", prog.Complete, prog.Total)
}

// planWeightImports runs PlanImport for each spec without side effects.
func planWeightImports(ctx context.Context, builder *model.WeightBuilder, specs []*model.WeightSpec) ([]*model.WeightImportPlan, error) {
plans := make([]*model.WeightImportPlan, 0, len(specs))
Expand Down
32 changes: 30 additions & 2 deletions pkg/cli/weights_pull.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,12 @@ package cli

import (
"fmt"
"path"
"path/filepath"

"github.com/spf13/cobra"

"github.com/replicate/cog/pkg/docker"
"github.com/replicate/cog/pkg/model"
"github.com/replicate/cog/pkg/paths"
"github.com/replicate/cog/pkg/util/console"
Expand Down Expand Up @@ -78,14 +80,16 @@ func weightsPullCommand(cmd *cobra.Command, args []string, verbose bool) error {
console.Info("")
}

results, err := mgr.Pull(ctx, args, pullEventPrinter(verbose))
progress := docker.NewProgressWriter()
results, err := mgr.Pull(ctx, args, pullEventPrinter(verbose, progress))
progress.Close()
printPullSummary(results, verbose)
return err
}

// pullEventPrinter returns a PullEvent handler that writes progress to
// the console. Verbose mode adds per-layer / per-file detail.
func pullEventPrinter(verbose bool) func(weights.PullEvent) {
func pullEventPrinter(verbose bool, progress *docker.ProgressWriter) func(weights.PullEvent) {
return func(e weights.PullEvent) {
switch e.Kind {
case weights.PullEventWeightStart:
Expand All @@ -109,7 +113,15 @@ func pullEventPrinter(verbose bool) func(weights.PullEvent) {
size = formatSize(e.LayerSize)
}
console.Infof(" layer %s (%s)", model.ShortDigest(e.LayerDigest), size)
case weights.PullEventFileProgress:
if progress == nil {
return
}
progress.Write(pullProgressID(e), "Downloading", e.FileComplete, e.FileSize)
case weights.PullEventFileStored:
if progress != nil {
progress.WriteStatus(pullProgressID(e), "Download complete")
}
if !verbose {
return
}
Expand All @@ -126,6 +138,22 @@ func pullEventPrinter(verbose bool) func(weights.PullEvent) {
}
}

func pullProgressID(e weights.PullEvent) string {
id := e.Weight
if e.FilePath != "" {
file := path.Base(e.FilePath)
if id == "" {
id = file
} else {
id += "/" + file
}
}
if id == "" {
id = model.ShortDigest(e.FileDigest)
}
return id
}

func printPullSummary(results []weights.PullResult, verbose bool) {
if len(results) == 0 {
return
Expand Down
101 changes: 97 additions & 4 deletions pkg/model/packer.go
Original file line number Diff line number Diff line change
Expand Up @@ -382,6 +382,28 @@ func packedFilesFromPlan(layers []packedLayer) []packedFile {
// silently producing a tar whose member digest disagrees with the
// inventory.
func ingressFromInventory(ctx context.Context, owners map[string]weightsource.Source, st store.Store, inv weightsource.Inventory) error {
return ingressFromInventoryWithProgress(ctx, "", owners, st, inv, nil)
}

// WeightBuildProgress reports progress while missing source files are fetched
// into the local content-addressed weight store during import.
type WeightBuildProgress struct {
WeightName string
FilePath string
FileDigest string
Complete int64
Total int64
Done bool
}

func ingressFromInventoryWithProgress(ctx context.Context, weightName string, owners map[string]weightsource.Source, st store.Store, inv weightsource.Inventory, progressFn func(WeightBuildProgress)) error {
type missingFile struct {
source weightsource.Source
file weightsource.InventoryFile
}

missing := make([]missingFile, 0, len(inv.Files))
var total int64
for _, f := range inv.Files {
if err := ctx.Err(); err != nil {
return err
Expand All @@ -397,20 +419,91 @@ func ingressFromInventory(ctx context.Context, owners map[string]weightsource.So
if !ok {
return fmt.Errorf("no source owner for file %s", f.Path)
}
if err := ingressOne(ctx, src, st, f); err != nil {
return fmt.Errorf("ingress %s: %w", f.Path, err)
missing = append(missing, missingFile{source: src, file: f})
total += f.Size
}

var complete int64
for _, m := range missing {
if err := ingressOne(ctx, weightName, m.source, st, m.file, complete, total, progressFn); err != nil {
return fmt.Errorf("ingress %s: %w", m.file.Path, err)
}
complete += m.file.Size
}
return nil
}

func ingressOne(ctx context.Context, src weightsource.Source, st store.Store, f weightsource.InventoryFile) error {
func ingressOne(ctx context.Context, weightName string, src weightsource.Source, st store.Store, f weightsource.InventoryFile, baseComplete, total int64, progressFn func(WeightBuildProgress)) error {
if progressFn != nil {
progressFn(WeightBuildProgress{
WeightName: weightName,
FilePath: f.Path,
FileDigest: f.Digest,
Complete: baseComplete,
Total: total,
})
}

rc, err := src.Open(ctx, f.Path)
if err != nil {
return fmt.Errorf("open source: %w", err)
}
defer rc.Close() //nolint:errcheck // best-effort close on read path
return st.PutFile(ctx, f.Digest, f.Size, rc)

var r io.Reader = rc
if progressFn != nil {
r = &progressReader{
r: rc,
interval: 250 * time.Millisecond,
fn: func(complete int64) {
progressFn(WeightBuildProgress{
WeightName: weightName,
FilePath: f.Path,
FileDigest: f.Digest,
Complete: baseComplete + complete,
Total: total,
})
},
}
}

if err := st.PutFile(ctx, f.Digest, f.Size, r); err != nil {
return err
}
if progressFn != nil {
progressFn(WeightBuildProgress{
WeightName: weightName,
FilePath: f.Path,
FileDigest: f.Digest,
Complete: baseComplete + f.Size,
Total: total,
Done: true,
})
}
return nil
}

type progressReader struct {
r io.Reader
complete int64
lastReported int64
lastUpdate time.Time
interval time.Duration
fn func(int64)
}

func (r *progressReader) Read(p []byte) (int, error) {
n, err := r.r.Read(p)
if n > 0 {
r.complete += int64(n)
now := time.Now()
if r.lastReported == 0 || now.Sub(r.lastUpdate) >= r.interval {
r.lastReported = r.complete
r.lastUpdate = now
r.fn(r.complete)
}
}
return n, err
}

// writeLayer writes the in-tar layout for a layer: deterministic
Expand Down
37 changes: 37 additions & 0 deletions pkg/model/packer_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,43 @@ func TestPack_SingleSmallFile(t *testing.T) {
assert.Equal(t, "config.json", entries[0])
}

func TestIngressFromInventoryReportsProgress(t *testing.T) {
dir := t.TempDir()
content := []byte("download progress")
relPath := "model.safetensors"
require.NoError(t, os.WriteFile(filepath.Join(dir, relPath), content, 0o644))

src, err := weightsource.NewFileSource("file://"+dir, "")
require.NoError(t, err)
inv, err := src.Inventory(t.Context())
require.NoError(t, err)
require.Len(t, inv.Files, 1)

st, err := store.NewFileStore(t.TempDir())
require.NoError(t, err)

var events []WeightBuildProgress
err = ingressFromInventoryWithProgress(t.Context(), "test-weight", sourceOwners(src, inv), st, inv, func(event WeightBuildProgress) {
events = append(events, event)
})
require.NoError(t, err)
require.NotEmpty(t, events)

first := events[0]
assert.Equal(t, "test-weight", first.WeightName)
assert.Equal(t, relPath, first.FilePath)
assert.Equal(t, int64(0), first.Complete)
assert.Equal(t, int64(len(content)), first.Total)
assert.False(t, first.Done)

last := events[len(events)-1]
assert.Equal(t, "test-weight", last.WeightName)
assert.Equal(t, relPath, last.FilePath)
assert.Equal(t, int64(len(content)), last.Complete)
assert.Equal(t, int64(len(content)), last.Total)
assert.True(t, last.Done)
}

func TestPack_SingleLargeFile_Incompressible(t *testing.T) {
dir := t.TempDir()
createTestFile(t, dir, "model.safetensors", 100*1024*1024) // 100 MB
Expand Down
14 changes: 10 additions & 4 deletions pkg/model/weight_builder.go
Original file line number Diff line number Diff line change
Expand Up @@ -137,9 +137,10 @@ func (b *WeightBuilder) resolveInventory(ctx context.Context, ws *WeightSpec) (*
// digest it writes into the artifact descriptor is a sha256 of the
// serialized manifest bytes.
type WeightBuilder struct {
source *Source
store store.Store
lockPath string
source *Source
store store.Store
lockPath string
progressFn func(WeightBuildProgress)
}

// NewWeightBuilder creates a WeightBuilder.
Expand All @@ -150,6 +151,11 @@ func NewWeightBuilder(source *Source, st store.Store, lockPath string) *WeightBu
return &WeightBuilder{source: source, store: st, lockPath: lockPath}
}

// SetProgressFn sets an optional callback for import-time file fetch progress.
func (b *WeightBuilder) SetProgressFn(fn func(WeightBuildProgress)) {
b.progressFn = fn
}

// Build runs the full import pipeline for one weight:
//
// 1. Inventory the source.
Expand Down Expand Up @@ -200,7 +206,7 @@ func (b *WeightBuilder) buildWithResolved(ctx context.Context, spec ArtifactSpec
inv := weightsource.Inventory{Files: resolved.mergedFiles}

// Step 2: ingress the filtered files into the local store.
if err := ingressFromInventory(ctx, resolved.owners, b.store, inv); err != nil {
if err := ingressFromInventoryWithProgress(ctx, ws.Name(), resolved.owners, b.store, inv, b.progressFn); err != nil {
return nil, fmt.Errorf("populate store for weight %q: %w", ws.Name(), err)
}

Expand Down
Loading
Loading