Skip to content

Commit 4ef52cb

Browse files
committed
Session tool approval toggle
Add an endpoint to toggle the automatic tool aproval of a session. Got carried away and extracted all the real logic away form the server into a session manager... Signed-off-by: Djordje Lukic <djordje.lukic@docker.com>
1 parent 10ce3f6 commit 4ef52cb

2 files changed

Lines changed: 212 additions & 126 deletions

File tree

pkg/server/server.go

Lines changed: 24 additions & 110 deletions
Original file line numberDiff line numberDiff line change
@@ -7,29 +7,20 @@ import (
77
"log/slog"
88
"net"
99
"net/http"
10-
"os"
11-
"path/filepath"
1210
"sort"
13-
"strings"
1411
"time"
1512

1613
"github.com/labstack/echo/v4"
1714
"github.com/labstack/echo/v4/middleware"
1815

1916
"github.com/docker/cagent/pkg/api"
20-
"github.com/docker/cagent/pkg/concurrent"
2117
"github.com/docker/cagent/pkg/config"
22-
"github.com/docker/cagent/pkg/runtime"
2318
"github.com/docker/cagent/pkg/session"
24-
"github.com/docker/cagent/pkg/tools"
2519
)
2620

2721
type Server struct {
28-
e *echo.Echo
29-
runtimeCancels *concurrent.Map[string, context.CancelFunc]
30-
sessionStore session.Store
31-
runConfig *config.RuntimeConfig
32-
sm *sessionManager
22+
e *echo.Echo
23+
sm *sessionManager
3324
}
3425

