Skip to content

Commit c8bc287

Browse files
committed
Fix session creation in the api command when using an ociRef and passing a working_dir in the sess creation request
Signed-off-by: Christopher Petito <chrisjpetito@gmail.com>
1 parent 0301d5b commit c8bc287

3 files changed

Lines changed: 268 additions & 29 deletions

File tree

cmd/root/api.go

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -103,11 +103,16 @@ func (f *apiFlags) runAPICommand(cmd *cobra.Command, args []string) error {
103103
return fmt.Errorf("failed to load teams: %w", err)
104104
}
105105

106-
// For OCI refs: clean up the temp file immediately after loading
107-
// We don't need it anymore since teams are now in memory
106+
// For OCI refs: store the reference for later per-session reloading, then clean up temp file
108107
if agentfile.IsOCIReference(agentsPath) {
109-
_ = os.Remove(resolvedPath)
110-
slog.Debug("Cleaned up temporary OCI file", "path", resolvedPath)
108+
teamKey := filepath.Base(resolvedPath)
109+
opts = append(opts, server.WithOCIRef(teamKey, agentsPath))
110+
111+
if err := os.Remove(resolvedPath); err != nil {
112+
slog.Warn("Failed to remove temporary OCI file", "path", resolvedPath, "error", err)
113+
} else {
114+
slog.Debug("Cleaned up temporary OCI file", "path", resolvedPath)
115+
}
111116
}
112117

