@@ -7,29 +7,20 @@ import (
77 "log/slog"
88 "net"
99 "net/http"
10- "os"
11- "path/filepath"
1210 "sort"
13- "strings"
1411 "time"
1512
1613 "github.com/labstack/echo/v4"
1714 "github.com/labstack/echo/v4/middleware"
1815
1916 "github.com/docker/cagent/pkg/api"
20- "github.com/docker/cagent/pkg/concurrent"
2117 "github.com/docker/cagent/pkg/config"
22- "github.com/docker/cagent/pkg/runtime"
2318 "github.com/docker/cagent/pkg/session"
24- "github.com/docker/cagent/pkg/tools"
2519)
2620
2721type Server struct {
28- e * echo.Echo
29- runtimeCancels * concurrent.Map [string , context.CancelFunc ]
30- sessionStore session.Store
31- runConfig * config.RuntimeConfig
32- sm * sessionManager
22+ e * echo.Echo
23+ sm * sessionManager
3324}
3425
3526func New (ctx context.Context , sessionStore session.Store , runConfig * config.RuntimeConfig , refreshInterval time.Duration , agentSources config.Sources ) (* Server , error ) {
@@ -38,11 +29,8 @@ func New(ctx context.Context, sessionStore session.Store, runConfig *config.Runt
3829 e .Use (middleware .Logger ())
3930
4031 s := & Server {
41- e : e ,
42- runtimeCancels : concurrent .NewMap [string , context.CancelFunc ](),
43- sessionStore : sessionStore ,
44- runConfig : runConfig ,
45- sm : newSessionManager (ctx , agentSources , refreshInterval ),
32+ e : e ,
33+ sm : newSessionManager (ctx , agentSources , sessionStore , refreshInterval , runConfig ),
4634 }
4735
4836 group := e .Group ("/api" )
@@ -56,6 +44,8 @@ func New(ctx context.Context, sessionStore session.Store, runConfig *config.Runt
5644 group .GET ("/sessions/:id" , s .getSession )
5745 // Resume a session by id
5846 group .POST ("/sessions/:id/resume" , s .resumeSession )
47+ // Toggle YOLO mode for a session
48+ group .POST ("/sessions/:id/tools/toggle" , s .toggleSessionYolo )
5949 // Create a new session
6050 group .POST ("/sessions" , s .createSession )
6151 // Delete a session
@@ -125,7 +115,6 @@ func (s *Server) getAgents(c echo.Context) error {
125115 }
126116 }
127117
128- // Sort agents by name
129118 sort .Slice (agents , func (i , j int ) bool {
130119 return agents [i ].Name < agents [j ].Name
131120 })
@@ -134,7 +123,7 @@ func (s *Server) getAgents(c echo.Context) error {
134123}
135124
136125func (s * Server ) getSessions (c echo.Context ) error {
137- sessions , err := s .sessionStore .GetSessions (c .Request ().Context ())
126+ sessions , err := s .sm .GetSessions (c .Request ().Context ())
138127 if err != nil {
139128 return echo .NewHTTPError (http .StatusInternalServerError , fmt .Sprintf ("failed to get sessions: %v" , err ))
140129 }
@@ -160,42 +149,16 @@ func (s *Server) createSession(c echo.Context) error {
160149 return echo .NewHTTPError (http .StatusBadRequest , fmt .Sprintf ("invalid request body: %v" , err ))
161150 }
162151
163- var opts []session.Opt
164- opts = append (opts ,
165- session .WithMaxIterations (sessionTemplate .MaxIterations ),
166- session .WithToolsApproved (sessionTemplate .ToolsApproved ),
167- )
168-
169- if wd := strings .TrimSpace (sessionTemplate .WorkingDir ); wd != "" {
170- absWd , err := filepath .Abs (wd )
171- if err != nil {
172- slog .Error ("Invalid working directory" , "error" , err )
173- return echo .NewHTTPError (http .StatusBadRequest , fmt .Sprintf ("invalid working directory: %v" , err ))
174- }
175- info , err := os .Stat (absWd )
176- if err != nil {
177- slog .Error ("Working directory not accessible" , "error" , err )
178- return echo .NewHTTPError (http .StatusBadRequest , fmt .Sprintf ("working directory not accessible: %v" , err ))
179- }
180- if ! info .IsDir () {
181- slog .Error ("Working directory is not a directory" )
182- return echo .NewHTTPError (http .StatusBadRequest , "working directory must be a directory" )
183- }
184- opts = append (opts , session .WithWorkingDir (absWd ))
185- }
186-
187- sess := session .New (opts ... )
188-
189- if err := s .sessionStore .AddSession (c .Request ().Context (), sess ); err != nil {
190- slog .Error ("Failed to persist session" , "session_id" , sess .ID , "error" , err )
152+ sess , err := s .sm .CreateSession (c .Request ().Context (), & sessionTemplate )
153+ if err != nil {
191154 return echo .NewHTTPError (http .StatusInternalServerError , fmt .Sprintf ("failed to create session: %v" , err ))
192155 }
193156
194157 return c .JSON (http .StatusOK , sess )
195158}
196159
197160func (s * Server ) getSession (c echo.Context ) error {
198- sess , err := s .sessionStore .GetSession (c .Request ().Context (), c .Param ("id" ))
161+ sess , err := s .sm .GetSession (c .Request ().Context (), c .Param ("id" ))
199162 if err != nil {
200163 return echo .NewHTTPError (http .StatusNotFound , fmt .Sprintf ("session not found: %v" , err ))
201164 }
@@ -215,41 +178,29 @@ func (s *Server) getSession(c echo.Context) error {
215178}
216179
217180func (s * Server ) resumeSession (c echo.Context ) error {
218- sessionID := c .Param ("id" )
219181 var req api.ResumeSessionRequest
220182 if err := c .Bind (& req ); err != nil {
221183 return echo .NewHTTPError (http .StatusBadRequest , fmt .Sprintf ("invalid request body: %v" , err ))
222184 }
223185
224- rt , exists := s .sm .runtimes .Load (sessionID )
225- if ! exists {
226- return echo .NewHTTPError (http .StatusNotFound , fmt .Sprintf ("runtime not found: %s" , sessionID ))
186+ if err := s .sm .ResumeSession (c .Request ().Context (), c .Param ("id" ), req .Confirmation ); err != nil {
187+ return echo .NewHTTPError (http .StatusInternalServerError , fmt .Sprintf ("failed to resume session: %v" , err ))
227188 }
228189
229- rt .Resume (c .Request ().Context (), runtime .ResumeType (req .Confirmation ))
230-
231190 return c .JSON (http .StatusOK , map [string ]string {"message" : "session resumed" })
232191}
233192
234- func (s * Server ) deleteSession (c echo.Context ) error {
235- sessionID := c .Param ("id" )
236-
237- // Cancel the runtime context if it's still running
238- if cancel , exists := s .runtimeCancels .Load (sessionID ); exists {
239- slog .Debug ("Cancelling runtime for session" , "session_id" , sessionID )
240- cancel ()
241- s .runtimeCancels .Delete (sessionID )
193+ func (s * Server ) toggleSessionYolo (c echo.Context ) error {
194+ if err := s .sm .ToggleToolApproval (c .Request ().Context (), c .Param ("id" )); err != nil {
195+ return echo .NewHTTPError (http .StatusInternalServerError , fmt .Sprintf ("failed to toggle session tool approval mode: %v" , err ))
242196 }
197+ return c .JSON (http .StatusOK , nil )
198+ }
243199
244- // Clean up the runtime
245- if _ , exists := s .sm .runtimes .Load (sessionID ); exists {
246- slog .Debug ("Removing runtime for session" , "session_id" , sessionID )
247- s .sm .runtimes .Delete (sessionID )
248- }
200+ func (s * Server ) deleteSession (c echo.Context ) error {
201+ sessionID := c .Param ("id" )
249202
250- // Delete the session from storage
251- if err := s .sessionStore .DeleteSession (c .Request ().Context (), sessionID ); err != nil {
252- slog .Error ("Failed to delete session" , "session_id" , sessionID , "error" , err )
203+ if err := s .sm .DeleteSession (c .Request ().Context (), sessionID ); err != nil {
253204 return echo .NewHTTPError (http .StatusInternalServerError , fmt .Sprintf ("failed to delete session: %v" , err ))
254205 }
255206
@@ -266,48 +217,20 @@ func (s *Server) runAgent(c echo.Context) error {
266217
267218 slog .Debug ("Running agent" , "agent_filename" , agentFilename , "session_id" , sessionID , "current_agent" , currentAgent )
268219
269- // Build a per-session team so Filesystem tool can be bound to session working dir
270- sess , err := s .sessionStore .GetSession (c .Request ().Context (), sessionID )
271- if err != nil {
272- return echo .NewHTTPError (http .StatusNotFound , fmt .Sprintf ("session not found: %v" , err ))
273- }
274-
275- // Copy runConfig and inject per-session working dir override
276- rc := s .runConfig .Clone ()
277- rc .WorkingDir = sess .WorkingDir
278-
279- rt , err := s .sm .runtimeForSession (c .Request ().Context (), sess , agentFilename , currentAgent , rc )
280- if err != nil {
281- return echo .NewHTTPError (http .StatusInternalServerError , fmt .Sprintf ("failed to get runtime for session: %v" , err ))
282- }
283-
284- // Receive messages from the API client
285220 var messages []api.Message
286221 if err := json .NewDecoder (c .Request ().Body ).Decode (& messages ); err != nil {
287222 return echo .NewHTTPError (http .StatusBadRequest , fmt .Sprintf ("invalid request body: %v" , err ))
288223 }
289224
290- for _ , msg := range messages {
291- sess .AddMessage (session .UserMessage (msg .Content , msg .MultiContent ... ))
292- }
293-
294- if err := s .sessionStore .UpdateSession (c .Request ().Context (), sess ); err != nil {
295- slog .Error ("Failed to update session in store" , "session_id" , sess .ID , "error" , err )
296- return echo .NewHTTPError (http .StatusInternalServerError , fmt .Sprintf ("failed to update session: %v" , err ))
225+ streamChan , err := s .sm .RunSession (c .Request ().Context (), sessionID , agentFilename , currentAgent , messages )
226+ if err != nil {
227+ return echo .NewHTTPError (http .StatusInternalServerError , fmt .Sprintf ("failed to run session: %v" , err ))
297228 }
298229
299230 c .Response ().Header ().Set ("Content-Type" , "text/event-stream" )
300231 c .Response ().Header ().Set ("Cache-Control" , "no-cache" )
301232 c .Response ().Header ().Set ("Connection" , "keep-alive" )
302233 c .Response ().WriteHeader (http .StatusOK )
303-
304- streamCtx , cancel := context .WithCancel (c .Request ().Context ())
305- s .runtimeCancels .Store (sess .ID , cancel )
306- defer func () {
307- s .runtimeCancels .Delete (sess .ID )
308- }()
309-
310- streamChan := rt .RunStream (streamCtx , sess )
311234 for event := range streamChan {
312235 data , err := json .Marshal (event )
313236 if err != nil {
@@ -317,10 +240,6 @@ func (s *Server) runAgent(c echo.Context) error {
317240 c .Response ().Flush ()
318241 }
319242
320- if err := s .sessionStore .UpdateSession (c .Request ().Context (), sess ); err != nil {
321- slog .Error ("Failed to final update session in store" , "session_id" , sess .ID , "error" , err )
322- }
323-
324243 return nil
325244}
326245
@@ -331,12 +250,7 @@ func (s *Server) elicitation(c echo.Context) error {
331250 return echo .NewHTTPError (http .StatusBadRequest , fmt .Sprintf ("invalid request body: %v" , err ))
332251 }
333252
334- rt , exists := s .sm .runtimes .Load (sessionID )
335- if ! exists {
336- return c .JSON (http .StatusNotFound , map [string ]string {"error" : fmt .Sprintf ("runtime not found: %s" , sessionID )})
337- }
338-
339- if err := rt .ResumeElicitation (c .Request ().Context (), tools .ElicitationAction (req .Action ), req .Content ); err != nil {
253+ if err := s .sm .ResumeElicitation (c .Request ().Context (), sessionID , req .Action , req .Content ); err != nil {
340254 return echo .NewHTTPError (http .StatusInternalServerError , fmt .Sprintf ("failed to resume elicitation: %v" , err ))
341255 }
342256
0 commit comments