Skip to content

Commit ea08bcc

Browse files
committed
test(cli): add tests for max iterations auto-continue
Cover --yolo auto-continue behavior for MaxIterationsReachedEvent in both normal and JSON output modes, including safety cap enforcement. Signed-off-by: Matthieu FRONTON <m@tthieu.fr>
1 parent 7ed808f commit ea08bcc

1 file changed

Lines changed: 205 additions & 0 deletions

File tree

pkg/cli/runner_test.go

Lines changed: 205 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,205 @@
1+
package cli
2+
3+
import (
4+
"bytes"
5+
"context"
6+
"sync"
7+
"testing"
8+
9+
"gotest.tools/v3/assert"
10+
11+
"github.com/docker/cagent/pkg/runtime"
12+
"github.com/docker/cagent/pkg/session"
13+
"github.com/docker/cagent/pkg/sessiontitle"
14+
"github.com/docker/cagent/pkg/tools"
15+
mcptools "github.com/docker/cagent/pkg/tools/mcp"
16+
)
17+
18+
// mockRuntime implements runtime.Runtime for testing the CLI runner.
19+
// It emits pre-configured events from RunStream and records Resume calls.
20+
type mockRuntime struct {
21+
events []runtime.Event
22+
23+
mu sync.Mutex
24+
resumes []runtime.ResumeRequest
25+
}
26+
27+
func (m *mockRuntime) CurrentAgentName() string { return "test" }
28+
func (m *mockRuntime) CurrentAgentInfo(context.Context) runtime.CurrentAgentInfo {
29+
return runtime.CurrentAgentInfo{Name: "test"}
30+
}
31+
func (m *mockRuntime) SetCurrentAgent(string) error { return nil }
32+
func (m *mockRuntime) CurrentAgentTools(context.Context) ([]tools.Tool, error) { return nil, nil }
33+
func (m *mockRuntime) EmitStartupInfo(context.Context, chan runtime.Event) {}
34+
func (m *mockRuntime) ResetStartupInfo() {}
35+
func (m *mockRuntime) Run(context.Context, *session.Session) ([]session.Message, error) {
36+
return nil, nil
37+
}
38+
39+
func (m *mockRuntime) ResumeElicitation(context.Context, tools.ElicitationAction, map[string]any) error {
40+
return nil
41+
}
42+
func (m *mockRuntime) SessionStore() session.Store { return nil }
43+
func (m *mockRuntime) Summarize(context.Context, *session.Session, string, chan runtime.Event) {}
44+
func (m *mockRuntime) PermissionsInfo() *runtime.PermissionsInfo { return nil }
45+
func (m *mockRuntime) CurrentAgentSkillsEnabled() bool { return false }
46+
func (m *mockRuntime) CurrentMCPPrompts(context.Context) map[string]mcptools.PromptInfo {
47+
return nil
48+
}
49+
50+
func (m *mockRuntime) ExecuteMCPPrompt(context.Context, string, map[string]string) (string, error) {
51+
return "", nil
52+
}
53+
func (m *mockRuntime) UpdateSessionTitle(context.Context, *session.Session, string) error { return nil }
54+
func (m *mockRuntime) TitleGenerator() *sessiontitle.Generator { return nil }
55+
func (m *mockRuntime) Close() error { return nil }
56+
func (m *mockRuntime) RegenerateTitle(context.Context, *session.Session, chan runtime.Event) {}
57+
58+
func (m *mockRuntime) Resume(_ context.Context, req runtime.ResumeRequest) {
59+
m.mu.Lock()
60+
defer m.mu.Unlock()
61+
m.resumes = append(m.resumes, req)
62+
}
63+
64+
func (m *mockRuntime) RunStream(_ context.Context, _ *session.Session) <-chan runtime.Event {
65+
ch := make(chan runtime.Event, len(m.events))
66+
for _, e := range m.events {
67+
ch <- e
68+
}
69+
close(ch)
70+
return ch
71+
}
72+
73+
func (m *mockRuntime) getResumes() []runtime.ResumeRequest {
74+
m.mu.Lock()
75+
defer m.mu.Unlock()
76+
result := make([]runtime.ResumeRequest, len(m.resumes))
77+
copy(result, m.resumes)
78+
return result
79+
}
80+
81+
func maxIterEvent(maxIter int) *runtime.MaxIterationsReachedEvent {
82+
return &runtime.MaxIterationsReachedEvent{
83+
Type: "max_iterations_reached",
84+
MaxIterations: maxIter,
85+
}
86+
}
87+
88+
func TestMaxIterationsAutoApproveInYoloMode(t *testing.T) {
89+
t.Parallel()
90+
91+
rt := &mockRuntime{
92+
events: []runtime.Event{maxIterEvent(60)},
93+
}
94+
95+
var buf bytes.Buffer
96+
out := NewPrinter(&buf)
97+
sess := session.New()
98+
cfg := Config{AutoApprove: true}
99+
100+
err := Run(t.Context(), out, cfg, rt, sess, []string{"hello"})
101+
assert.NilError(t, err)
102+
103+
resumes := rt.getResumes()
104+
assert.Equal(t, len(resumes), 1)
105+
assert.Equal(t, resumes[0].Type, runtime.ResumeTypeApprove)
106+
}
107+
108+
func TestMaxIterationsAutoApproveSafetyCap(t *testing.T) {
109+
t.Parallel()
110+
111+
// Emit maxAutoExtensions+1 events to trigger the safety cap
112+
events := make([]runtime.Event, maxAutoExtensions+1)
113+
for i := range events {
114+
events[i] = maxIterEvent(60 + i*10)
115+
}
116+
117+
rt := &mockRuntime{events: events}
118+
119+
var buf bytes.Buffer
120+
out := NewPrinter(&buf)
121+
sess := session.New()
122+
cfg := Config{AutoApprove: true}
123+
124+
err := Run(t.Context(), out, cfg, rt, sess, []string{"hello"})
125+
assert.NilError(t, err)
126+
127+
resumes := rt.getResumes()
128+
assert.Equal(t, len(resumes), maxAutoExtensions+1)
129+
130+
// First maxAutoExtensions should be approved
131+
for i := range maxAutoExtensions {
132+
assert.Equal(t, resumes[i].Type, runtime.ResumeTypeApprove,
133+
"extension %d should be approved", i+1)
134+
}
135+
// Last one should be rejected (safety cap)
136+
assert.Equal(t, resumes[maxAutoExtensions].Type, runtime.ResumeTypeReject,
137+
"extension beyond cap should be rejected")
138+
}
139+
140+
func TestMaxIterationsAutoApproveJSONMode(t *testing.T) {
141+
t.Parallel()
142+
143+
rt := &mockRuntime{
144+
events: []runtime.Event{maxIterEvent(60)},
145+
}
146+
147+
var buf bytes.Buffer
148+
out := NewPrinter(&buf)
149+
sess := session.New()
150+
cfg := Config{AutoApprove: true, OutputJSON: true}
151+
152+
err := Run(t.Context(), out, cfg, rt, sess, []string{"hello"})
153+
assert.NilError(t, err)
154+
155+
resumes := rt.getResumes()
156+
assert.Equal(t, len(resumes), 1)
157+
assert.Equal(t, resumes[0].Type, runtime.ResumeTypeApprove)
158+
}
159+
160+
func TestMaxIterationsRejectInJSONModeWithoutYolo(t *testing.T) {
161+
t.Parallel()
162+
163+
rt := &mockRuntime{
164+
events: []runtime.Event{maxIterEvent(60)},
165+
}
166+
167+
var buf bytes.Buffer
168+
out := NewPrinter(&buf)
169+
sess := session.New()
170+
cfg := Config{AutoApprove: false, OutputJSON: true}
171+
172+
err := Run(t.Context(), out, cfg, rt, sess, []string{"hello"})
173+
assert.NilError(t, err)
174+
175+
resumes := rt.getResumes()
176+
assert.Equal(t, len(resumes), 1)
177+
assert.Equal(t, resumes[0].Type, runtime.ResumeTypeReject)
178+
}
179+
180+
func TestMaxIterationsSafetyCapJSONMode(t *testing.T) {
181+
t.Parallel()
182+
183+
events := make([]runtime.Event, maxAutoExtensions+1)
184+
for i := range events {
185+
events[i] = maxIterEvent(60 + i*10)
186+
}
187+
188+
rt := &mockRuntime{events: events}
189+
190+
var buf bytes.Buffer
191+
out := NewPrinter(&buf)
192+
sess := session.New()
193+
cfg := Config{AutoApprove: true, OutputJSON: true}
194+
195+
err := Run(t.Context(), out, cfg, rt, sess, []string{"hello"})
196+
assert.NilError(t, err)
197+
198+
resumes := rt.getResumes()
199+
assert.Equal(t, len(resumes), maxAutoExtensions+1)
200+
201+
for i := range maxAutoExtensions {
202+
assert.Equal(t, resumes[i].Type, runtime.ResumeTypeApprove)
203+
}
204+
assert.Equal(t, resumes[maxAutoExtensions].Type, runtime.ResumeTypeReject)
205+
}

0 commit comments

Comments
 (0)