113118
defer func() {

pkg/server/server.go

Lines changed: 75 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ import (
2323
"github.com/labstack/echo/v4"
2424
"github.com/labstack/echo/v4/middleware"
2525

26+
"github.com/docker/cagent/pkg/agentfile"
2627
"github.com/docker/cagent/pkg/api"
2728
"github.com/docker/cagent/pkg/config"
2829
"github.com/docker/cagent/pkg/config/latest"
@@ -48,6 +49,8 @@ type Server struct {
4849
runConfig *config.RuntimeConfig
4950
teams map[string]*team.Team
5051
teamsMu sync.RWMutex
52+
ociRef string // OCI reference, set once at startup for OCI-based servers
53+
ociTeamKey string
5154
agentsDir string
5255
agentsPath string // For local files: specific file path to reload (instead of scanning agentsDir)
5356
rootFS *os.Root
@@ -75,6 +78,19 @@ func WithAgentsPath(agentPath string) Opt {
7578
}
7679
}
7780

81+
func WithOCIRef(teamKey, ociRef string) Opt {
82+
return func(s *Server) error {
83+
if teamKey == "" || ociRef == "" {
84+
return nil
85+
}
86+
87+
s.ociTeamKey = teamKey
88+
s.ociRef = ociRef
89+
90+
return nil
91+
}
92+
}
93+
7894
func New(sessionStore session.Store, runConfig *config.RuntimeConfig, teams map[string]*team.Team, opts ...Opt) (*Server, error) {
7995
e := echo.New()
8096
e.Use(middleware.CORS())
@@ -217,6 +233,13 @@ func (s *Server) replaceAllTeams(teams map[string]*team.Team) map[string]*team.T
217233
return oldTeams
218234
}
219235

236+
func (s *Server) getOCIRef(key string) (string, bool) {
237+
if s.ociRef == "" || s.ociTeamKey != key {
238+
return "", false
239+
}
240+
return s.ociRef, true
241+
}
242+
220243
// countTeams returns the number of teams with read lock
221244
func (s *Server) countTeams() int {
222245
s.teamsMu.RLock()
@@ -1188,27 +1211,9 @@ func (s *Server) runAgent(c echo.Context) error {
11881211
rc.WorkingDir = sess.WorkingDir
11891212

11901213
// Load team - either reload from disk (local) or use in-memory team (OCI refs)
1191-
var t *team.Team
1192-
1193-
if s.hasAgentsDir() {
1194-
// Has local agents path or directory: reload from disk to pick up changes
1195-
loadPath := s.agentsPath
1196-
if loadPath == "" {
1197-
loadPath = filepath.Join(s.agentsDir, p)
1198-
}
1199-
1200-
var loadErr error
1201-
t, loadErr = teamloader.Load(c.Request().Context(), loadPath, rc)
1202-
if loadErr != nil {
1203-
return echo.NewHTTPError(http.StatusInternalServerError, fmt.Sprintf("failed to load agent for session: %v", loadErr))
1204-
}
1205-
} else {
1206-
// No local directory: use the already-loaded team from memory (OCI refs)
1207-
var exists bool
1208-
t, exists = s.getTeam(p)
1209-
if !exists {
1210-
return echo.NewHTTPError(http.StatusNotFound, fmt.Sprintf("agent not found: %s", agentFilename))
1211-
}
1214+
t, err := s.loadTeamForSession(c.Request().Context(), p, sess, rc)
1215+
if err != nil {
1216+
return err
12121217
}
12131218

12141219
agent, err := t.Agent(currentAgent)
@@ -1286,6 +1291,55 @@ func (s *Server) runAgent(c echo.Context) error {
12861291
return nil
12871292
}
12881293

1294+
// loadTeamForSession loads the appropriate team for a session, handling both
1295+
// local files (always reload) and OCI refs (reload only if session has custom workingDir).
1296+
func (s *Server) loadTeamForSession(ctx context.Context, agentFilename string, sess *session.Session, rc *config.RuntimeConfig) (*team.Team, error) {
1297+
if s.hasAgentsDir() {
1298+
// Has local agents path or directory: reload from disk to pick up changes
1299+
loadPath := s.agentsPath
1300+
if loadPath == "" {
1301+
loadPath = filepath.Join(s.agentsDir, agentFilename)
1302+
}
1303+
1304+
t, err := teamloader.Load(ctx, loadPath, rc)
1305+
if err != nil {
1306+
return nil, echo.NewHTTPError(http.StatusInternalServerError, fmt.Sprintf("failed to load agent for session: %v", err))
1307+
}
1308+
return t, nil
1309+
}
1310+
1311+
// No local directory: use the already-loaded team from memory (OCI refs)
1312+
t, exists := s.getTeam(agentFilename)
1313+
if !exists {
1314+
return nil, echo.NewHTTPError(http.StatusNotFound, fmt.Sprintf("agent not found: %s", agentFilename))
1315+
}
1316+
1317+
// If session has a custom working dir, reload the agent to pick it up
1318+
if sess.WorkingDir != "" {
1319+
if ociRef, ok := s.getOCIRef(agentFilename); ok {
1320+
yamlContent, err := agentfile.FromStore(ociRef)
1321+
if err != nil {
1322+
slog.Error("Failed to load OCI agent from store", "agent", agentFilename, "oci_ref", ociRef, "error", err)
1323+
return t, nil // Fall back to cached team
1324+
}
1325+
1326+
reloaded, err := teamloader.LoadFrom(
1327+
ctx,
1328+
teamloader.NewBytesSource(sess.WorkingDir, []byte(yamlContent)),
1329+
rc,
1330+
teamloader.WithID(agentFilename),
1331+
)
1332+
if err != nil {
1333+
slog.Error("Failed to reload OCI agent with session working dir", "agent", agentFilename, "error", err)
1334+
return t, nil // Fall back to cached team
1335+
}
1336+
return reloaded, nil
1337+
}
1338+
}
1339+
1340+
return t, nil
1341+
}
1342+
12891343
func fromStore(reference string) (string, error) {
12901344
store, err := content.NewStore()
12911345
if err != nil {

pkg/server/server_test.go

Lines changed: 184 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,10 @@ import (
1919
"github.com/docker/cagent/pkg/api"
2020
"github.com/docker/cagent/pkg/config"
2121
"github.com/docker/cagent/pkg/config/latest"
22+
"github.com/docker/cagent/pkg/content"
23+
"github.com/docker/cagent/pkg/oci"
2224
"github.com/docker/cagent/pkg/session"
25+
"github.com/docker/cagent/pkg/team"
2326
"github.com/docker/cagent/pkg/teamloader"
2427
)
2528

@@ -674,9 +677,9 @@ agents:
674677
count := srv.countTeams()
675678
assert.Equal(t, 1, count, "should only load from agentsPath")
676679

677-
team, exists := srv.getTeam("pirate.yaml")
680+
tm, exists := srv.getTeam("pirate.yaml")
678681
require.True(t, exists)
679-
agent, err := team.Agent("root")
682+
agent, err := tm.Agent("root")
680683
require.NoError(t, err)
681684
assert.Contains(t, agent.Instruction(), "MODIFIED", "should have loaded modified pirate from agentsPath")
682685

@@ -744,10 +747,10 @@ agents:
744747
assert.Equal(t, 1, count, "should only have the OCI agent, not files from /tmp")
745748

746749
// Verify it's the correct agent
747-
team, exists := srv.getTeam("docker.io_myorg_pirate_v1.yaml")
750+
tm, exists := srv.getTeam("docker.io_myorg_pirate_v1.yaml")
748751
require.True(t, exists, "should have the OCI agent")
749752

750-
agent, err := team.Agent("root")
753+
agent, err := tm.Agent("root")
751754
require.NoError(t, err)
752755
assert.Contains(t, agent.Instruction(), "pirate", "should be the pirate agent")
753756

@@ -935,3 +938,180 @@ type mockStore struct {
935938
func (s mockStore) GetSessions(context.Context) ([]*session.Session, error) {
936939
return nil, nil
937940
}
941+
942+
// TestServer_OCIRef_WithOCIRefCaching verifies that OCI agents store their
943+
// reference so per-session working directories can reload from the content store.
944+
func TestServer_OCIRef_WithOCIRefCaching(t *testing.T) {
945+
t.Setenv("OPENAI_API_KEY", "dummy")
946+
947+
runConfig := config.RuntimeConfig{}
948+
teams := map[string]*team.Team{}
949+
950+
ociFilename := "docker.io_myorg_pirate_v1.yaml"
951+
ociRef := "docker.io/myorg/pirate:v1"
952+
953+
var store mockStore
954+
srv, err := New(store, &runConfig, teams, WithOCIRef(ociFilename, ociRef))
955+
require.NoError(t, err)
956+
957+
// Verify that the OCI ref was cached
958+
cachedRef, exists := srv.getOCIRef(ociFilename)
959+
require.True(t, exists, "OCI ref should be cached")
960+
assert.Equal(t, ociRef, cachedRef, "cached ref should match original")
961+
}
962+
963+
// TestServer_LocalAgent_HasAgentsDirSet verifies that local file/directory agents
964+
// have hasAgentsDir() == true, ensuring they reload from disk (not use OCI path).
965+
func TestServer_LocalAgent_HasAgentsDirSet(t *testing.T) {
966+
t.Setenv("OPENAI_API_KEY", "dummy")
967+
968+
ctx := t.Context()
969+
runConfig := config.RuntimeConfig{}
970+
971+
// Test 1: Directory of agents
972+
agentsDir := prepareAgentsDir(t, "pirate.yaml")
973+
teams, err := teamloader.LoadTeams(ctx, agentsDir, &runConfig)
974+
require.NoError(t, err)
975+
976+
var store mockStore
977+
srv, err := New(store, &runConfig, teams, WithAgentsDir(agentsDir))
978+
require.NoError(t, err)
979+
980+
assert.True(t, srv.hasAgentsDir(), "directory of agents should have agentsDir set")
981+
assert.Empty(t, srv.ociRef, "local agents should not have OCI ref")
982+
assert.Empty(t, srv.ociTeamKey, "local agents should not have OCI team key")
983+
984+
// Test 2: Single file agent
985+
agentFile := filepath.Join(agentsDir, "pirate.yaml")
986+
teams2, err := teamloader.LoadTeams(ctx, agentFile, &runConfig)
987+
require.NoError(t, err)
988+
989+
srv2, err := New(store, &runConfig, teams2,
990+
WithAgentsPath(agentFile),
991+
WithAgentsDir(filepath.Dir(agentFile)))
992+
require.NoError(t, err)
993+
994+
assert.True(t, srv2.hasAgentsDir(), "single file agent should have agentsDir set")
995+
assert.Empty(t, srv2.ociRef, "local agents should not have OCI ref")
996+
assert.Empty(t, srv2.ociTeamKey, "local agents should not have OCI team key")
997+
}
998+
999+
// TestServer_loadTeamForSession_LocalAgent verifies that local agents always reload from disk.
1000+
func TestServer_loadTeamForSession_LocalAgent(t *testing.T) {
1001+
t.Setenv("OPENAI_API_KEY", "dummy")
1002+
1003+
ctx := t.Context()
1004+
runConfig := config.RuntimeConfig{}
1005+
1006+
agentsDir := prepareAgentsDir(t, "pirate.yaml")
1007+
teams, err := teamloader.LoadTeams(ctx, agentsDir, &runConfig)
1008+
require.NoError(t, err)
1009+
1010+
var store mockStore
1011+
srv, err := New(store, &runConfig, teams, WithAgentsDir(agentsDir))
1012+
require.NoError(t, err)
1013+
1014+
sess := session.New()
1015+
rc := runConfig.Clone()
1016+
1017+
// Call the extracted method
1018+
tm, err := srv.loadTeamForSession(ctx, "pirate.yaml", sess, rc)
1019+
require.NoError(t, err)
1020+
require.NotNil(t, tm)
1021+
1022+
agent, err := tm.Agent("root")
1023+
require.NoError(t, err)
1024+
assert.Contains(t, agent.Instruction(), "pirate")
1025+
}
1026+
1027+
// TestServer_loadTeamForSession_OCIRef_NoWorkingDir verifies OCI agents use cached team
1028+
// when session has no custom working directory.
1029+
func TestServer_loadTeamForSession_OCIRef_NoWorkingDir(t *testing.T) {
1030+
t.Setenv("OPENAI_API_KEY", "dummy")
1031+
1032+
ctx := t.Context()
1033+
tmpDir := t.TempDir()
1034+
1035+
pirateContent, err := os.ReadFile(filepath.Join("testdata", "pirate.yaml"))
1036+
require.NoError(t, err)
1037+
1038+
ociFilename := "docker.io_myorg_pirate_v1.yaml"
1039+
ociFile := filepath.Join(tmpDir, ociFilename)
1040+
err = os.WriteFile(ociFile, pirateContent, 0o600)
1041+
require.NoError(t, err)
1042+
1043+
ociRef := "docker.io/myorg/pirate:v1"
1044+
runConfig := config.RuntimeConfig{}
1045+
teams, err := teamloader.LoadTeams(ctx, ociFile, &runConfig)
1046+
require.NoError(t, err)
1047+
1048+
var store mockStore
1049+
srv, err := New(store, &runConfig, teams, WithOCIRef(ociFilename, ociRef))
1050+
require.NoError(t, err)
1051+
1052+
// Session without working dir
1053+
sess := session.New()
1054+
rc := runConfig.Clone()
1055+
1056+
// Should use cached team
1057+
tm, err := srv.loadTeamForSession(ctx, ociFilename, sess, rc)
1058+
require.NoError(t, err)
1059+
require.NotNil(t, tm)
1060+
1061+
// Verify it's the same team instance (pointer equality)
1062+
cachedTeam, _ := srv.getTeam(ociFilename)
1063+
assert.Same(t, cachedTeam, tm, "should return cached team when no custom workingDir")
1064+
}
1065+
1066+
// TestServer_loadTeamForSession_OCIRef_WithWorkingDir verifies OCI agents reload
1067+
// when session has a custom working directory (requires content store).
1068+
func TestServer_loadTeamForSession_OCIRef_WithWorkingDir(t *testing.T) {
1069+
t.Setenv("OPENAI_API_KEY", "dummy")
1070+
1071+
ctx := t.Context()
1072+
tmpDir := t.TempDir()
1073+
1074+
// Create agent with filesystem tool to show working dir matters
1075+
agentWithFS := `version: "2"
1076+
agents:
1077+
root:
1078+
model: openai/gpt-4o
1079+
instruction: Test agent
1080+
toolsets:
1081+
- type: filesystem`
1082+
1083+
ociFilename := "docker.io_test_fs_v1.yaml"
1084+
ociFile := filepath.Join(tmpDir, ociFilename)
1085+
err := os.WriteFile(ociFile, []byte(agentWithFS), 0o600)
1086+
require.NoError(t, err)
1087+
1088+
// Push to content store so FromStore can retrieve it
1089+
ociRef := "docker.io/test/fs:v1"
1090+
contentStore, err := content.NewStore()
1091+
require.NoError(t, err)
1092+
_, err = oci.PackageFileAsOCIToStore(ctx, ociFile, ociRef, contentStore)
1093+
require.NoError(t, err)
1094+
1095+
runConfig := config.RuntimeConfig{}
1096+
teams, err := teamloader.LoadTeams(ctx, ociFile, &runConfig)
1097+
require.NoError(t, err)
1098+
1099+
var sessionStore mockStore
1100+
srv, err := New(sessionStore, &runConfig, teams, WithOCIRef(ociFilename, ociRef))
1101+
require.NoError(t, err)
1102+
1103+
// Session WITH working dir
1104+
customWorkingDir := t.TempDir()
1105+
sess := session.New(session.WithWorkingDir(customWorkingDir))
1106+
rc := runConfig.Clone()
1107+
rc.WorkingDir = customWorkingDir
1108+
1109+
// Should reload from content store
1110+
tm, err := srv.loadTeamForSession(ctx, ociFilename, sess, rc)
1111+
require.NoError(t, err)
1112+
require.NotNil(t, tm)
1113+
1114+
// Verify it's NOT the same team instance (was reloaded)
1115+
cachedTeam, _ := srv.getTeam(ociFilename)
1116+
assert.NotSame(t, cachedTeam, tm, "should return new team instance when custom workingDir is set")
1117+
}

0 commit comments

Comments
 (0)