diff --git a/pkg/cli/weights.go b/pkg/cli/weights.go index de9c1941a4..b207e57db6 100644 --- a/pkg/cli/weights.go +++ b/pkg/cli/weights.go @@ -3,6 +3,7 @@ package cli import ( "context" "fmt" + "path" "path/filepath" "time" @@ -123,14 +124,21 @@ func weightsImportCommand(cmd *cobra.Command, args []string, dryRun, verbose boo builder = model.NewWeightBuilder(src, fileStore, lockPath) console.Infof("Building %d weight(s)...", len(weightSpecs)) + buildProgress := docker.NewProgressWriter() + builder.SetProgressFn(func(prog model.WeightBuildProgress) { + writeWeightBuildProgress(buildProgress, prog) + }) release, err := src.DotCog.Lock(ctx) if err != nil { + buildProgress.Close() return err } defer release() artifacts, err := buildWeightArtifactsFromPlans(ctx, builder, weightSpecs, plans) + buildProgress.Close() + builder.SetProgressFn(nil) if err != nil { return err } @@ -150,6 +158,27 @@ func weightsImportCommand(cmd *cobra.Command, args []string, dryRun, verbose boo return pushWeightArtifacts(ctx, repo, artifacts, "Imported") } +func writeWeightBuildProgress(pw *docker.ProgressWriter, prog model.WeightBuildProgress) { + id := prog.WeightName + if prog.FilePath != "" { + file := path.Base(prog.FilePath) + if id == "" { + id = file + } else { + id += "/" + file + } + } + if id == "" { + id = model.ShortDigest(prog.FileDigest) + } + + if prog.Done { + pw.WriteStatus(id, "Download complete") + return + } + pw.Write(id, "Downloading", prog.Complete, prog.Total) +} + // planWeightImports runs PlanImport for each spec without side effects. func planWeightImports(ctx context.Context, builder *model.WeightBuilder, specs []*model.WeightSpec) ([]*model.WeightImportPlan, error) { plans := make([]*model.WeightImportPlan, 0, len(specs)) diff --git a/pkg/cli/weights_pull.go b/pkg/cli/weights_pull.go index 97ad2b79f1..2869a3f61d 100644 --- a/pkg/cli/weights_pull.go +++ b/pkg/cli/weights_pull.go @@ -2,10 +2,12 @@ package cli import ( "fmt" + "path" "path/filepath" "github.com/spf13/cobra" + "github.com/replicate/cog/pkg/docker" "github.com/replicate/cog/pkg/model" "github.com/replicate/cog/pkg/paths" "github.com/replicate/cog/pkg/util/console" @@ -78,14 +80,16 @@ func weightsPullCommand(cmd *cobra.Command, args []string, verbose bool) error { console.Info("") } - results, err := mgr.Pull(ctx, args, pullEventPrinter(verbose)) + progress := docker.NewProgressWriter() + results, err := mgr.Pull(ctx, args, pullEventPrinter(verbose, progress)) + progress.Close() printPullSummary(results, verbose) return err } // pullEventPrinter returns a PullEvent handler that writes progress to // the console. Verbose mode adds per-layer / per-file detail. -func pullEventPrinter(verbose bool) func(weights.PullEvent) { +func pullEventPrinter(verbose bool, progress *docker.ProgressWriter) func(weights.PullEvent) { return func(e weights.PullEvent) { switch e.Kind { case weights.PullEventWeightStart: @@ -109,7 +113,15 @@ func pullEventPrinter(verbose bool) func(weights.PullEvent) { size = formatSize(e.LayerSize) } console.Infof(" layer %s (%s)", model.ShortDigest(e.LayerDigest), size) + case weights.PullEventFileProgress: + if progress == nil { + return + } + progress.Write(pullProgressID(e), "Downloading", e.FileComplete, e.FileSize) case weights.PullEventFileStored: + if progress != nil { + progress.WriteStatus(pullProgressID(e), "Download complete") + } if !verbose { return } @@ -126,6 +138,22 @@ func pullEventPrinter(verbose bool) func(weights.PullEvent) { } } +func pullProgressID(e weights.PullEvent) string { + id := e.Weight + if e.FilePath != "" { + file := path.Base(e.FilePath) + if id == "" { + id = file + } else { + id += "/" + file + } + } + if id == "" { + id = model.ShortDigest(e.FileDigest) + } + return id +} + func printPullSummary(results []weights.PullResult, verbose bool) { if len(results) == 0 { return diff --git a/pkg/model/packer.go b/pkg/model/packer.go index 8026495dc3..6f8b3fcaa7 100644 --- a/pkg/model/packer.go +++ b/pkg/model/packer.go @@ -382,6 +382,28 @@ func packedFilesFromPlan(layers []packedLayer) []packedFile { // silently producing a tar whose member digest disagrees with the // inventory. func ingressFromInventory(ctx context.Context, owners map[string]weightsource.Source, st store.Store, inv weightsource.Inventory) error { + return ingressFromInventoryWithProgress(ctx, "", owners, st, inv, nil) +} + +// WeightBuildProgress reports progress while missing source files are fetched +// into the local content-addressed weight store during import. +type WeightBuildProgress struct { + WeightName string + FilePath string + FileDigest string + Complete int64 + Total int64 + Done bool +} + +func ingressFromInventoryWithProgress(ctx context.Context, weightName string, owners map[string]weightsource.Source, st store.Store, inv weightsource.Inventory, progressFn func(WeightBuildProgress)) error { + type missingFile struct { + source weightsource.Source + file weightsource.InventoryFile + } + + missing := make([]missingFile, 0, len(inv.Files)) + var total int64 for _, f := range inv.Files { if err := ctx.Err(); err != nil { return err @@ -397,20 +419,91 @@ func ingressFromInventory(ctx context.Context, owners map[string]weightsource.So if !ok { return fmt.Errorf("no source owner for file %s", f.Path) } - if err := ingressOne(ctx, src, st, f); err != nil { - return fmt.Errorf("ingress %s: %w", f.Path, err) + missing = append(missing, missingFile{source: src, file: f}) + total += f.Size + } + + var complete int64 + for _, m := range missing { + if err := ingressOne(ctx, weightName, m.source, st, m.file, complete, total, progressFn); err != nil { + return fmt.Errorf("ingress %s: %w", m.file.Path, err) } + complete += m.file.Size } return nil } -func ingressOne(ctx context.Context, src weightsource.Source, st store.Store, f weightsource.InventoryFile) error { +func ingressOne(ctx context.Context, weightName string, src weightsource.Source, st store.Store, f weightsource.InventoryFile, baseComplete, total int64, progressFn func(WeightBuildProgress)) error { + if progressFn != nil { + progressFn(WeightBuildProgress{ + WeightName: weightName, + FilePath: f.Path, + FileDigest: f.Digest, + Complete: baseComplete, + Total: total, + }) + } + rc, err := src.Open(ctx, f.Path) if err != nil { return fmt.Errorf("open source: %w", err) } defer rc.Close() //nolint:errcheck // best-effort close on read path - return st.PutFile(ctx, f.Digest, f.Size, rc) + + var r io.Reader = rc + if progressFn != nil { + r = &progressReader{ + r: rc, + interval: 250 * time.Millisecond, + fn: func(complete int64) { + progressFn(WeightBuildProgress{ + WeightName: weightName, + FilePath: f.Path, + FileDigest: f.Digest, + Complete: baseComplete + complete, + Total: total, + }) + }, + } + } + + if err := st.PutFile(ctx, f.Digest, f.Size, r); err != nil { + return err + } + if progressFn != nil { + progressFn(WeightBuildProgress{ + WeightName: weightName, + FilePath: f.Path, + FileDigest: f.Digest, + Complete: baseComplete + f.Size, + Total: total, + Done: true, + }) + } + return nil +} + +type progressReader struct { + r io.Reader + complete int64 + lastReported int64 + lastUpdate time.Time + interval time.Duration + fn func(int64) +} + +func (r *progressReader) Read(p []byte) (int, error) { + n, err := r.r.Read(p) + if n > 0 { + r.complete += int64(n) + now := time.Now() + if r.lastReported == 0 || now.Sub(r.lastUpdate) >= r.interval { + r.lastReported = r.complete + r.lastUpdate = now + r.fn(r.complete) + } + } + return n, err } // writeLayer writes the in-tar layout for a layer: deterministic diff --git a/pkg/model/packer_test.go b/pkg/model/packer_test.go index 91183ed864..31c29611ee 100644 --- a/pkg/model/packer_test.go +++ b/pkg/model/packer_test.go @@ -146,6 +146,43 @@ func TestPack_SingleSmallFile(t *testing.T) { assert.Equal(t, "config.json", entries[0]) } +func TestIngressFromInventoryReportsProgress(t *testing.T) { + dir := t.TempDir() + content := []byte("download progress") + relPath := "model.safetensors" + require.NoError(t, os.WriteFile(filepath.Join(dir, relPath), content, 0o644)) + + src, err := weightsource.NewFileSource("file://"+dir, "") + require.NoError(t, err) + inv, err := src.Inventory(t.Context()) + require.NoError(t, err) + require.Len(t, inv.Files, 1) + + st, err := store.NewFileStore(t.TempDir()) + require.NoError(t, err) + + var events []WeightBuildProgress + err = ingressFromInventoryWithProgress(t.Context(), "test-weight", sourceOwners(src, inv), st, inv, func(event WeightBuildProgress) { + events = append(events, event) + }) + require.NoError(t, err) + require.NotEmpty(t, events) + + first := events[0] + assert.Equal(t, "test-weight", first.WeightName) + assert.Equal(t, relPath, first.FilePath) + assert.Equal(t, int64(0), first.Complete) + assert.Equal(t, int64(len(content)), first.Total) + assert.False(t, first.Done) + + last := events[len(events)-1] + assert.Equal(t, "test-weight", last.WeightName) + assert.Equal(t, relPath, last.FilePath) + assert.Equal(t, int64(len(content)), last.Complete) + assert.Equal(t, int64(len(content)), last.Total) + assert.True(t, last.Done) +} + func TestPack_SingleLargeFile_Incompressible(t *testing.T) { dir := t.TempDir() createTestFile(t, dir, "model.safetensors", 100*1024*1024) // 100 MB diff --git a/pkg/model/weight_builder.go b/pkg/model/weight_builder.go index 28d2189d3d..5cebb0a05b 100644 --- a/pkg/model/weight_builder.go +++ b/pkg/model/weight_builder.go @@ -137,9 +137,10 @@ func (b *WeightBuilder) resolveInventory(ctx context.Context, ws *WeightSpec) (* // digest it writes into the artifact descriptor is a sha256 of the // serialized manifest bytes. type WeightBuilder struct { - source *Source - store store.Store - lockPath string + source *Source + store store.Store + lockPath string + progressFn func(WeightBuildProgress) } // NewWeightBuilder creates a WeightBuilder. @@ -150,6 +151,11 @@ func NewWeightBuilder(source *Source, st store.Store, lockPath string) *WeightBu return &WeightBuilder{source: source, store: st, lockPath: lockPath} } +// SetProgressFn sets an optional callback for import-time file fetch progress. +func (b *WeightBuilder) SetProgressFn(fn func(WeightBuildProgress)) { + b.progressFn = fn +} + // Build runs the full import pipeline for one weight: // // 1. Inventory the source. @@ -200,7 +206,7 @@ func (b *WeightBuilder) buildWithResolved(ctx context.Context, spec ArtifactSpec inv := weightsource.Inventory{Files: resolved.mergedFiles} // Step 2: ingress the filtered files into the local store. - if err := ingressFromInventory(ctx, resolved.owners, b.store, inv); err != nil { + if err := ingressFromInventoryWithProgress(ctx, ws.Name(), resolved.owners, b.store, inv, b.progressFn); err != nil { return nil, fmt.Errorf("populate store for weight %q: %w", ws.Name(), err) } diff --git a/pkg/weights/pull.go b/pkg/weights/pull.go index fb71ba1357..8ad6d416f9 100644 --- a/pkg/weights/pull.go +++ b/pkg/weights/pull.go @@ -6,6 +6,7 @@ import ( "errors" "fmt" "io" + "time" v1 "github.com/google/go-containerregistry/pkg/v1" @@ -53,6 +54,10 @@ type PullEvent struct { FileDigest string FileSize int64 + // FileProgress: per-file byte progress while streaming a file into + // the local store. + FileComplete int64 + // WeightDone: cumulative totals for the weight. FullyCached is // true when no registry I/O happened. BytesFetched int64 @@ -69,6 +74,7 @@ const ( PullEventUnknown PullEventKind = iota PullEventWeightStart PullEventLayerStart + PullEventFileProgress PullEventFileStored PullEventLayerDone PullEventWeightDone @@ -248,7 +254,32 @@ func (m *Manager) pullLayer( return fmt.Errorf("layer %s: unexpected file %q not in lockfile", layerDigest, hdr.Name) } - if err := m.store.PutFile(ctx, file.Digest, file.Size, tr); err != nil { + emit(PullEvent{ + Kind: PullEventFileProgress, + Weight: weightName, + LayerDigest: layerDigest, + FilePath: file.Path, + FileDigest: file.Digest, + FileSize: file.Size, + FileComplete: 0, + }) + + reader := &pullProgressReader{ + r: tr, + fn: func(complete int64) { + emit(PullEvent{ + Kind: PullEventFileProgress, + Weight: weightName, + LayerDigest: layerDigest, + FilePath: file.Path, + FileDigest: file.Digest, + FileSize: file.Size, + FileComplete: complete, + }) + }, + } + + if err := m.store.PutFile(ctx, file.Digest, file.Size, reader); err != nil { return fmt.Errorf("store %s (%s): %w", file.Path, file.Digest, err) } written[file.Path] = true @@ -271,3 +302,30 @@ func (m *Manager) pullLayer( emit(PullEvent{Kind: PullEventLayerDone, Weight: weightName, LayerDigest: layerDigest}) return nil } + +type pullProgressReader struct { + r io.Reader + complete int64 + lastReported int64 + lastUpdate time.Time + interval time.Duration + fn func(int64) +} + +func (r *pullProgressReader) Read(p []byte) (int, error) { + n, err := r.r.Read(p) + if n > 0 { + r.complete += int64(n) + now := time.Now() + interval := r.interval + if interval == 0 { + interval = 250 * time.Millisecond + } + if r.lastReported == 0 || now.Sub(r.lastUpdate) >= interval { + r.lastReported = r.complete + r.lastUpdate = now + r.fn(r.complete) + } + } + return n, err +} diff --git a/pkg/weights/pull_test.go b/pkg/weights/pull_test.go index ee8ee2ac77..8ba456238a 100644 --- a/pkg/weights/pull_test.go +++ b/pkg/weights/pull_test.go @@ -10,6 +10,7 @@ import ( "fmt" "io" "testing" + "time" v1 "github.com/google/go-containerregistry/pkg/v1" "github.com/google/go-containerregistry/pkg/v1/empty" @@ -572,21 +573,10 @@ func TestManager_Pull_EmitsEvents(t *testing.T) { _, err := mgr.Pull(ctx, nil, func(e PullEvent) { events = append(events, e) }) require.NoError(t, err) - // Expected sequence for a single weight with one layer of two - // files: WeightStart, LayerStart, FileStored x2, LayerDone, - // WeightDone. - kinds := make([]PullEventKind, len(events)) - for i, e := range events { - kinds[i] = e.Kind - } - require.Equal(t, []PullEventKind{ - PullEventWeightStart, - PullEventLayerStart, - PullEventFileStored, - PullEventFileStored, - PullEventLayerDone, - PullEventWeightDone, - }, kinds) + require.Equal(t, PullEventWeightStart, events[0].Kind) + require.Equal(t, PullEventLayerStart, events[1].Kind) + require.Equal(t, PullEventLayerDone, events[len(events)-2].Kind) + require.Equal(t, PullEventWeightDone, events[len(events)-1].Kind) // WeightStart carries the manifest reference and file counts. start := events[0] @@ -596,10 +586,29 @@ func TestManager_Pull_EmitsEvents(t *testing.T) { assert.Equal(t, 2, start.MissingFiles) assert.Equal(t, testRepo+"@"+entry.Digest, start.ManifestRef) - // FileStored events carry path + digest. - for _, e := range events[2:4] { + var progressEvents, storedEvents []PullEvent + for _, e := range events { + switch e.Kind { + case PullEventFileProgress: + progressEvents = append(progressEvents, e) + case PullEventFileStored: + storedEvents = append(storedEvents, e) + } + } + require.NotEmpty(t, progressEvents) + require.Len(t, storedEvents, 2) + + // File progress + stored events carry path, digest, and byte counts. + for _, e := range progressEvents { assert.NotEmpty(t, e.FilePath) assert.NotEmpty(t, e.FileDigest) + assert.GreaterOrEqual(t, e.FileComplete, int64(0)) + assert.LessOrEqual(t, e.FileComplete, e.FileSize) + } + for _, e := range storedEvents { + assert.NotEmpty(t, e.FilePath) + assert.NotEmpty(t, e.FileDigest) + assert.Greater(t, e.FileSize, int64(0)) } } @@ -634,6 +643,29 @@ func TestManager_Pull_EmitsFullyCachedEvent(t *testing.T) { assert.True(t, events[1].FullyCached) } +func TestPullProgressReaderThrottlesEvents(t *testing.T) { + t.Parallel() + + var events []int64 + r := &pullProgressReader{ + r: bytes.NewReader([]byte("abcdef")), + interval: time.Hour, + fn: func(complete int64) { + events = append(events, complete) + }, + } + + buf := make([]byte, 2) + n, err := r.Read(buf) + require.NoError(t, err) + assert.Equal(t, 2, n) + n, err = r.Read(buf) + require.NoError(t, err) + assert.Equal(t, 2, n) + + assert.Equal(t, []int64{2}, events) +} + func TestNewManager_RequiresStore(t *testing.T) { t.Parallel() _, err := NewManager(ManagerOptions{