Skip to content

Commit 8740605

Browse files
authored
Merge pull request #1737 from frntn/frntn/yolo-auto-continue-max-iterations
fix(cli): auto-continue max iterations in --yolo mode
2 parents 3e4f113 + 3634385 commit 8740605

2 files changed

Lines changed: 263 additions & 7 deletions

File tree

pkg/cli/runner.go

Lines changed: 58 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,39 @@ func (e RuntimeError) Unwrap() error {
3131
return e.Err
3232
}
3333

34+
// maxAutoExtensions is the maximum number of times --yolo mode will
35+
// auto-continue when max iterations is reached, to prevent infinite loops.
36+
const maxAutoExtensions = 5
37+
38+
// maxIterAction describes what the caller should do after a MaxIterationsReachedEvent.
39+
type maxIterAction int
40+
41+
const (
42+
maxIterContinue maxIterAction = iota // auto-approved, keep running
43+
maxIterStop // safety cap reached, caller should stop
44+
maxIterPrompt // not in yolo mode, caller should prompt the user
45+
)
46+
47+
// handleMaxIterationsAutoApprove decides whether to auto-extend iterations in
48+
// --yolo mode. Returns maxIterContinue (approved), maxIterStop (cap reached),
49+
// or maxIterPrompt (not in auto-approve mode, caller should ask the user).
50+
func handleMaxIterationsAutoApprove(autoApprove bool, autoExtensions *int, maxIter int) maxIterAction {
51+
if !autoApprove {
52+
return maxIterPrompt
53+
}
54+
*autoExtensions++
55+
if *autoExtensions <= maxAutoExtensions {
56+
slog.Info("Auto-extending iterations in yolo mode",
57+
"extension", *autoExtensions,
58+
"max_extensions", maxAutoExtensions,
59+
"current_max", maxIter)
60+
return maxIterContinue
61+
}
62+
slog.Warn("Max auto-extensions reached in yolo mode, stopping",
63+
"total_extensions", *autoExtensions)
64+
return maxIterStop
65+
}
66+
3467
// Config holds configuration for running an agent in CLI mode
3568
type Config struct {
3669
AppName string
@@ -60,6 +93,8 @@ func Run(ctx context.Context, out *Printer, cfg Config, rt runtime.Runtime, sess
6093
var lastErr error
6194

6295
oneLoop := func(text string, rd io.Reader) error {
96+
autoExtensions := 0
97+
6398
userInput := strings.TrimSpace(text)
6499
if userInput == "" {
65100
return nil
@@ -74,6 +109,14 @@ func Run(ctx context.Context, out *Printer, cfg Config, rt runtime.Runtime, sess
74109
if !cfg.AutoApprove {
75110
rt.Resume(ctx, runtime.ResumeReject(""))
76111
}
112+
case *runtime.MaxIterationsReachedEvent:
113+
switch handleMaxIterationsAutoApprove(cfg.AutoApprove, &autoExtensions, e.MaxIterations) {
114+
case maxIterContinue:
115+
rt.Resume(ctx, runtime.ResumeApprove())
116+
default: // maxIterStop or maxIterPrompt (no interactive prompt in JSON mode)
117+
rt.Resume(ctx, runtime.ResumeReject(""))
118+
return nil
119+
}
77120
case *runtime.ErrorEvent:
78121
return fmt.Errorf("%s", e.Error)
79122
}
@@ -153,16 +196,24 @@ func Run(ctx context.Context, out *Printer, cfg Config, rt runtime.Runtime, sess
153196
out.PrintError(lastErr)
154197
}
155198
case *runtime.MaxIterationsReachedEvent:
156-
result := out.PromptMaxIterationsContinue(ctx, e.MaxIterations)
157-
switch result {
158-
case ConfirmationApprove:
199+
switch handleMaxIterationsAutoApprove(cfg.AutoApprove, &autoExtensions, e.MaxIterations) {
200+
case maxIterContinue:
159201
rt.Resume(ctx, runtime.ResumeApprove())
160-
case ConfirmationReject:
161-
rt.Resume(ctx, runtime.ResumeReject(""))
162-
return nil
163-
case ConfirmationAbort:
202+
case maxIterStop:
164203
rt.Resume(ctx, runtime.ResumeReject(""))
165204
return nil
205+
case maxIterPrompt:
206+
result := out.PromptMaxIterationsContinue(ctx, e.MaxIterations)
207+
switch result {
208+
case ConfirmationApprove:
209+
rt.Resume(ctx, runtime.ResumeApprove())
210+
case ConfirmationReject:
211+
rt.Resume(ctx, runtime.ResumeReject(""))
212+
return nil
213+
case ConfirmationAbort:
214+
rt.Resume(ctx, runtime.ResumeReject(""))
215+
return nil
216+
}
166217
}
167218
case *runtime.ElicitationRequestEvent:
168219
serverURL, ok := e.Meta["cagent/server_url"].(string)

pkg/cli/runner_test.go

Lines changed: 205 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,205 @@
1+
package cli
2+
3+
import (
4+
"bytes"
5+
"context"
6+
"sync"
7+
"testing"
8+
9+
"gotest.tools/v3/assert"
10+
11+
"github.com/docker/cagent/pkg/runtime"
12+
"github.com/docker/cagent/pkg/session"
13+
"github.com/docker/cagent/pkg/sessiontitle"
14+
"github.com/docker/cagent/pkg/tools"
15+
mcptools "github.com/docker/cagent/pkg/tools/mcp"
16+
)
17+
18+
// mockRuntime implements runtime.Runtime for testing the CLI runner.
19+
// It emits pre-configured events from RunStream and records Resume calls.
20+
type mockRuntime struct {
21+
events []runtime.Event
22+
23+
mu sync.Mutex
24+
resumes []runtime.ResumeRequest
25+
}
26+
27+
func (m *mockRuntime) CurrentAgentName() string { return "test" }
28+
func (m *mockRuntime) CurrentAgentInfo(context.Context) runtime.CurrentAgentInfo {
29+
return runtime.CurrentAgentInfo{Name: "test"}
30+
}
31+
func (m *mockRuntime) SetCurrentAgent(string) error { return nil }
32+
func (m *mockRuntime) CurrentAgentTools(context.Context) ([]tools.Tool, error) { return nil, nil }
33+
func (m *mockRuntime) EmitStartupInfo(context.Context, chan runtime.Event) {}
34+
func (m *mockRuntime) ResetStartupInfo() {}
35+
func (m *mockRuntime) Run(context.Context, *session.Session) ([]session.Message, error) {
36+
return nil, nil
37+
}
38+
39+
func (m *mockRuntime) ResumeElicitation(context.Context, tools.ElicitationAction, map[string]any) error {
40+
return nil
41+
}
42+
func (m *mockRuntime) SessionStore() session.Store { return nil }
43+
func (m *mockRuntime) Summarize(context.Context, *session.Session, string, chan runtime.Event) {}
44+
func (m *mockRuntime) PermissionsInfo() *runtime.PermissionsInfo { return nil }
45+
func (m *mockRuntime) CurrentAgentSkillsEnabled() bool { return false }
46+
func (m *mockRuntime) CurrentMCPPrompts(context.Context) map[string]mcptools.PromptInfo {
47+
return nil
48+
}
49+
50+
func (m *mockRuntime) ExecuteMCPPrompt(context.Context, string, map[string]string) (string, error) {
51+
return "", nil
52+
}
53+
func (m *mockRuntime) UpdateSessionTitle(context.Context, *session.Session, string) error { return nil }
54+
func (m *mockRuntime) TitleGenerator() *sessiontitle.Generator { return nil }
55+
func (m *mockRuntime) Close() error { return nil }
56+
func (m *mockRuntime) RegenerateTitle(context.Context, *session.Session, chan runtime.Event) {}
57+
58+
func (m *mockRuntime) Resume(_ context.Context, req runtime.ResumeRequest) {
59+
m.mu.Lock()
60+
defer m.mu.Unlock()
61+
m.resumes = append(m.resumes, req)
62+
}
63+
64+
func (m *mockRuntime) RunStream(_ context.Context, _ *session.Session) <-chan runtime.Event {
65+
ch := make(chan runtime.Event, len(m.events))
66+
for _, e := range m.events {
67+
ch <- e
68+
}
69+
close(ch)
70+
return ch
71+
}
72+
73+
func (m *mockRuntime) getResumes() []runtime.ResumeRequest {
74+
m.mu.Lock()
75+
defer m.mu.Unlock()
76+
result := make([]runtime.ResumeRequest, len(m.resumes))
77+
copy(result, m.resumes)
78+
return result
79+
}
80+
81+
func maxIterEvent(maxIter int) *runtime.MaxIterationsReachedEvent {
82+
return &runtime.MaxIterationsReachedEvent{
83+
Type: "max_iterations_reached",
84+
MaxIterations: maxIter,
85+
}
86+
}
87+
88+
func TestMaxIterationsAutoApproveInYoloMode(t *testing.T) {
89+
t.Parallel()
90+
91+
rt := &mockRuntime{
92+
events: []runtime.Event{maxIterEvent(60)},
93+
}
94+
95+
var buf bytes.Buffer
96+
out := NewPrinter(&buf)
97+
sess := session.New()
98+
cfg := Config{AutoApprove: true}
99+
100+
err := Run(t.Context(), out, cfg, rt, sess, []string{"hello"})
101+
assert.NilError(t, err)
102+
103+
resumes := rt.getResumes()
104+
assert.Equal(t, len(resumes), 1)
105+
assert.Equal(t, resumes[0].Type, runtime.ResumeTypeApprove)
106+
}
107+
108+
func TestMaxIterationsAutoApproveSafetyCap(t *testing.T) {
109+
t.Parallel()
110+
111+
// Emit maxAutoExtensions+1 events to trigger the safety cap
112+
events := make([]runtime.Event, maxAutoExtensions+1)
113+
for i := range events {
114+
events[i] = maxIterEvent(60 + i*10)
115+
}
116+
117+
rt := &mockRuntime{events: events}
118+
119+
var buf bytes.Buffer
120+
out := NewPrinter(&buf)
121+
sess := session.New()
122+
cfg := Config{AutoApprove: true}
123+
124+
err := Run(t.Context(), out, cfg, rt, sess, []string{"hello"})
125+
assert.NilError(t, err)
126+
127+
resumes := rt.getResumes()
128+
assert.Equal(t, len(resumes), maxAutoExtensions+1)
129+
130+
// First maxAutoExtensions should be approved
131+
for i := range maxAutoExtensions {
132+
assert.Equal(t, resumes[i].Type, runtime.ResumeTypeApprove,
133+
"extension %d should be approved", i+1)
134+
}
135+
// Last one should be rejected (safety cap)
136+
assert.Equal(t, resumes[maxAutoExtensions].Type, runtime.ResumeTypeReject,
137+
"extension beyond cap should be rejected")
138+
}
139+
140+
func TestMaxIterationsAutoApproveJSONMode(t *testing.T) {
141+
t.Parallel()
142+
143+
rt := &mockRuntime{
144+
events: []runtime.Event{maxIterEvent(60)},
145+
}
146+
147+
var buf bytes.Buffer
148+
out := NewPrinter(&buf)
149+
sess := session.New()
150+
cfg := Config{AutoApprove: true, OutputJSON: true}
151+
152+
err := Run(t.Context(), out, cfg, rt, sess, []string{"hello"})
153+
assert.NilError(t, err)
154+
155+
resumes := rt.getResumes()
156+
assert.Equal(t, len(resumes), 1)
157+
assert.Equal(t, resumes[0].Type, runtime.ResumeTypeApprove)
158+
}
159+
160+
func TestMaxIterationsRejectInJSONModeWithoutYolo(t *testing.T) {
161+
t.Parallel()
162+
163+
rt := &mockRuntime{
164+
events: []runtime.Event{maxIterEvent(60)},
165+
}
166+
167+
var buf bytes.Buffer
168+
out := NewPrinter(&buf)
169+
sess := session.New()
170+
cfg := Config{AutoApprove: false, OutputJSON: true}
171+
172+
err := Run(t.Context(), out, cfg, rt, sess, []string{"hello"})
173+
assert.NilError(t, err)
174+
175+
resumes := rt.getResumes()
176+
assert.Equal(t, len(resumes), 1)
177+
assert.Equal(t, resumes[0].Type, runtime.ResumeTypeReject)
178+
}
179+
180+
func TestMaxIterationsSafetyCapJSONMode(t *testing.T) {
181+
t.Parallel()
182+
183+
events := make([]runtime.Event, maxAutoExtensions+1)
184+
for i := range events {
185+
events[i] = maxIterEvent(60 + i*10)
186+
}
187+
188+
rt := &mockRuntime{events: events}
189+
190+
var buf bytes.Buffer
191+
out := NewPrinter(&buf)
192+
sess := session.New()
193+
cfg := Config{AutoApprove: true, OutputJSON: true}
194+
195+
err := Run(t.Context(), out, cfg, rt, sess, []string{"hello"})
196+
assert.NilError(t, err)
197+
198+
resumes := rt.getResumes()
199+
assert.Equal(t, len(resumes), maxAutoExtensions+1)
200+
201+
for i := range maxAutoExtensions {
202+
assert.Equal(t, resumes[i].Type, runtime.ResumeTypeApprove)
203+
}
204+
assert.Equal(t, resumes[maxAutoExtensions].Type, runtime.ResumeTypeReject)
205+
}

0 commit comments

Comments
 (0)