Skip to content

Commit 5881f7d

Browse files
authored
feat: add e2e test for ACP (#190)
1 parent 5d6b259 commit 5881f7d

5 files changed

Lines changed: 231 additions & 6 deletions

File tree

e2e/acp_echo.go

Lines changed: 145 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,145 @@
1+
//go:build ignore
2+
3+
package main
4+
5+
import (
6+
"context"
7+
"encoding/json"
8+
"fmt"
9+
"os"
10+
"os/signal"
11+
"strings"
12+
13+
acp "github.com/coder/acp-go-sdk"
14+
)
15+
16+
// ScriptEntry defines a single entry in the test script.
17+
type ScriptEntry struct {
18+
ExpectMessage string `json:"expectMessage"`
19+
ThinkDurationMS int64 `json:"thinkDurationMS"`
20+
ResponseMessage string `json:"responseMessage"`
21+
}
22+
23+
// acpEchoAgent implements the ACP Agent interface for testing.
24+
type acpEchoAgent struct {
25+
script []ScriptEntry
26+
scriptIndex int
27+
conn *acp.AgentSideConnection
28+
sessionID acp.SessionId
29+
}
30+
31+
var _ acp.Agent = (*acpEchoAgent)(nil)
32+
33+
func main() {
34+
if len(os.Args) != 2 {
35+
fmt.Fprintln(os.Stderr, "Usage: acp_echo <script.json>")
36+
os.Exit(1)
37+
}
38+
39+
script, err := loadScript(os.Args[1])
40+
if err != nil {
41+
fmt.Fprintf(os.Stderr, "Error loading script: %v\n", err)
42+
os.Exit(1)
43+
}
44+
45+
if len(script) == 0 {
46+
fmt.Fprintln(os.Stderr, "Script is empty")
47+
os.Exit(1)
48+
}
49+
50+
sigCh := make(chan os.Signal, 1)
51+
signal.Notify(sigCh, os.Interrupt)
52+
go func() {
53+
<-sigCh
54+
os.Exit(0)
55+
}()
56+
57+
agent := &acpEchoAgent{
58+
script: script,
59+
}
60+
61+
conn := acp.NewAgentSideConnection(agent, os.Stdout, os.Stdin)
62+
agent.conn = conn
63+
64+
<-conn.Done()
65+
}
66+
67+
func (a *acpEchoAgent) Initialize(_ context.Context, _ acp.InitializeRequest) (acp.InitializeResponse, error) {
68+
return acp.InitializeResponse{
69+
ProtocolVersion: acp.ProtocolVersionNumber,
70+
AgentCapabilities: acp.AgentCapabilities{},
71+
}, nil
72+
}
73+
74+
func (a *acpEchoAgent) Authenticate(_ context.Context, _ acp.AuthenticateRequest) (acp.AuthenticateResponse, error) {
75+
return acp.AuthenticateResponse{}, nil
76+
}
77+
78+
func (a *acpEchoAgent) Cancel(_ context.Context, _ acp.CancelNotification) error {
79+
return nil
80+
}
81+
82+
func (a *acpEchoAgent) NewSession(_ context.Context, _ acp.NewSessionRequest) (acp.NewSessionResponse, error) {
83+
a.sessionID = "test-session"
84+
return acp.NewSessionResponse{
85+
SessionId: a.sessionID,
86+
}, nil
87+
}
88+
89+
func (a *acpEchoAgent) Prompt(ctx context.Context, params acp.PromptRequest) (acp.PromptResponse, error) {
90+
// Extract text from prompt
91+
var promptText string
92+
for _, block := range params.Prompt {
93+
if block.Text != nil {
94+
promptText = block.Text.Text
95+
break
96+
}
97+
}
98+
promptText = strings.TrimSpace(promptText)
99+
100+
if a.scriptIndex >= len(a.script) {
101+
return acp.PromptResponse{
102+
StopReason: acp.StopReasonEndTurn,
103+
}, nil
104+
}
105+
106+
entry := a.script[a.scriptIndex]
107+
expected := strings.TrimSpace(entry.ExpectMessage)
108+
109+
// Empty ExpectMessage matches any prompt
110+
if expected != "" && expected != promptText {
111+
return acp.PromptResponse{}, fmt.Errorf("expected message %q but got %q", expected, promptText)
112+
}
113+
114+
a.scriptIndex++
115+
116+
// Send response via session update
117+
if err := a.conn.SessionUpdate(ctx, acp.SessionNotification{
118+
SessionId: params.SessionId,
119+
Update: acp.UpdateAgentMessageText(entry.ResponseMessage),
120+
}); err != nil {
121+
return acp.PromptResponse{}, err
122+
}
123+
124+
return acp.PromptResponse{
125+
StopReason: acp.StopReasonEndTurn,
126+
}, nil
127+
}
128+
129+
func (a *acpEchoAgent) SetSessionMode(_ context.Context, _ acp.SetSessionModeRequest) (acp.SetSessionModeResponse, error) {
130+
return acp.SetSessionModeResponse{}, nil
131+
}
132+
133+
func loadScript(scriptPath string) ([]ScriptEntry, error) {
134+
data, err := os.ReadFile(scriptPath)
135+
if err != nil {
136+
return nil, fmt.Errorf("failed to read script file: %w", err)
137+
}
138+
139+
var script []ScriptEntry
140+
if err := json.Unmarshal(data, &script); err != nil {
141+
return nil, fmt.Errorf("failed to parse script JSON: %w", err)
142+
}
143+
144+
return script, nil
145+
}

e2e/echo_test.go

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -100,6 +100,34 @@ func TestE2E(t *testing.T) {
100100
require.Equal(t, script[0].ExpectMessage, strings.TrimSpace(msgResp.Messages[1].Content))
101101
require.Equal(t, script[0].ResponseMessage, strings.TrimSpace(msgResp.Messages[2].Content))
102102
})
103+
104+
t.Run("acp_basic", func(t *testing.T) {
105+
ctx, cancel := context.WithTimeout(context.Background(), testTimeout)
106+
defer cancel()
107+
108+
script, apiClient := setup(ctx, t, &params{
109+
cmdFn: func(ctx context.Context, t testing.TB, serverPort int, binaryPath, cwd, scriptFilePath string) (string, []string) {
110+
return binaryPath, []string{
111+
"server",
112+
fmt.Sprintf("--port=%d", serverPort),
113+
"--experimental-acp",
114+
"--", "go", "run", filepath.Join(cwd, "acp_echo.go"), scriptFilePath,
115+
}
116+
},
117+
})
118+
messageReq := agentapisdk.PostMessageParams{
119+
Content: "This is a test message.",
120+
Type: agentapisdk.MessageTypeUser,
121+
}
122+
_, err := apiClient.PostMessage(ctx, messageReq)
123+
require.NoError(t, err, "Failed to send message via SDK")
124+
require.NoError(t, waitAgentAPIStable(ctx, t, apiClient, operationTimeout, "post message"))
125+
msgResp, err := apiClient.GetMessages(ctx)
126+
require.NoError(t, err, "Failed to get messages via SDK")
127+
require.Len(t, msgResp.Messages, 2)
128+
require.Equal(t, script[0].ExpectMessage, strings.TrimSpace(msgResp.Messages[0].Content))
129+
require.Equal(t, script[0].ResponseMessage, strings.TrimSpace(msgResp.Messages[1].Content))
130+
})
103131
}
104132