3526
func New(ctx context.Context, sessionStore session.Store, runConfig *config.RuntimeConfig, refreshInterval time.Duration, agentSources config.Sources) (*Server, error) {
@@ -38,11 +29,8 @@ func New(ctx context.Context, sessionStore session.Store, runConfig *config.Runt
3829
e.Use(middleware.Logger())
3930

4031
s := &Server{
41-
e: e,
42-
runtimeCancels: concurrent.NewMap[string, context.CancelFunc](),
43-
sessionStore: sessionStore,
44-
runConfig: runConfig,
45-
sm: newSessionManager(ctx, agentSources, refreshInterval),
32+
e: e,
33+
sm: newSessionManager(ctx, agentSources, sessionStore, refreshInterval, runConfig),
4634
}
4735

4836
group := e.Group("/api")
@@ -56,6 +44,8 @@ func New(ctx context.Context, sessionStore session.Store, runConfig *config.Runt
5644
group.GET("/sessions/:id", s.getSession)
5745
// Resume a session by id
5846
group.POST("/sessions/:id/resume", s.resumeSession)
47+
// Toggle YOLO mode for a session
48+
group.POST("/sessions/:id/tools/toggle", s.toggleSessionYolo)
5949
// Create a new session
6050
group.POST("/sessions", s.createSession)
6151
// Delete a session
@@ -125,7 +115,6 @@ func (s *Server) getAgents(c echo.Context) error {
125115
}
126116
}
127117

128-
// Sort agents by name
129118
sort.Slice(agents, func(i, j int) bool {
130119
return agents[i].Name < agents[j].Name
131120
})
@@ -134,7 +123,7 @@ func (s *Server) getAgents(c echo.Context) error {
134123
}
135124

136125
func (s *Server) getSessions(c echo.Context) error {
137-
sessions, err := s.sessionStore.GetSessions(c.Request().Context())
126+
sessions, err := s.sm.GetSessions(c.Request().Context())
138127
if err != nil {
139128
return echo.NewHTTPError(http.StatusInternalServerError, fmt.Sprintf("failed to get sessions: %v", err))
140129
}
@@ -160,42 +149,16 @@ func (s *Server) createSession(c echo.Context) error {
160149
return echo.NewHTTPError(http.StatusBadRequest, fmt.Sprintf("invalid request body: %v", err))
161150
}
162151

163-
var opts []session.Opt
164-
opts = append(opts,
165-
session.WithMaxIterations(sessionTemplate.MaxIterations),
166-
session.WithToolsApproved(sessionTemplate.ToolsApproved),
167-
)
168-
169-
if wd := strings.TrimSpace(sessionTemplate.WorkingDir); wd != "" {
170-
absWd, err := filepath.Abs(wd)
171-
if err != nil {
172-
slog.Error("Invalid working directory", "error", err)
173-
return echo.NewHTTPError(http.StatusBadRequest, fmt.Sprintf("invalid working directory: %v", err))
174-
}
175-
info, err := os.Stat(absWd)
176-
if err != nil {
177-
slog.Error("Working directory not accessible", "error", err)
178-
return echo.NewHTTPError(http.StatusBadRequest, fmt.Sprintf("working directory not accessible: %v", err))
179-
}
180-
if !info.IsDir() {
181-
slog.Error("Working directory is not a directory")
182-
return echo.NewHTTPError(http.StatusBadRequest, "working directory must be a directory")
183-
}
184-
opts = append(opts, session.WithWorkingDir(absWd))
185-
}
186-
187-
sess := session.New(opts...)
188-
189-
if err := s.sessionStore.AddSession(c.Request().Context(), sess); err != nil {
190-
slog.Error("Failed to persist session", "session_id", sess.ID, "error", err)
152+
sess, err := s.sm.CreateSession(c.Request().Context(), &sessionTemplate)
153+
if err != nil {
191154
return echo.NewHTTPError(http.StatusInternalServerError, fmt.Sprintf("failed to create session: %v", err))
192155
}
193156

194157
return c.JSON(http.StatusOK, sess)
195158
}
196159

197160
func (s *Server) getSession(c echo.Context) error {
198-
sess, err := s.sessionStore.GetSession(c.Request().Context(), c.Param("id"))
161+
sess, err := s.sm.GetSession(c.Request().Context(), c.Param("id"))
199162
if err != nil {
200163
return echo.NewHTTPError(http.StatusNotFound, fmt.Sprintf("session not found: %v", err))
201164
}
@@ -215,41 +178,29 @@ func (s *Server) getSession(c echo.Context) error {
215178
}
216179

217180
func (s *Server) resumeSession(c echo.Context) error {
218-
sessionID := c.Param("id")
219181
var req api.ResumeSessionRequest
220182
if err := c.Bind(&req); err != nil {
221183
return echo.NewHTTPError(http.StatusBadRequest, fmt.Sprintf("invalid request body: %v", err))
222184
}
223185

224-
rt, exists := s.sm.runtimes.Load(sessionID)
225-
if !exists {
226-
return echo.NewHTTPError(http.StatusNotFound, fmt.Sprintf("runtime not found: %s", sessionID))
186+
if err := s.sm.ResumeSession(c.Request().Context(), c.Param("id"), req.Confirmation); err != nil {
187+
return echo.NewHTTPError(http.StatusInternalServerError, fmt.Sprintf("failed to resume session: %v", err))
227188
}
228189

229-
rt.Resume(c.Request().Context(), runtime.ResumeType(req.Confirmation))
230-
231190
return c.JSON(http.StatusOK, map[string]string{"message": "session resumed"})
232191
}
233192

234-
func (s *Server) deleteSession(c echo.Context) error {
235-
sessionID := c.Param("id")
236-
237-
// Cancel the runtime context if it's still running
238-
if cancel, exists := s.runtimeCancels.Load(sessionID); exists {
239-
slog.Debug("Cancelling runtime for session", "session_id", sessionID)
240-
cancel()
241-
s.runtimeCancels.Delete(sessionID)
193+
func (s *Server) toggleSessionYolo(c echo.Context) error {
194+
if err := s.sm.ToggleToolApproval(c.Request().Context(), c.Param("id")); err != nil {
195+
return echo.NewHTTPError(http.StatusInternalServerError, fmt.Sprintf("failed to toggle session tool approval mode: %v", err))
242196
}
197+
return c.JSON(http.StatusOK, nil)
198+
}
243199

244-
// Clean up the runtime
245-
if _, exists := s.sm.runtimes.Load(sessionID); exists {
246-
slog.Debug("Removing runtime for session", "session_id", sessionID)
247-
s.sm.runtimes.Delete(sessionID)
248-
}
200+
func (s *Server) deleteSession(c echo.Context) error {
201+
sessionID := c.Param("id")
249202

250-
// Delete the session from storage
251-
if err := s.sessionStore.DeleteSession(c.Request().Context(), sessionID); err != nil {
252-
slog.Error("Failed to delete session", "session_id", sessionID, "error", err)
203+
if err := s.sm.DeleteSession(c.Request().Context(), sessionID); err != nil {
253204
return echo.NewHTTPError(http.StatusInternalServerError, fmt.Sprintf("failed to delete session: %v", err))
254205
}
255206

@@ -266,48 +217,20 @@ func (s *Server) runAgent(c echo.Context) error {
266217

267218
slog.Debug("Running agent", "agent_filename", agentFilename, "session_id", sessionID, "current_agent", currentAgent)
268219

269-
// Build a per-session team so Filesystem tool can be bound to session working dir
270-
sess, err := s.sessionStore.GetSession(c.Request().Context(), sessionID)
271-
if err != nil {
272-
return echo.NewHTTPError(http.StatusNotFound, fmt.Sprintf("session not found: %v", err))
273-
}
274-
275-
// Copy runConfig and inject per-session working dir override
276-
rc := s.runConfig.Clone()
277-
rc.WorkingDir = sess.WorkingDir
278-
279-
rt, err := s.sm.runtimeForSession(c.Request().Context(), sess, agentFilename, currentAgent, rc)
280-
if err != nil {
281-
return echo.NewHTTPError(http.StatusInternalServerError, fmt.Sprintf("failed to get runtime for session: %v", err))
282-
}
283-
284-
// Receive messages from the API client
285220
var messages []api.Message
286221
if err := json.NewDecoder(c.Request().Body).Decode(&messages); err != nil {
287222
return echo.NewHTTPError(http.StatusBadRequest, fmt.Sprintf("invalid request body: %v", err))
288223
}
289224

290-
for _, msg := range messages {
291-
sess.AddMessage(session.UserMessage(msg.Content, msg.MultiContent...))
292-
}
293-
294-
if err := s.sessionStore.UpdateSession(c.Request().Context(), sess); err != nil {
295-
slog.Error("Failed to update session in store", "session_id", sess.ID, "error", err)
296-
return echo.NewHTTPError(http.StatusInternalServerError, fmt.Sprintf("failed to update session: %v", err))
225+
streamChan, err := s.sm.RunSession(c.Request().Context(), sessionID, agentFilename, currentAgent, messages)
226+
if err != nil {
227+
return echo.NewHTTPError(http.StatusInternalServerError, fmt.Sprintf("failed to run session: %v", err))
297228
}
298229

299230
c.Response().Header().Set("Content-Type", "text/event-stream")
300231
c.Response().Header().Set("Cache-Control", "no-cache")
301232
c.Response().Header().Set("Connection", "keep-alive")
302233
c.Response().WriteHeader(http.StatusOK)
303-
304-
streamCtx, cancel := context.WithCancel(c.Request().Context())
305-
s.runtimeCancels.Store(sess.ID, cancel)
306-
defer func() {
307-
s.runtimeCancels.Delete(sess.ID)
308-
}()
309-
310-
streamChan := rt.RunStream(streamCtx, sess)
311234
for event := range streamChan {
312235
data, err := json.Marshal(event)
313236
if err != nil {
@@ -317,10 +240,6 @@ func (s *Server) runAgent(c echo.Context) error {
317240
c.Response().Flush()
318241
}
319242

320-
if err := s.sessionStore.UpdateSession(c.Request().Context(), sess); err != nil {
321-
slog.Error("Failed to final update session in store", "session_id", sess.ID, "error", err)
322-
}
323-
324243
return nil
325244
}
326245

@@ -331,12 +250,7 @@ func (s *Server) elicitation(c echo.Context) error {
331250
return echo.NewHTTPError(http.StatusBadRequest, fmt.Sprintf("invalid request body: %v", err))
332251
}
333252

334-
rt, exists := s.sm.runtimes.Load(sessionID)
335-
if !exists {
336-
return c.JSON(http.StatusNotFound, map[string]string{"error": fmt.Sprintf("runtime not found: %s", sessionID)})
337-
}
338-
339-
if err := rt.ResumeElicitation(c.Request().Context(), tools.ElicitationAction(req.Action), req.Content); err != nil {
253+
if err := s.sm.ResumeElicitation(c.Request().Context(), sessionID, req.Action, req.Content); err != nil {
340254
return echo.NewHTTPError(http.StatusInternalServerError, fmt.Sprintf("failed to resume elicitation: %v", err))
341255
}
342256

0 commit comments

Comments
 (0)