Skip to content

Commit 32d94ad

Browse files
authored
Merge pull request #484 from dgageot/ctrl-c
Handle ctrl-c for all commands
2 parents 589e96f + 451dcfc commit 32d94ad

7 files changed

Lines changed: 87 additions & 42 deletions

File tree

cmd/root/new.go

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
package root
22

33
import (
4-
"bufio"
54
"fmt"
65
"os"
76
"strings"
@@ -78,15 +77,13 @@ func NewNewCmd() *cobra.Command {
7877
if len(args) > 0 {
7978
prompt = strings.Join(args, " ")
8079
} else {
81-
reader := bufio.NewReader(os.Stdin)
82-
8380
fmt.Printf("%s\n", blue("------- Welcome to %s! -------", bold(AppName)))
8481
fmt.Printf("%s\n\n", white(" (Ctrl+C to exit)"))
8582
fmt.Printf("%s\n\n", blue("What should your agent/agent team do? (describe its purpose)"))
8683
fmt.Print(blue("> "))
8784

8885
var err error
89-
prompt, err = reader.ReadString('\n')
86+
prompt, err = readLine(ctx)
9087
if err != nil {
9188
return fmt.Errorf("failed to read purpose: %w", err)
9289
}
@@ -133,7 +130,7 @@ func NewNewCmd() *cobra.Command {
133130
llmIsTyping = false
134131
}
135132

136-
result := promptMaxIterationsContinue(e.MaxIterations)
133+
result := promptMaxIterationsContinue(ctx, e.MaxIterations)
137134
switch result {
138135
case ConfirmationApprove:
139136
rt.Resume(ctx, string(runtime.ResumeTypeApprove))

cmd/root/root.go

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,16 @@
11
package root
22

33
import (
4+
"context"
45
"errors"
56
"fmt"
67
"io"
78
"log/slog"
89
"os"
10+
"os/signal"
911
"path/filepath"
1012
"strings"
13+
"syscall"
1114

1215
"github.com/spf13/cobra"
1316

@@ -124,11 +127,10 @@ func NewRootCmd() *cobra.Command {
124127
return cmd
125128
}
126129

127-
func Run() {
128-
Execute()
129-
}
130+
func Execute() error {
131+
ctx, cancel := signal.NotifyContext(context.Background(), syscall.SIGINT, syscall.SIGTERM)
132+
defer cancel()
130133

131-
func Execute() {
132134
// Set the version for automatic telemetry initialization
133135
telemetry.SetGlobalTelemetryVersion(version.Version)
134136

@@ -147,11 +149,13 @@ We collect anonymous usage data to help improve cagent. To disable:
147149
}
148150

149151
rootCmd := NewRootCmd()
150-
if err := rootCmd.Execute(); err != nil {
152+
if err := rootCmd.ExecuteContext(ctx); err != nil {
151153
envErr := &environment.RequiredEnvError{}
152154
runtimeErr := RuntimeError{}
153155

154156
switch {
157+
case ctx.Err() != nil:
158+
return ctx.Err()
155159
case errors.As(err, &envErr):
156160
fmt.Fprintln(os.Stderr, "The following environment variables must be set:")
157161
for _, v := range envErr.Missing {
@@ -170,8 +174,10 @@ We collect anonymous usage data to help improve cagent. To disable:
170174
}
171175
}
172176

173-
os.Exit(1)
177+
return err
174178
}
179+
180+
return nil
175181
}
176182

177183
// setupLogging configures slog logging behavior.

cmd/root/run.go

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -506,7 +506,7 @@ func runWithoutTUI(ctx context.Context, agentFilename string, rt runtime.Runtime
506506
llmIsTyping = false
507507
}
508508

509-
result := promptMaxIterationsContinue(e.MaxIterations)
509+
result := promptMaxIterationsContinue(ctx, e.MaxIterations)
510510
switch result {
511511
case ConfirmationApprove:
512512
rt.Resume(ctx, string(runtime.ResumeTypeApprove))
@@ -524,11 +524,13 @@ func runWithoutTUI(ctx context.Context, agentFilename string, rt runtime.Runtime
524524
}
525525

526526
serverURL := e.Meta["cagent/server_url"].(string)
527-
result := promptOAuthAuthorization(serverURL)
528-
switch result {
529-
case ConfirmationApprove:
527+
result := promptOAuthAuthorization(ctx, serverURL)
528+
switch {
529+
case ctx.Err() != nil:
530+
return ctx.Err()
531+
case result == ConfirmationApprove:
530532
_ = rt.ResumeElicitation(ctx, "accept", nil)
531-
case ConfirmationReject:
533+
case result == ConfirmationReject:
532534
_ = rt.ResumeElicitation(ctx, "decline", nil)
533535
return fmt.Errorf("OAuth authorization rejected by user")
534536
}

cmd/root/run_text_utils.go

Lines changed: 32 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ package root
22

33
import (
44
"bufio"
5+
"context"
56
"encoding/json"
67
"fmt"
78
"os"
@@ -117,13 +118,12 @@ func printToolCallResponse(toolCall tools.ToolCall, response string) {
117118
fmt.Printf("\n%s\n", white("%s response%s", bold(toolCall.Function.Name), formatToolCallResponse(response)))
118119
}
119120

120-
func promptMaxIterationsContinue(maxIterations int) ConfirmationResult {
121+
func promptMaxIterationsContinue(ctx context.Context, maxIterations int) ConfirmationResult {
121122
fmt.Printf("\n%s\n", yellow("⚠️ Maximum iterations (%d) reached. The agent may be stuck in a loop.", maxIterations))
122123
fmt.Printf("%s\n", white("This can happen with smaller or less capable models."))
123124
fmt.Printf("\n%s (y/n): ", blue("Do you want to continue for 10 more iterations?"))
124125

125-
reader := bufio.NewReader(os.Stdin)
126-
response, err := reader.ReadString('\n')
126+
response, err := readLine(ctx)
127127
if err != nil {
128128
fmt.Printf("\n%s\n", red("Failed to read input, exiting..."))
129129
return ConfirmationAbort
@@ -139,15 +139,14 @@ func promptMaxIterationsContinue(maxIterations int) ConfirmationResult {
139139
}
140140
}
141141

142-
func promptOAuthAuthorization(serverURL string) ConfirmationResult {
142+
func promptOAuthAuthorization(ctx context.Context, serverURL string) ConfirmationResult {
143143
fmt.Printf("\n%s\n", yellow("🔐 OAuth Authorization Required"))
144144
fmt.Printf("%s %s (remote)\n", white("Server:"), blue(serverURL))
145145
fmt.Printf("%s\n", white("This server requires OAuth authentication to access its tools."))
146146
fmt.Printf("%s\n", white("Your browser will open automatically to complete the authorization."))
147147
fmt.Printf("\n%s (y/n): ", blue("Do you want to authorize access?"))
148148

149-
reader := bufio.NewReader(os.Stdin)
150-
response, err := reader.ReadString('\n')
149+
response, err := readLine(ctx)
151150
if err != nil {
152151
fmt.Printf("\n%s\n", red("Failed to read input, aborting authorization..."))
153152
return ConfirmationAbort
@@ -290,3 +289,30 @@ func formatJSONValue(key string, value any) string {
290289
return fmt.Sprintf("%s: %s", bold(key), string(jsonBytes))
291290
}
292291
}
292+
293+
func readLine(ctx context.Context) (string, error) {
294+
lines := make(chan string)
295+
errs := make(chan error)
296+
297+
go func() {
298+
defer close(lines)
299+
defer close(errs)
300+
301+
reader := bufio.NewReader(os.Stdin)
302+
line, err := reader.ReadString('\n')
303+
if err != nil {
304+
errs <- err
305+
} else {
306+
lines <- line
307+
}
308+
}()
309+
310+
select {
311+
case <-ctx.Done():
312+
return "", ctx.Err()
313+
case err := <-errs:
314+
return "", err
315+
case line := <-lines:
316+
return line, nil
317+
}
318+
}

main.go

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,13 @@
11
package main
22

33
import (
4+
"os"
5+
46
"github.com/docker/cagent/cmd/root"
57
)
68

79
func main() {
8-
root.Run()
10+
if err := root.Execute(); err != nil {
11+
os.Exit(1)
12+
}
913
}

pkg/agent/agent.go

Lines changed: 14 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@ import (
55
"fmt"
66
"log/slog"
77
"math/rand"
8-
"sync"
98

109
"github.com/docker/cagent/pkg/memorymanager"
1110
"github.com/docker/cagent/pkg/model/provider"
@@ -17,9 +16,7 @@ type Agent struct {
1716
name string
1817
description string
1918
instruction string
20-
toolsets []tools.ToolSet
21-
startedToolsets map[tools.ToolSet]bool
22-
toolsetsMutex sync.RWMutex
19+
toolsets []*StartableToolSet
2320
models []provider.Provider
2421
subAgents []*Agent
2522
parents []*Agent
@@ -36,9 +33,8 @@ type Agent struct {
3633
// New creates a new agent
3734
func New(name, prompt string, opts ...Opt) *Agent {
3835
agent := &Agent{
39-
name: name,
40-
instruction: prompt,
41-
startedToolsets: make(map[tools.ToolSet]bool),
36+
name: name,
37+
instruction: prompt,
4238
}
4339

4440
for _, opt := range opts {
@@ -143,7 +139,13 @@ func (a *Agent) ToolDisplayName(ctx context.Context, toolName string) string {
143139
}
144140

145141
func (a *Agent) ToolSets() []tools.ToolSet {
146-
return a.toolsets
142+
var toolSets []tools.ToolSet
143+
144+
for _, ts := range a.toolsets {
145+
toolSets = append(toolSets, ts)
146+
}
147+
148+
return toolSets
147149
}
148150

149151
// Commands returns the named commands configured for this agent.
@@ -152,12 +154,9 @@ func (a *Agent) Commands() map[string]string {
152154
}
153155

154156
func (a *Agent) ensureToolSetsAreStarted() error {
155-
a.toolsetsMutex.Lock()
156-
defer a.toolsetsMutex.Unlock()
157-
158157
for _, toolSet := range a.toolsets {
159158
// Skip if toolset is already started
160-
if a.startedToolsets[toolSet] {
159+
if toolSet.started.Load() {
161160
continue
162161
}
163162

@@ -172,19 +171,16 @@ func (a *Agent) ensureToolSetsAreStarted() error {
172171
}
173172

174173
// Mark toolset as started
175-
a.startedToolsets[toolSet] = true
174+
toolSet.started.Store(true)
176175
}
177176

178177
return nil
179178
}
180179

181180
func (a *Agent) StopToolSets() error {
182-
a.toolsetsMutex.Lock()
183-
defer a.toolsetsMutex.Unlock()
184-
185181
for _, toolSet := range a.toolsets {
186182
// Only stop toolsets that are marked as started
187-
if !a.startedToolsets[toolSet] {
183+
if !toolSet.started.Load() {
188184
continue
189185
}
190186

@@ -193,7 +189,7 @@ func (a *Agent) StopToolSets() error {
193189
}
194190

195191
// Mark toolset as stopped
196-
a.startedToolsets[toolSet] = false
192+
toolSet.started.Store(false)
197193
}
198194

199195
return nil

pkg/agent/opts.go

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
package agent
22

33
import (
4+
"sync/atomic"
5+
46
"github.com/docker/cagent/pkg/memorymanager"
57
"github.com/docker/cagent/pkg/model/provider"
68
"github.com/docker/cagent/pkg/tools"
@@ -15,8 +17,15 @@ func WithInstruction(prompt string) Opt {
1517
}
1618

1719
func WithToolSets(toolSet ...tools.ToolSet) Opt {
20+
var startableToolSet []*StartableToolSet
21+
for _, ts := range toolSet {
22+
startableToolSet = append(startableToolSet, &StartableToolSet{
23+
ToolSet: ts,
24+
})
25+
}
26+
1827
return func(a *Agent) {
19-
a.toolsets = toolSet
28+
a.toolsets = startableToolSet
2029
}
2130
}
2231

@@ -96,3 +105,8 @@ func WithCommands(commands map[string]string) Opt {
96105
a.commands = commands
97106
}
98107
}
108+
109+
type StartableToolSet struct {
110+
tools.ToolSet
111+
started atomic.Bool
112+
}

0 commit comments

Comments
 (0)