105133
type params struct {

e2e/testdata/acp_basic.json

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
[
2+
{
3+
"expectMessage": "This is a test message.",
4+
"responseMessage": "Echo: This is a test message."
5+
}
6+
]

x/acpio/acp_conversation.go

Lines changed: 25 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ import (
66
"slices"
77
"strings"
88
"sync"
9+
"time"
910

1011
st "github.com/coder/agentapi/lib/screentracker"
1112
"github.com/coder/quartz"
@@ -31,7 +32,8 @@ type ACPConversation struct {
3132
agentIO ChunkableAgentIO
3233
messages []st.ConversationMessage
3334
nextID int // monotonically increasing message ID
34-
prompting bool // true while agent is processing
35+
prompting bool // true while agent is processing
36+
chunkReceived chan struct{} // signals that handleChunk has accumulated a chunk
3537
streamingResponse strings.Builder
3638
logger *slog.Logger
3739
emitter st.Emitter
@@ -68,6 +70,7 @@ func NewACPConversation(ctx context.Context, agentIO ChunkableAgentIO, logger *s
6870
initialPrompt: initialPrompt,
6971
emitter: emitter,
7072
clock: clock,
73+
chunkReceived: make(chan struct{}, 1),
7174
}
7275
return c
7376
}
@@ -202,13 +205,25 @@ func (c *ACPConversation) handleChunk(chunk string) {
202205
screen := c.streamingResponse.String()
203206
c.mu.Unlock()
204207

208+
// Signal that a chunk has been received (non-blocking; a pending signal is sufficient).
209+
select {
210+
case c.chunkReceived <- struct{}{}:
211+
default:
212+
}
213+
205214
c.emitter.EmitMessages(messages)
206215
c.emitter.EmitStatus(status)
207216
c.emitter.EmitScreen(screen)
208217
}
209218

210219
// executePrompt runs the actual agent request and returns any error.
211220
func (c *ACPConversation) executePrompt(messageParts []st.MessagePart) error {
221+
// Drain any stale signal before sending the prompt.
222+
select {
223+
case <-c.chunkReceived:
224+
default:
225+
}
226+
212227
var err error
213228
for _, part := range messageParts {
214229
if c.ctx.Err() != nil {
@@ -221,6 +236,15 @@ func (c *ACPConversation) executePrompt(messageParts []st.MessagePart) error {
221236
}
222237
}
223238

239+
// The ACP SDK dispatches SessionUpdate notifications as goroutines, so
240+
// the chunk may arrive after conn.Prompt() returns. Wait up to 100ms.
241+
timer := c.clock.NewTimer(100 * time.Millisecond)
242+
select {
243+
case <-c.chunkReceived:
244+
case <-timer.C:
245+
}
246+
timer.Stop()
247+
224248
c.mu.Lock()
225249
c.prompting = false
226250

x/acpio/acp_conversation_test.go

Lines changed: 27 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -227,6 +227,9 @@ func Test_Send_AddsUserMessage(t *testing.T) {
227227
assert.Equal(t, "hello", messages[0].Message)
228228
assert.Equal(t, screentracker.ConversationRoleAgent, messages[1].Role)
229229

230+
// Signal a chunk so executePrompt's timer wait doesn't hang on the mock clock.
231+
mock.SimulateChunks("hello response")
232+
230233
// Unblock the write to let Send complete
231234
close(done)
232235
require.NoError(t, <-errCh)
@@ -290,6 +293,9 @@ func Test_Send_RejectsDuplicateSend(t *testing.T) {
290293
err := conv.Send(screentracker.MessagePartText{Content: "second"})
291294
assert.ErrorIs(t, err, screentracker.ErrMessageValidationChanging)
292295

296+
// Signal a chunk so executePrompt's timer wait doesn't hang on the mock clock.
297+
mock.SimulateChunks("first response")
298+
293299
// Unblock the write to let the test complete cleanly
294300
close(done)
295301
require.NoError(t, <-errCh)
@@ -318,6 +324,9 @@ func Test_Status_ChangesWhileProcessing(t *testing.T) {
318324
// Status should be changing while processing
319325
assert.Equal(t, screentracker.ConversationStatusChanging, conv.Status())
320326

327+
// Signal a chunk so executePrompt's timer wait doesn't hang on the mock clock.
328+
mock.SimulateChunks("test response")
329+
321330
// Unblock the write
322331
close(done)
323332

@@ -428,6 +437,9 @@ func Test_InitialPrompt_SentOnStart(t *testing.T) {
428437
assert.Equal(t, screentracker.ConversationRoleUser, messages[0].Role)
429438
assert.Equal(t, "initial prompt", messages[0].Message)
430439

440+
// Signal a chunk so executePrompt's timer wait doesn't hang on the mock clock.
441+
mock.SimulateChunks("initial response")
442+
431443
// Unblock the write to let the test complete cleanly
432444
close(done)
433445
}
@@ -457,6 +469,9 @@ func Test_Messages_AreCopied(t *testing.T) {
457469
originalMessages := conv.Messages()
458470
assert.Equal(t, "test", originalMessages[0].Message)
459471

472+
// Signal a chunk so executePrompt's timer wait doesn't hang on the mock clock.
473+
mock.SimulateChunks("test response")
474+
460475
// Unblock the write to let Send complete
461476
close(done)
462477
require.NoError(t, <-errCh)
@@ -518,12 +533,15 @@ func Test_ErrorRemovesPartialMessage(t *testing.T) {
518533
// Send a second message — IDs must not reuse the removed agent message's ID (1).
519534
mock.mu.Lock()
520535
mock.writeErr = nil
521-
mock.writeBlock = nil
522-
mock.writeStarted = nil
523536
mock.mu.Unlock()
524-
525-
err := conv.Send(screentracker.MessagePartText{Content: "retry"})
526-
require.NoError(t, err)
537+
started2, done2 := mock.BlockWrite()
538+
errCh2 := make(chan error, 1)
539+
go func() { errCh2 <- conv.Send(screentracker.MessagePartText{Content: "retry"}) }()
540+
<-started2
541+
// Signal a chunk so executePrompt's timer wait doesn't hang on the mock clock.
542+
mock.SimulateChunks("retry response")
543+
close(done2)
544+
require.NoError(t, <-errCh2)
527545

528546
messages = conv.Messages()
529547
require.Len(t, messages, 3, "first user + second user + second agent")
@@ -548,6 +566,10 @@ func Test_LateChunkAfterError_DoesNotCorruptUserMessage(t *testing.T) {
548566
mock.mu.Lock()
549567
mock.writeErr = assert.AnError
550568
mock.mu.Unlock()
569+
570+
// Signal a chunk before unblocking; the error path still waits on chunkReceived
571+
// or the timer, so pre-signaling avoids a hang on the mock clock.
572+
mock.SimulateChunks("unexpected chunk")
551573
close(done)
552574

553575
require.ErrorIs(t, <-errCh, assert.AnError)

0 commit comments

Comments
 (0)