Skip to content

Commit 9c709dd

Browse files
authored
Merge pull request #435 from dgageot/test-server
Add test for API server
2 parents ddf9cb3 + 71be479 commit 9c709dd

2 files changed

Lines changed: 166 additions & 0 deletions

File tree

pkg/server/server_test.go

Lines changed: 159 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,159 @@
1+
package server
2+
3+
import (
4+
"context"
5+
"encoding/json"
6+
"io"
7+
"net"
8+
"net/http"
9+
"os"
10+
"path/filepath"
11+
"strings"
12+
"testing"
13+
14+
"github.com/stretchr/testify/assert"
15+
"github.com/stretchr/testify/require"
16+
17+
"github.com/docker/cagent/pkg/api"
18+
"github.com/docker/cagent/pkg/config"
19+
latest "github.com/docker/cagent/pkg/config/v2"
20+
"github.com/docker/cagent/pkg/session"
21+
)
22+
23+
func TestServerTODO(t *testing.T) {
24+
t.Setenv("OPENAI_API_KEY", "dummy")
25+
ctx := t.Context()
26+
27+
agentsDir := prepareAgentsDir(t, "pirate.yaml")
28+
lnPath := startServer(t, ctx, agentsDir)
29+
30+
t.Run("list agents", func(t *testing.T) {
31+
buf := httpGET(t, ctx, lnPath, "/api/agents")
32+
33+
var agents []any
34+
unmarshal(t, buf, &agents)
35+
36+
assert.NotEmpty(t, agents)
37+
})
38+
39+
t.Run("get agent (no extension)", func(t *testing.T) {
40+
buf := httpGET(t, ctx, lnPath, "/api/agents/pirate")
41+
42+
var cfg latest.Config
43+
unmarshal(t, buf, &cfg)
44+
45+
assert.NotEmpty(t, cfg.Version)
46+
require.NotEmpty(t, cfg.Agents)
47+
assert.Contains(t, cfg.Agents["root"].Instruction, "pirate")
48+
})
49+
50+
t.Run("get agent", func(t *testing.T) {
51+
buf := httpGET(t, ctx, lnPath, "/api/agents/pirate.yaml")
52+
53+
var cfg latest.Config
54+
unmarshal(t, buf, &cfg)
55+
56+
assert.NotEmpty(t, cfg.Version)
57+
require.NotEmpty(t, cfg.Agents)
58+
assert.Contains(t, cfg.Agents["root"].Instruction, "pirate")
59+
})
60+
61+
t.Run("get agent's yaml (no extension)", func(t *testing.T) {
62+
content := httpGET(t, ctx, lnPath, "/api/agents/pirate/yaml")
63+
assert.Contains(t, string(content), "pirate")
64+
})
65+
66+
t.Run("get agent's yaml", func(t *testing.T) {
67+
content := httpGET(t, ctx, lnPath, "/api/agents/pirate.yaml/yaml")
68+
assert.Contains(t, string(content), "pirate")
69+
})
70+
71+
t.Run("list sessions", func(t *testing.T) {
72+
buf := httpGET(t, ctx, lnPath, "/api/sessions")
73+
74+
var sessions []api.SessionsResponse
75+
unmarshal(t, buf, &sessions)
76+
77+
assert.Empty(t, sessions)
78+
})
79+
}
80+
81+
func prepareAgentsDir(t *testing.T, testFiles ...string) string {
82+
t.Helper()
83+
84+
agentsDir := filepath.Join(t.TempDir(), "agents")
85+
err := os.MkdirAll(agentsDir, 0o700)
86+
require.NoError(t, err)
87+
88+
for _, file := range testFiles {
89+
buf, err := os.ReadFile(filepath.Join("testdata", file))
90+
require.NoError(t, err)
91+
92+
err = os.WriteFile(filepath.Join(agentsDir, filepath.Base(file)), buf, 0o600)
93+
require.NoError(t, err)
94+
}
95+
96+
return agentsDir
97+
}
98+
99+
func startServer(t *testing.T, ctx context.Context, agentsDir string) string {
100+
t.Helper()
101+
102+
var store mockStore
103+
var runConfig config.RuntimeConfig
104+
105+
srv, err := New(store, runConfig, nil, WithAgentsDir(agentsDir))
106+
require.NoError(t, err)
107+
108+
socketPath := "unix://" + filepath.Join(t.TempDir(), "test.sock")
109+
ln, err := Listen(ctx, socketPath)
110+
require.NoError(t, err)
111+
go func() {
112+
<-ctx.Done()
113+
_ = ln.Close()
114+
}()
115+
116+
go srv.Serve(ctx, ln)
117+
118+
return socketPath
119+
}
120+
121+
func httpGET(t *testing.T, ctx context.Context, socketPath, path string) []byte {
122+
t.Helper()
123+
124+
client := &http.Client{
125+
Transport: &http.Transport{
126+
DialContext: func(ctx context.Context, _, _ string) (net.Conn, error) {
127+
var d net.Dialer
128+
return d.DialContext(ctx, "unix", strings.TrimPrefix(socketPath, "unix://"))
129+
},
130+
},
131+
}
132+
133+
url := "http://_" + path
134+
req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, http.NoBody)
135+
require.NoError(t, err)
136+
137+
resp, err := client.Do(req)
138+
require.NoError(t, err)
139+
defer resp.Body.Close()
140+
141+
buf, err := io.ReadAll(resp.Body)
142+
require.NoError(t, err)
143+
144+
return buf
145+
}
146+
147+
func unmarshal(t *testing.T, buf []byte, v any) {
148+
t.Helper()
149+
err := json.Unmarshal(buf, &v)
150+
require.NoError(t, err)
151+
}
152+
153+
type mockStore struct {
154+
session.Store
155+
}
156+
157+
func (s mockStore) GetSessions(ctx context.Context) ([]*session.Session, error) {
158+
return nil, nil
159+
}

pkg/server/testdata/pirate.yaml

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
#!/usr/bin/env cagent run
2+
version: "2"
3+
4+
agents:
5+
root:
6+
instruction: Always answer by talking like a pirate.
7+
model: openai/gpt-4o

0 commit comments

Comments
 (0)