Skip to content

Commit fb6347f

Browse files
authored
Merge pull request #972 from rumpl/feat-server
Session manager for the server
2 parents 42bd35a + 9290482 commit fb6347f

6 files changed

Lines changed: 369 additions & 66 deletions

File tree

cmd/root/api.go

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ import (
44
"fmt"
55
"log/slog"
66
"os"
7+
"time"
78

89
"github.com/spf13/cobra"
910

@@ -77,12 +78,11 @@ func (f *apiFlags) runAPICommand(cmd *cobra.Command, args []string) error {
7778
if err != nil {
7879
return fmt.Errorf("failed to resolve agent sources: %w", err)
7980
}
80-
s, err := server.New(sessionStore, &f.runConfig, sources)
81+
82+
s, err := server.New(ctx, sessionStore, &f.runConfig, time.Duration(f.pullIntervalMins)*time.Minute, sources)
8183
if err != nil {
8284
return fmt.Errorf("failed to create server: %w", err)
8385
}
8486

85-
// TODO(rumpl): implement pull interval
86-
8787
return s.Serve(ctx, ln)
8888
}

pkg/server/server.go

Lines changed: 9 additions & 62 deletions
Original file line numberDiff line numberDiff line change
@@ -22,43 +22,35 @@ import (
2222
"github.com/docker/cagent/pkg/config"
2323
"github.com/docker/cagent/pkg/runtime"
2424
"github.com/docker/cagent/pkg/session"
25-
"github.com/docker/cagent/pkg/team"
26-
"github.com/docker/cagent/pkg/teamloader"
2725
"github.com/docker/cagent/pkg/tools"
2826
)
2927

3028
type Server struct {
3129
e *echo.Echo
32-
runtimes *concurrent.Map[string, runtime.Runtime]
3330
runtimeCancels *concurrent.Map[string, context.CancelFunc]
3431
sessionStore session.Store
3532
runConfig *config.RuntimeConfig
36-
agentSources config.Sources
33+
sm *sessionManager
3734
}
3835

39-
type Opt func(*Server) error
40-
41-
func New(sessionStore session.Store, runConfig *config.RuntimeConfig, agentSources config.Sources) (*Server, error) {
36+
func New(ctx context.Context, sessionStore session.Store, runConfig *config.RuntimeConfig, refreshInterval time.Duration, agentSources config.Sources) (*Server, error) {
4237
e := echo.New()
4338
e.Use(middleware.CORS())
4439
e.Use(middleware.Logger())
4540

4641
s := &Server{
4742
e: e,
48-
runtimes: concurrent.NewMap[string, runtime.Runtime](),
4943
runtimeCancels: concurrent.NewMap[string, context.CancelFunc](),
5044
sessionStore: sessionStore,
5145
runConfig: runConfig,
52-
agentSources: agentSources,
46+
sm: newSessionManager(ctx, agentSources, refreshInterval),
5347
}
5448

5549
group := e.Group("/api")
5650

5751
// List all available agents
5852
group.GET("/agents", s.getAgents)
5953

60-
// SESSIONS
61-
6254
// List all sessions
6355
group.GET("/sessions", s.getSessions)
6456
// Get a session by id
@@ -95,11 +87,9 @@ func (s *Server) Serve(ctx context.Context, ln net.Listener) error {
9587
return nil
9688
}
9789

98-
// API handlers
99-
10090
func (s *Server) getAgents(c echo.Context) error {
10191
agents := []api.Agent{}
102-
for k, agentSource := range s.agentSources {
92+
for k, agentSource := range s.sm.sources {
10393
slog.Debug("API source", "source", agentSource.Name())
10494

10595
c, err := config.Load(c.Request().Context(), agentSource)
@@ -252,7 +242,7 @@ func (s *Server) resumeSession(c echo.Context) error {
252242
return echo.NewHTTPError(http.StatusBadRequest, fmt.Sprintf("invalid request body: %v", err))
253243
}
254244

255-
rt, exists := s.runtimes.Load(sessionID)
245+
rt, exists := s.sm.runtimes.Load(sessionID)
256246
if !exists {
257247
return echo.NewHTTPError(http.StatusNotFound, fmt.Sprintf("runtime not found: %s", sessionID))
258248
}
@@ -273,9 +263,9 @@ func (s *Server) deleteSession(c echo.Context) error {
273263
}
274264

275265
// Clean up the runtime
276-
if _, exists := s.runtimes.Load(sessionID); exists {
266+
if _, exists := s.sm.runtimes.Load(sessionID); exists {
277267
slog.Debug("Removing runtime for session", "session_id", sessionID)
278-
s.runtimes.Delete(sessionID)
268+
s.sm.runtimes.Delete(sessionID)
279269
}
280270

281271
// Delete the session from storage
@@ -307,7 +297,7 @@ func (s *Server) runAgent(c echo.Context) error {
307297
rc := s.runConfig.Clone()
308298
rc.WorkingDir = sess.WorkingDir
309299

310-
rt, err := s.runtimeForSession(c.Request().Context(), sess, agentFilename, currentAgent, rc)
300+
rt, err := s.sm.runtimeForSession(c.Request().Context(), sess, agentFilename, currentAgent, rc)
311301
if err != nil {
312302
return echo.NewHTTPError(http.StatusInternalServerError, fmt.Sprintf("failed to get runtime for session: %v", err))
313303
}
@@ -332,7 +322,6 @@ func (s *Server) runAgent(c echo.Context) error {
332322
c.Response().Header().Set("Connection", "keep-alive")
333323
c.Response().WriteHeader(http.StatusOK)
334324

335-
// Create a cancellable context for this stream
336325
streamCtx, cancel := context.WithCancel(c.Request().Context())
337326
s.runtimeCancels.Store(sess.ID, cancel)
338327
defer func() {
@@ -356,56 +345,14 @@ func (s *Server) runAgent(c echo.Context) error {
356345
return nil
357346
}
358347

359-
func (s *Server) runtimeForSession(ctx context.Context, sess *session.Session, agentFilename, currentAgent string, runConfig *config.RuntimeConfig) (runtime.Runtime, error) {
360-
rt, exists := s.runtimes.Load(sess.ID)
361-
if exists {
362-
return rt, nil
363-
}
364-
365-
t, err := s.loadTeam(ctx, agentFilename, runConfig)
366-
if err != nil {
367-
return nil, err
368-
}
369-
370-
agent, err := t.Agent(currentAgent)
371-
if err != nil {
372-
return nil, echo.NewHTTPError(http.StatusNotFound, fmt.Sprintf("agent not found: %v", err))
373-
}
374-
sess.MaxIterations = agent.MaxIterations()
375-
376-
opts := []runtime.Opt{
377-
runtime.WithCurrentAgent(currentAgent),
378-
runtime.WithManagedOAuth(false),
379-
runtime.WithRootSessionID(sess.ID),
380-
}
381-
rt, err = runtime.New(t, opts...)
382-
if err != nil {
383-
slog.Error("Failed to create runtime", "error", err)
384-
return nil, echo.NewHTTPError(http.StatusInternalServerError, fmt.Sprintf("failed to create runtime: %v", err))
385-
}
386-
s.runtimes.Store(sess.ID, rt)
387-
slog.Debug("Runtime created for session", "session_id", sess.ID)
388-
389-
return rt, nil
390-
}
391-
392-
func (s *Server) loadTeam(ctx context.Context, agentFilename string, runConfig *config.RuntimeConfig) (*team.Team, error) {
393-
agentSource, found := s.agentSources[agentFilename]
394-
if !found {
395-
return nil, fmt.Errorf("agent not found: %s", agentFilename)
396-
}
397-
398-
return teamloader.Load(ctx, agentSource, runConfig)
399-
}
400-
401348
func (s *Server) elicitation(c echo.Context) error {
402349
sessionID := c.Param("id")
403350
var req api.ResumeElicitationRequest
404351
if err := c.Bind(&req); err != nil {
405352
return echo.NewHTTPError(http.StatusBadRequest, fmt.Sprintf("invalid request body: %v", err))
406353
}
407354

408-
rt, exists := s.runtimes.Load(sessionID)
355+
rt, exists := s.sm.runtimes.Load(sessionID)
409356
if !exists {
410357
return c.JSON(http.StatusNotFound, map[string]string{"error": fmt.Sprintf("runtime not found: %s", sessionID)})
411358
}

pkg/server/server_test.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -95,7 +95,7 @@ func startServer(t *testing.T, ctx context.Context, agentsDir string) string {
9595

9696
sources, err := config.ResolveSources(agentsDir)
9797
require.NoError(t, err)
98-
srv, err := New(store, &runConfig, sources)
98+
srv, err := New(ctx, store, &runConfig, 0, sources)
9999
require.NoError(t, err)
100100

101101
socketPath := "unix://" + filepath.Join(t.TempDir(), "sock")

pkg/server/session_manager.go

Lines changed: 82 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,82 @@
1+
package server
2+
3+
import (
4+
"context"
5+
"fmt"
6+
"log/slog"
7+
"net/http"
8+
"time"
9+
10+
"github.com/labstack/echo/v4"
11+
12+
"github.com/docker/cagent/pkg/concurrent"
13+
"github.com/docker/cagent/pkg/config"
14+
"github.com/docker/cagent/pkg/runtime"
15+
"github.com/docker/cagent/pkg/session"
16+
"github.com/docker/cagent/pkg/team"
17+
"github.com/docker/cagent/pkg/teamloader"
18+
)
19+
20+
type sessionManager struct {
21+
runtimes *concurrent.Map[string, runtime.Runtime]
22+
sources config.Sources
23+
24+
refreshInterval time.Duration
25+
}
26+
27+
func newSessionManager(ctx context.Context, sources config.Sources, refreshInterval time.Duration) *sessionManager {
28+
loaders := make(config.Sources)
29+
for name, source := range sources {
30+
loaders[name] = newSourceLoader(ctx, source, refreshInterval)
31+
}
32+
33+
sm := &sessionManager{
34+
runtimes: concurrent.NewMap[string, runtime.Runtime](),
35+
sources: loaders,
36+
refreshInterval: refreshInterval,
37+
}
38+
39+
return sm
40+
}
41+
42+
func (sm *sessionManager) runtimeForSession(ctx context.Context, sess *session.Session, agentFilename, currentAgent string, rc *config.RuntimeConfig) (runtime.Runtime, error) {
43+
rt, exists := sm.runtimes.Load(sess.ID)
44+
if exists {
45+
return rt, nil
46+
}
47+
48+
t, err := sm.loadTeam(ctx, agentFilename, rc)
49+
if err != nil {
50+
return nil, err
51+
}
52+
53+
agent, err := t.Agent(currentAgent)
54+
if err != nil {
55+
return nil, echo.NewHTTPError(http.StatusNotFound, fmt.Sprintf("agent not found: %v", err))
56+
}
57+
sess.MaxIterations = agent.MaxIterations()
58+
59+
opts := []runtime.Opt{
60+
runtime.WithCurrentAgent(currentAgent),
61+
runtime.WithManagedOAuth(false),
62+
runtime.WithRootSessionID(sess.ID),
63+
}
64+
rt, err = runtime.New(t, opts...)
65+
if err != nil {
66+
slog.Error("Failed to create runtime", "error", err)
67+
return nil, echo.NewHTTPError(http.StatusInternalServerError, fmt.Sprintf("failed to create runtime: %v", err))
68+
}
69+
sm.runtimes.Store(sess.ID, rt)
70+
slog.Debug("Runtime created for session", "session_id", sess.ID)
71+
72+
return rt, nil
73+
}
74+
75+
func (sm *sessionManager) loadTeam(ctx context.Context, agentFilename string, runConfig *config.RuntimeConfig) (*team.Team, error) {
76+
agentSource, found := sm.sources[agentFilename]
77+
if !found {
78+
return nil, fmt.Errorf("agent not found: %s", agentFilename)
79+
}
80+
81+
return teamloader.Load(ctx, agentSource, runConfig)
82+
}

pkg/server/source_loader.go

Lines changed: 83 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,83 @@
1+
package server
2+
3+
import (
4+
"context"
5+
"log/slog"
6+
"sync"
7+
"time"
8+
9+
"github.com/docker/cagent/pkg/config"
10+
)
11+
12+
type sourceLoader struct {
13+
inner config.Source
14+
refreshInterval time.Duration
15+
16+
mu sync.RWMutex
17+
data []byte
18+
err error
19+
}
20+
21+
func newSourceLoader(ctx context.Context, inner config.Source, refreshInterval time.Duration) *sourceLoader {
22+
sl := &sourceLoader{
23+
inner: inner,
24+
refreshInterval: refreshInterval,
25+
}
26+
27+
sl.load(ctx)
28+
29+
if refreshInterval > 0 {
30+
go sl.refreshLoop(ctx)
31+
}
32+
33+
return sl
34+
}
35+
36+
func (sl *sourceLoader) Name() string {
37+
return sl.inner.Name()
38+
}
39+
40+
func (sl *sourceLoader) ParentDir() string {
41+
return sl.inner.ParentDir()
42+
}
43+
44+
func (sl *sourceLoader) Read(_ context.Context) ([]byte, error) {
45+
sl.mu.RLock()
46+
defer sl.mu.RUnlock()
47+
return sl.data, sl.err
48+
}
49+
50+
func (sl *sourceLoader) load(ctx context.Context) {
51+
data, err := sl.inner.Read(ctx)
52+
53+
sl.mu.Lock()
54+
defer sl.mu.Unlock()
55+
56+
if err != nil {
57+
// Only log errors, keep previous data if available
58+
slog.Warn("Failed to refresh source",
59+
"source", sl.inner.Name(),
60+
"error", err)
61+
// Only update error if we don't have data yet
62+
if len(sl.data) == 0 {
63+
sl.err = err
64+
}
65+
} else {
66+
sl.data = data
67+
sl.err = nil
68+
}
69+
}
70+
71+
func (sl *sourceLoader) refreshLoop(ctx context.Context) {
72+
ticker := time.NewTicker(sl.refreshInterval)
73+
defer ticker.Stop()
74+
75+
for {
76+
select {
77+
case <-ctx.Done():
78+
return
79+
case <-ticker.C:
80+
sl.load(ctx)
81+
}
82+
}
83+
}

0 commit comments

Comments
 (0)