Skip to content

Commit 51a4255

Browse files
authored
Merge pull request #1056 from rumpl/runtime-session-store
Add the session store to the runtime
2 parents 783768f + 02daf90 commit 51a4255

6 files changed

Lines changed: 118 additions & 27 deletions

File tree

cmd/root/run.go

Lines changed: 20 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -6,13 +6,15 @@ import (
66
"io"
77
"log/slog"
88
"os"
9+
"path/filepath"
910

1011
"github.com/mattn/go-isatty"
1112
"github.com/spf13/cobra"
1213
"go.opentelemetry.io/otel"
1314

1415
"github.com/docker/cagent/pkg/cli"
1516
"github.com/docker/cagent/pkg/config"
17+
"github.com/docker/cagent/pkg/paths"
1618
"github.com/docker/cagent/pkg/runtime"
1719
"github.com/docker/cagent/pkg/session"
1820
"github.com/docker/cagent/pkg/team"
@@ -28,6 +30,7 @@ type runExecFlags struct {
2830
modelOverrides []string
2931
dryRun bool
3032
runConfig config.RuntimeConfig
33+
sessionDB string
3134

3235
// Exec only
3336
hideToolCalls bool
@@ -64,6 +67,7 @@ func addRunOrExecFlags(cmd *cobra.Command, flags *runExecFlags) {
6467
cmd.PersistentFlags().StringArrayVar(&flags.modelOverrides, "model", nil, "Override agent model: [agent=]provider/model (repeatable)")
6568
cmd.PersistentFlags().BoolVar(&flags.dryRun, "dry-run", false, "Initialize the agent without executing anything")
6669
cmd.PersistentFlags().StringVar(&flags.remoteAddress, "remote", "", "Use remote runtime with specified address")
70+
cmd.PersistentFlags().StringVarP(&flags.sessionDB, "session-db", "s", filepath.Join(paths.GetHomeDir(), ".cagent", "session.db"), "Path to the session database")
6771
}
6872

6973
func (f *runExecFlags) runRunCommand(cmd *cobra.Command, args []string) error {
@@ -105,7 +109,7 @@ func (f *runExecFlags) runOrExec(ctx context.Context, out *cli.Printer, args []s
105109
return err
106110
}
107111

108-
rt, sess, err = f.createLocalRuntimeAndSession(t)
112+
rt, sess, err = f.createLocalRuntimeAndSession(ctx, t)
109113
if err != nil {
110114
return err
111115
}
@@ -166,26 +170,35 @@ func (f *runExecFlags) createRemoteRuntimeAndSession(ctx context.Context, origin
166170
return remoteRt, sess, nil
167171
}
168172

169-
func (f *runExecFlags) createLocalRuntimeAndSession(t *team.Team) (runtime.Runtime, *session.Session, error) {
173+
func (f *runExecFlags) createLocalRuntimeAndSession(ctx context.Context, t *team.Team) (runtime.Runtime, *session.Session, error) {
170174
agent, err := t.Agent(f.agentName)
171175
if err != nil {
172176
return nil, nil, err
173177
}
174178

175-
sess := session.New(
176-
session.WithMaxIterations(agent.MaxIterations()),
177-
session.WithToolsApproved(f.autoApprove),
178-
)
179+
sessStore, err := session.NewSQLiteSessionStore(f.sessionDB)
180+
if err != nil {
181+
return nil, nil, fmt.Errorf("failed to create session store: %w", err)
182+
}
179183

180184
localRt, err := runtime.New(t,
185+
runtime.WithSessionStore(sessStore),
181186
runtime.WithCurrentAgent(f.agentName),
182187
runtime.WithTracer(otel.Tracer(AppName)),
183-
runtime.WithRootSessionID(sess.ID),
184188
)
185189
if err != nil {
186190
return nil, nil, fmt.Errorf("failed to create runtime: %w", err)
187191
}
188192

193+
sess := session.New(
194+
session.WithMaxIterations(agent.MaxIterations()),
195+
session.WithToolsApproved(f.autoApprove),
196+
)
197+
198+
if err := sessStore.AddSession(ctx, sess); err != nil {
199+
return nil, nil, err
200+
}
201+
189202
slog.Debug("Using local runtime", "agent", f.agentName)
190203
return localRt, sess, nil
191204
}

e2e/cagent_exec_test.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -104,7 +104,7 @@ func cagentExec(t *testing.T, moreArgs ...string) string {
104104

105105
// Start a recording AI proxy to record and replay traffic.
106106
svr, _ := startRecordingAIProxy(t)
107-
args = append(args, "--models-gateway", svr.URL)
107+
args = append(args, "--models-gateway", svr.URL, "--session-db", "/tmp/session.db")
108108

109109
// Run cagent exec
110110
var stdout bytes.Buffer

pkg/mcp/server.go

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -167,7 +167,6 @@ func CreateToolHandler(t *team.Team, agentName string) func(context.Context, *mc
167167

168168
rt, err := runtime.New(t,
169169
runtime.WithCurrentAgent(agentName),
170-
runtime.WithRootSessionID(sess.ID),
171170
)
172171
if err != nil {
173172
return nil, ToolOutput{}, fmt.Errorf("failed to create runtime: %w", err)

pkg/runtime/runtime.go

Lines changed: 31 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,10 @@ import (
3232
mcptools "github.com/docker/cagent/pkg/tools/mcp"
3333
)
3434

35+
type SessionStore interface {
36+
UpdateSession(ctx context.Context, sess *session.Session) error
37+
}
38+
3539
// UnwrapMCPToolset extracts an MCP toolset from a potentially wrapped StartableToolSet.
3640
// Returns the MCP toolset if found, or nil if the toolset is not an MCP toolset.
3741
func UnwrapMCPToolset(toolset tools.ToolSet) *mcptools.Toolset {
@@ -94,16 +98,18 @@ type Runtime interface {
9498
CurrentWelcomeMessage(ctx context.Context) string
9599
// EmitStartupInfo emits initial agent, team, and toolset information for immediate display
96100
EmitStartupInfo(ctx context.Context, events chan Event)
101+
97102
// RunStream starts the agent's interaction loop and returns a channel of events
98103
RunStream(ctx context.Context, sess *session.Session) <-chan Event
99104
// Run starts the agent's interaction loop and returns the final messages
100105
Run(ctx context.Context, sess *session.Session) ([]session.Message, error)
101106
// Resume allows resuming execution after user confirmation
102107
Resume(ctx context.Context, confirmationType ResumeType)
103-
// Summarize generates a summary for the session
104-
Summarize(ctx context.Context, sess *session.Session, events chan Event)
105108
// ResumeElicitation sends an elicitation response back to a waiting elicitation request
106109
ResumeElicitation(_ context.Context, action tools.ElicitationAction, content map[string]any) error
110+
111+
// Summarize generates a summary for the session
112+
Summarize(ctx context.Context, sess *session.Session, events chan Event)
107113
}
108114

109115
type ModelStore interface {
@@ -121,7 +127,6 @@ type LocalRuntime struct {
121127
toolMap map[string]ToolHandler
122128
team *team.Team
123129
currentAgent string
124-
rootSessionID string // Root session ID for OAuth state encoding (preserved across sub-sessions)
125130
resumeChan chan ResumeType
126131
tracer trace.Tracer
127132
modelsStore ModelStore
@@ -133,6 +138,7 @@ type LocalRuntime struct {
133138
elicitationEventsChannelMux sync.RWMutex // Protects elicitationEventsChannel
134139
ragInitialized atomic.Bool
135140
titleGen *titleGenerator
141+
sessionStore SessionStore
136142
}
137143

138144
type streamResult struct {
@@ -158,12 +164,6 @@ func WithManagedOAuth(managed bool) Opt {
158164
}
159165
}
160166

161-
func WithRootSessionID(sessionID string) Opt {
162-
return func(r *LocalRuntime) {
163-
r.rootSessionID = sessionID
164-
}
165-
}
166-
167167
// WithTracer sets a custom OpenTelemetry tracer; if not provided, tracing is disabled (no-op).
168168
func WithTracer(t trace.Tracer) Opt {
169169
return func(r *LocalRuntime) {
@@ -183,6 +183,12 @@ func WithModelStore(store ModelStore) Opt {
183183
}
184184
}
185185

186+
func WithSessionStore(store SessionStore) Opt {
187+
return func(r *LocalRuntime) {
188+
r.sessionStore = store
189+
}
190+
}
191+
186192
// New creates a new runtime for an agent and its team
187193
func New(agents *team.Team, opts ...Opt) (*LocalRuntime, error) {
188194
modelsStore, err := modelsdev.NewStore()
@@ -199,6 +205,7 @@ func New(agents *team.Team, opts ...Opt) (*LocalRuntime, error) {
199205
modelsStore: modelsStore,
200206
sessionCompaction: true,
201207
managedOAuth: true,
208+
sessionStore: session.NewInMemorySessionStore(),
202209
}
203210

204211
for _, opt := range opts {
@@ -610,6 +617,7 @@ func (r *LocalRuntime) RunStream(ctx context.Context, sess *session.Session) <-c
610617
CreatedAt: time.Now().Format(time.RFC3339),
611618
}
612619
sess.AddMessage(session.NewAgentMessage(a, &assistantMessage))
620+
_ = r.sessionStore.UpdateSession(ctx, sess)
613621
return
614622
}
615623
case <-ctx.Done():
@@ -707,6 +715,7 @@ func (r *LocalRuntime) RunStream(ctx context.Context, sess *session.Session) <-c
707715
}
708716

709717
sess.AddMessage(session.NewAgentMessage(a, &assistantMessage))
718+
_ = r.sessionStore.UpdateSession(ctx, sess)
710719
slog.Debug("Added assistant message to session", "agent", a.Name(), "total_messages", len(sess.GetAllMessages()))
711720
} else {
712721
slog.Debug("Skipping empty assistant message (no content and no tool calls)", "agent", a.Name())
@@ -1023,14 +1032,14 @@ func (r *LocalRuntime) processToolCalls(ctx context.Context, sess *session.Sessi
10231032
r.runAgentTool(callCtx, def.handler, sess, toolCall, def.tool, events, a)
10241033
case ResumeTypeReject:
10251034
slog.Debug("Resume signal received, rejecting tool handler", "tool", toolCall.Function.Name, "session_id", sess.ID)
1026-
r.addToolRejectedResponse(sess, toolCall, def.tool, events)
1035+
r.addToolRejectedResponse(ctx, sess, toolCall, def.tool, events)
10271036
}
10281037
case <-callCtx.Done():
10291038
slog.Debug("Context cancelled while waiting for resume", "tool", toolCall.Function.Name, "session_id", sess.ID)
10301039
// Synthesize cancellation responses for the current and any remaining tool calls
1031-
r.addToolCancelledResponse(sess, toolCall, def.tool, events)
1040+
r.addToolCancelledResponse(ctx, sess, toolCall, def.tool, events)
10321041
for j := i + 1; j < len(calls); j++ {
1033-
r.addToolCancelledResponse(sess, calls[j], def.tool, events)
1042+
r.addToolCancelledResponse(ctx, sess, calls[j], def.tool, events)
10341043
}
10351044
callSpan.SetStatus(codes.Ok, "tool call canceled by user")
10361045
return
@@ -1066,17 +1075,17 @@ func (r *LocalRuntime) processToolCalls(ctx context.Context, sess *session.Sessi
10661075
r.runTool(callCtx, tool, toolCall, events, sess, a)
10671076
case ResumeTypeReject:
10681077
slog.Debug("Resume signal received, rejecting tool handler", "tool", toolCall.Function.Name, "session_id", sess.ID)
1069-
r.addToolRejectedResponse(sess, toolCall, tool, events)
1078+
r.addToolRejectedResponse(ctx, sess, toolCall, tool, events)
10701079
}
10711080

10721081
slog.Debug("Added tool response to session", "tool", toolCall.Function.Name, "session_id", sess.ID, "total_messages", len(sess.GetAllMessages()))
10731082
break toolLoop
10741083
case <-callCtx.Done():
10751084
slog.Debug("Context cancelled while waiting for resume", "tool", toolCall.Function.Name, "session_id", sess.ID)
10761085
// Synthesize cancellation responses for the current and any remaining tool calls
1077-
r.addToolCancelledResponse(sess, toolCall, tool, events)
1086+
r.addToolCancelledResponse(ctx, sess, toolCall, tool, events)
10781087
for j := i + 1; j < len(calls); j++ {
1079-
r.addToolCancelledResponse(sess, calls[j], tool, events)
1088+
r.addToolCancelledResponse(ctx, sess, calls[j], tool, events)
10801089
}
10811090
callSpan.SetStatus(codes.Ok, "tool call canceled by user")
10821091
return
@@ -1146,6 +1155,7 @@ func (r *LocalRuntime) runTool(ctx context.Context, tool tools.Tool, toolCall to
11461155
CreatedAt: time.Now().Format(time.RFC3339),
11471156
}
11481157
sess.AddMessage(session.NewAgentMessage(a, &toolResponseMsg))
1158+
_ = r.sessionStore.UpdateSession(ctx, sess)
11491159
}
11501160

11511161
func (r *LocalRuntime) runAgentTool(ctx context.Context, handler ToolHandlerFunc, sess *session.Session, toolCall tools.ToolCall, tool tools.Tool, events chan Event, a *agent.Agent) {
@@ -1199,9 +1209,10 @@ func (r *LocalRuntime) runAgentTool(ctx context.Context, handler ToolHandlerFunc
11991209
CreatedAt: time.Now().Format(time.RFC3339),
12001210
}
12011211
sess.AddMessage(session.NewAgentMessage(a, &toolResponseMsg))
1212+
_ = r.sessionStore.UpdateSession(ctx, sess)
12021213
}
12031214

1204-
func (r *LocalRuntime) addToolRejectedResponse(sess *session.Session, toolCall tools.ToolCall, tool tools.Tool, events chan Event) {
1215+
func (r *LocalRuntime) addToolRejectedResponse(ctx context.Context, sess *session.Session, toolCall tools.ToolCall, tool tools.Tool, events chan Event) {
12051216
a := r.CurrentAgent()
12061217

12071218
result := "The user rejected the tool call."
@@ -1215,9 +1226,10 @@ func (r *LocalRuntime) addToolRejectedResponse(sess *session.Session, toolCall t
12151226
CreatedAt: time.Now().Format(time.RFC3339),
12161227
}
12171228
sess.AddMessage(session.NewAgentMessage(a, &toolResponseMsg))
1229+
_ = r.sessionStore.UpdateSession(ctx, sess)
12181230
}
12191231

1220-
func (r *LocalRuntime) addToolCancelledResponse(sess *session.Session, toolCall tools.ToolCall, tool tools.Tool, events chan Event) {
1232+
func (r *LocalRuntime) addToolCancelledResponse(ctx context.Context, sess *session.Session, toolCall tools.ToolCall, tool tools.Tool, events chan Event) {
12211233
a := r.CurrentAgent()
12221234

12231235
result := "The tool call was canceled by the user."
@@ -1231,6 +1243,7 @@ func (r *LocalRuntime) addToolCancelledResponse(sess *session.Session, toolCall
12311243
CreatedAt: time.Now().Format(time.RFC3339),
12321244
}
12331245
sess.AddMessage(session.NewAgentMessage(a, &toolResponseMsg))
1246+
_ = r.sessionStore.UpdateSession(ctx, sess)
12341247
}
12351248

12361249
// startSpan wraps tracer.Start, returning a no-op span if the tracer is nil.
@@ -1418,6 +1431,7 @@ func (r *LocalRuntime) Summarize(ctx context.Context, sess *session.Session, eve
14181431
}
14191432
// Add the summary to the session as a summary item
14201433
sess.Messages = append(sess.Messages, session.Item{Summary: summary})
1434+
_ = r.sessionStore.UpdateSession(ctx, sess)
14211435
slog.Debug("Generated session summary", "session_id", sess.ID, "summary_length", len(summary))
14221436
events <- SessionSummary(sess.ID, summary, r.currentAgent)
14231437
}

pkg/server/session_manager.go

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -132,6 +132,7 @@ func (sm *sessionManager) RunSession(ctx context.Context, sessionID, agentFilena
132132
for _, msg := range messages {
133133
sess.AddMessage(session.UserMessage(msg.Content, msg.MultiContent...))
134134
}
135+
135136
if err := sm.sessionStore.UpdateSession(ctx, sess); err != nil {
136137
return nil, err
137138
}
@@ -228,7 +229,7 @@ func (sm *sessionManager) runtimeForSession(ctx context.Context, sess *session.S
228229
opts := []runtime.Opt{
229230
runtime.WithCurrentAgent(currentAgent),
230231
runtime.WithManagedOAuth(false),
231-
runtime.WithRootSessionID(sess.ID),
232+
runtime.WithSessionStore(sm.sessionStore),
232233
}
233234
run, err := runtime.New(t, opts...)
234235
if err != nil {

pkg/session/store.go

Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,8 @@ import (
99
"time"
1010

1111
_ "modernc.org/sqlite"
12+
13+
"github.com/docker/cagent/pkg/concurrent"
1214
)
1315

1416
var (
@@ -34,6 +36,68 @@ type Store interface {
3436
UpdateSession(ctx context.Context, session *Session) error
3537
}
3638

39+
type InMemorySessionStore struct {
40+
sessions *concurrent.Map[string, *Session]
41+
}
42+
43+
func NewInMemorySessionStore() Store {
44+
return &InMemorySessionStore{
45+
sessions: concurrent.NewMap[string, *Session](),
46+
}
47+
}
48+
49+
func (s *InMemorySessionStore) AddSession(ctx context.Context, session *Session) error {
50+
if session.ID == "" {
51+
return ErrEmptyID
52+
}
53+
s.sessions.Store(session.ID, session)
54+
return nil
55+
}
56+
57+
func (s *InMemorySessionStore) GetSession(ctx context.Context, id string) (*Session, error) {
58+
if id == "" {
59+
return nil, ErrEmptyID
60+
}
61+
session, exists := s.sessions.Load(id)
62+
if !exists {
63+
return nil, ErrNotFound
64+
}
65+
return session, nil
66+
}
67+
68+
func (s *InMemorySessionStore) GetSessions(ctx context.Context) ([]*Session, error) {
69+
sessions := make([]*Session, 0, s.sessions.Length())
70+
s.sessions.Range(func(key string, value *Session) bool {
71+
sessions = append(sessions, value)
72+
return true
73+
})
74+
return sessions, nil
75+
}
76+
77+
func (s *InMemorySessionStore) DeleteSession(ctx context.Context, id string) error {
78+
if id == "" {
79+
return ErrEmptyID
80+
}
81+
_, exists := s.sessions.Load(id)
82+
if !exists {
83+
return ErrNotFound
84+
}
85+
s.sessions.Delete(id)
86+
return nil
87+
}
88+
89+
func (s *InMemorySessionStore) UpdateSession(ctx context.Context, session *Session) error {
90+
if session.ID == "" {
91+
return ErrEmptyID
92+
}
93+
_, exists := s.sessions.Load(session.ID)
94+
if !exists {
95+
return ErrNotFound
96+
}
97+
s.sessions.Store(session.ID, session)
98+
return nil
99+
}
100+
37101
// SQLiteSessionStore implements Store using SQLite
38102
type SQLiteSessionStore struct {
39103
db *sql.DB

0 commit comments

Comments
 (0)