Skip to content

Commit b81dab1

Browse files
committed
Merge branch 'main' into fix-session-deletion
2 parents 2c2bd82 + f09b589 commit b81dab1

4 files changed

Lines changed: 99 additions & 174 deletions

File tree

cmd/root/run.go

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -361,17 +361,18 @@ func runWithoutTUI(ctx context.Context, agentFilename string, rt runtime.Runtime
361361
llmIsTyping := false
362362
var lastConfirmedToolCallID string
363363
for event := range rt.RunStream(loopCtx, sess) {
364-
if event.GetAgentName() != "" && (firstLoop || lastAgent != event.GetAgentName()) {
364+
agentName := event.GetAgentName()
365+
if agentName != "" && (firstLoop || lastAgent != agentName) {
365366
if !firstLoop {
366367
if llmIsTyping {
367368
fmt.Println()
368369
llmIsTyping = false
369370
}
370371
fmt.Println()
371372
}
372-
printAgentName(event.GetAgentName())
373+
printAgentName(agentName)
373374
firstLoop = false
374-
lastAgent = event.GetAgentName()
375+
lastAgent = agentName
375376
}
376377
switch e := event.(type) {
377378
case *runtime.AgentChoiceEvent:

pkg/runtime/client.go

Lines changed: 50 additions & 78 deletions
Original file line numberDiff line numberDiff line change
@@ -11,19 +11,18 @@ import (
1111
"net/http"
1212
"net/url"
1313
"path"
14-
"strings"
1514
"time"
1615

1716
"github.com/docker/cagent/pkg/api"
1817
v2 "github.com/docker/cagent/pkg/config/v2"
1918
"github.com/docker/cagent/pkg/session"
20-
"github.com/docker/cagent/pkg/tools"
2119
)
2220

2321
// Client is an HTTP client for the cagent server API
2422
type Client struct {
2523
baseURL *url.URL
2624
httpClient *http.Client
25+
registry map[string]func() Event
2726
}
2827

2928
// ClientOption is a function for configuring the Client
@@ -58,6 +57,25 @@ func NewClient(baseURL string, opts ...ClientOption) (*Client, error) {
5857
httpClient: &http.Client{
5958
Timeout: 30 * time.Second,
6059
},
60+
registry: map[string]func() Event{
61+
"user_message": func() Event { return &UserMessageEvent{} },
62+
"tool_call_confirmation": func() Event { return &ToolCallConfirmationEvent{} },
63+
"partial_tool_call": func() Event { return &PartialToolCallEvent{} },
64+
"tool_call": func() Event { return &ToolCallEvent{} },
65+
"tool_call_response": func() Event { return &ToolCallResponseEvent{} },
66+
"agent_choice_reasoning": func() Event { return &AgentChoiceReasoningEvent{} },
67+
"agent_choice": func() Event { return &AgentChoiceEvent{} },
68+
"stream_started": func() Event { return &StreamStartedEvent{} },
69+
"stream_stopped": func() Event { return &StreamStoppedEvent{} },
70+
"authorization_required": func() Event { return &AuthorizationRequiredEvent{} },
71+
"session_compaction": func() Event { return &SessionCompactionEvent{} },
72+
"token_usage": func() Event { return &TokenUsageEvent{} },
73+
"max_iterations_reached": func() Event { return &MaxIterationsReachedEvent{} },
74+
"session_title": func() Event { return &SessionTitleEvent{} },
75+
"session_summary": func() Event { return &SessionSummaryEvent{} },
76+
"shell": func() Event { return &ShellOutputEvent{} },
77+
"error": func() Event { return &ErrorEvent{} },
78+
},
6179
}
6280

6381
for _, opt := range opts {
@@ -72,20 +90,6 @@ type ErrorResponse struct {
7290
Error string `json:"error"`
7391
}
7492

75-
// parseToolCall safely converts an any to tools.ToolCall
76-
func parseToolCall(data, toolDefinition []byte) (tools.ToolCall, tools.Tool, error) {
77-
var toolCall tools.ToolCall
78-
if err := json.Unmarshal(data, &toolCall); err != nil {
79-
return tools.ToolCall{}, tools.Tool{}, fmt.Errorf("failed to unmarshal tool call: %w", err)
80-
}
81-
var toolDef tools.Tool
82-
if err := json.Unmarshal(toolDefinition, &toolDef); err != nil {
83-
return tools.ToolCall{}, tools.Tool{}, fmt.Errorf("failed to unmarshal tool definition: %w", err)
84-
}
85-
86-
return toolCall, toolDef, nil
87-
}
88-
8993
// doRequest performs an HTTP request and handles common response patterns
9094
func (c *Client) doRequest(ctx context.Context, method, endpoint string, body, result any) error {
9195
var reqBody io.Reader
@@ -321,73 +325,41 @@ func (c *Client) runAgentWithAgentName(ctx context.Context, sessionID, agent, ag
321325

322326
scanner := bufio.NewScanner(resp.Body)
323327
for scanner.Scan() {
324-
line := scanner.Text()
328+
line := scanner.Bytes()
329+
if len(line) == 0 || line[0] == ':' {
330+
continue
331+
}
332+
333+
after, ok := bytes.CutPrefix(line, []byte("data: "))
334+
if !ok {
335+
continue
336+
}
337+
338+
slog.Debug("event", "event", string(after))
325339

326-
if line == "" || strings.HasPrefix(line, ":") {
340+
// First unmarshal to get the type
341+
var baseEvent struct {
342+
Type string `json:"type"`
343+
}
344+
if err := json.Unmarshal(after, &baseEvent); err != nil {
345+
slog.Debug("event", "error", err)
327346
continue
328347
}
329348

330-
if after, ok := strings.CutPrefix(line, "data: "); ok {
331-
var event map[string]any
332-
if err := json.Unmarshal([]byte(after), &event); err != nil {
333-
continue
334-
}
335-
336-
slog.Debug("event", "event", after)
337-
338-
switch event["type"] {
339-
case "user_message":
340-
eventChan <- UserMessage(event["message"].(string))
341-
case "tool_call_confirmation":
342-
if toolCall, toolDef, err := parseToolCall(event["tool_call"].([]byte), event["tool_definition"].([]byte)); err == nil {
343-
eventChan <- ToolCallConfirmation(toolCall, toolDef, event["agent_name"].(string))
344-
}
345-
case "partial_tool_call":
346-
if toolCall, toolDef, err := parseToolCall(event["tool_call"].([]byte), event["tool_definition"].([]byte)); err == nil {
347-
eventChan <- PartialToolCall(toolCall, toolDef, event["agent_name"].(string))
348-
}
349-
case "tool_call":
350-
if toolCall, toolDef, err := parseToolCall(event["tool_call"].([]byte), event["tool_definition"].([]byte)); err == nil {
351-
eventChan <- ToolCall(toolCall, toolDef, event["agent_name"].(string))
352-
}
353-
case "tool_call_response":
354-
if toolCall, _, err := parseToolCall(event["tool_call"].([]byte), event["tool_definition"].([]byte)); err == nil {
355-
eventChan <- ToolCallResponse(toolCall, event["response"].(string), event["agent_name"].(string))
356-
}
357-
case "agent_choice_reasoning":
358-
eventChan <- AgentChoiceReasoning(event["agent_name"].(string), event["content"].(string))
359-
case "agent_choice":
360-
eventChan <- AgentChoice(event["agent_name"].(string), event["content"].(string))
361-
case "stream_started":
362-
eventChan <- StreamStarted(sessionID, event["agent_name"].(string))
363-
case "stream_stopped":
364-
eventChan <- StreamStopped(sessionID, event["agent_name"].(string))
365-
case "authorization_required":
366-
eventChan <- AuthorizationRequired(event["server_url"].(string), event["server_type"].(string), event["confirmation"].(string), event["agent_name"].(string))
367-
case "session_compaction":
368-
eventChan <- SessionCompaction(event["session_id"].(string), event["status"].(string), event["agent_name"].(string))
369-
case "token_usage":
370-
usage := event["usage"].(map[string]any)
371-
inputTokens, _ := usage["input_tokens"].(float64)
372-
outputTokens, _ := usage["output_tokens"].(float64)
373-
contextLength, _ := usage["context_length"].(float64)
374-
contextLimit, _ := usage["context_limit"].(float64)
375-
cost, _ := usage["cost"].(float64)
376-
377-
eventChan <- TokenUsage(int(inputTokens), int(outputTokens), int(contextLength), int(contextLimit), cost)
378-
case "max_iterations_reached":
379-
maxIterations, _ := event["max_iterations"].(float64)
380-
eventChan <- MaxIterationsReached(int(maxIterations))
381-
case "session_title":
382-
eventChan <- SessionTitle(event["session_id"].(string), event["title"].(string), event["agent_name"].(string))
383-
case "session_summary":
384-
eventChan <- SessionSummary(event["session_id"].(string), event["summary"].(string), event["agent_name"].(string))
385-
case "shell":
386-
eventChan <- ShellOutput(event["output"].(string))
387-
case "error":
388-
eventChan <- Error(event["error"].(string))
389-
}
349+
// Then unmarshal the full event
350+
createEvent, found := c.registry[baseEvent.Type]
351+
if !found {
352+
slog.Debug("event", "invalid_type", baseEvent.Type)
353+
continue
390354
}
355+
356+
e := createEvent()
357+
if err := json.Unmarshal(after, &e); err != nil {
358+
slog.Debug("event", "error", err)
359+
continue
360+
}
361+
362+
eventChan <- e
391363
}
392364

393365
if err := scanner.Err(); err != nil {

0 commit comments

Comments
 (0)