Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion pkg/modelsdev/store.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package modelsdev
import (
"context"
"encoding/json"
"errors"
"fmt"
"io"
"log/slog"
Expand All @@ -23,6 +24,9 @@ const (
refreshInterval = 24 * time.Hour
)

// ErrProviderNotFound is returned when a requested provider is not found in the database.
var ErrProviderNotFound = errors.New("provider not found")

// Store manages access to the models.dev data.
// All methods are safe for concurrent use.
//
Expand Down Expand Up @@ -91,7 +95,7 @@ func (s *Store) getProvider(ctx context.Context, providerID string) (*Provider,

provider, exists := db.Providers[providerID]
if !exists {
return nil, fmt.Errorf("provider %q not found", providerID)
return nil, fmt.Errorf("%w: %q", ErrProviderNotFound, providerID)
}

return &provider, nil
Expand Down
37 changes: 31 additions & 6 deletions pkg/runtime/session_compaction.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ import (
"github.com/docker/docker-agent/pkg/compaction"
"github.com/docker/docker-agent/pkg/model/provider"
"github.com/docker/docker-agent/pkg/model/provider/options"
"github.com/docker/docker-agent/pkg/modelsdev"
"github.com/docker/docker-agent/pkg/session"
"github.com/docker/docker-agent/pkg/team"
)
Expand All @@ -26,7 +27,9 @@ const maxKeepTokens = 20_000
// persistence, token count updates). The agent is used to extract the
// conversation from the session and to obtain the model for summarization.
func (r *LocalRuntime) doCompact(ctx context.Context, sess *session.Session, a *agent.Agent, additionalPrompt string, events chan Event) {
slog.Debug("Generating summary for session", "session_id", sess.ID)
lg := slog.With("session_id", sess.ID, "agent", a.Name(), "action", "compaction")

lg.Debug("Generating summary for session")
events <- SessionCompaction(sess.ID, "started", a.Name())
defer func() {
events <- SessionCompaction(sess.ID, "completed", a.Name())
Expand All @@ -37,10 +40,32 @@ func (r *LocalRuntime) doCompact(ctx context.Context, sess *session.Session, a *
options.WithStructuredOutput(nil),
options.WithMaxTokens(maxSummaryTokens),
)

m, err := r.modelsStore.GetModel(ctx, summaryModel.ID())
if err != nil && errors.Is(err, modelsdev.ErrProviderNotFound) {
lg.Debug("Provider not found; attempting to find by model name", "error", err)

db, dberr := r.modelsStore.GetDatabase(ctx)
if dberr != nil {
lg.Error("Provider not found and failed to find by model name", "error", dberr)
events <- Error("Failed to get db to find model definition: " + dberr.Error())
return
}

// Find the lowest context limit for this model, regardless of the provider.
for _, provider := range db.Providers {
if v, ok := provider.Models[summaryModel.BaseConfig().ModelConfig.Model]; ok {
if m == nil || v.Limit.Context < m.Limit.Context {
m = &v
err = nil
}
}
}
}

if err != nil {
slog.Error("Failed to generate session summary", "error", errors.New("failed to get model definition"))
events <- Error("Failed to get model definition")
lg.Error("Failed to get model definition to generate session summary", "error", err)
events <- Error("Failed to get model definition: " + err.Error())
return
}

Expand All @@ -58,12 +83,12 @@ func (r *LocalRuntime) doCompact(ctx context.Context, sess *session.Session, a *
t := team.New(team.WithAgents(compactionAgent))
rt, err := New(t, WithSessionCompaction(false))
if err != nil {
slog.Error("Failed to generate session summary", "error", err)
lg.Error("Failed to generate session summary", "error", err)
events <- Error(err.Error())
return
}
if _, err = rt.Run(ctx, compactionSession); err != nil {
slog.Error("Failed to generate session summary", "error", err)
lg.Error("Failed to generate session summary", "error", err)
events <- Error(err.Error())
return
}
Expand All @@ -83,7 +108,7 @@ func (r *LocalRuntime) doCompact(ctx context.Context, sess *session.Session, a *
})
_ = r.sessionStore.UpdateSession(ctx, sess)

slog.Debug("Generated session summary", "session_id", sess.ID, "summary_length", len(summary))
lg.Debug("Generated session summary", "summary_length", len(summary))
events <- SessionSummary(sess.ID, summary, a.Name(), firstKeptEntry)
}

Expand Down