@@ -11,19 +11,18 @@ import (
1111 "net/http"
1212 "net/url"
1313 "path"
14- "strings"
1514 "time"
1615
1716 "github.com/docker/cagent/pkg/api"
1817 v2 "github.com/docker/cagent/pkg/config/v2"
1918 "github.com/docker/cagent/pkg/session"
20- "github.com/docker/cagent/pkg/tools"
2119)
2220
2321// Client is an HTTP client for the cagent server API
2422type Client struct {
2523 baseURL * url.URL
2624 httpClient * http.Client
25+ registry map [string ]func () Event
2726}
2827
2928// ClientOption is a function for configuring the Client
@@ -58,6 +57,25 @@ func NewClient(baseURL string, opts ...ClientOption) (*Client, error) {
5857 httpClient : & http.Client {
5958 Timeout : 30 * time .Second ,
6059 },
60+ registry : map [string ]func () Event {
61+ "user_message" : func () Event { return & UserMessageEvent {} },
62+ "tool_call_confirmation" : func () Event { return & ToolCallConfirmationEvent {} },
63+ "partial_tool_call" : func () Event { return & PartialToolCallEvent {} },
64+ "tool_call" : func () Event { return & ToolCallEvent {} },
65+ "tool_call_response" : func () Event { return & ToolCallResponseEvent {} },
66+ "agent_choice_reasoning" : func () Event { return & AgentChoiceReasoningEvent {} },
67+ "agent_choice" : func () Event { return & AgentChoiceEvent {} },
68+ "stream_started" : func () Event { return & StreamStartedEvent {} },
69+ "stream_stopped" : func () Event { return & StreamStoppedEvent {} },
70+ "authorization_required" : func () Event { return & AuthorizationRequiredEvent {} },
71+ "session_compaction" : func () Event { return & SessionCompactionEvent {} },
72+ "token_usage" : func () Event { return & TokenUsageEvent {} },
73+ "max_iterations_reached" : func () Event { return & MaxIterationsReachedEvent {} },
74+ "session_title" : func () Event { return & SessionTitleEvent {} },
75+ "session_summary" : func () Event { return & SessionSummaryEvent {} },
76+ "shell" : func () Event { return & ShellOutputEvent {} },
77+ "error" : func () Event { return & ErrorEvent {} },
78+ },
6179 }
6280
6381 for _ , opt := range opts {
@@ -72,20 +90,6 @@ type ErrorResponse struct {
7290 Error string `json:"error"`
7391}
7492
75- // parseToolCall safely converts an any to tools.ToolCall
76- func parseToolCall (data , toolDefinition []byte ) (tools.ToolCall , tools.Tool , error ) {
77- var toolCall tools.ToolCall
78- if err := json .Unmarshal (data , & toolCall ); err != nil {
79- return tools.ToolCall {}, tools.Tool {}, fmt .Errorf ("failed to unmarshal tool call: %w" , err )
80- }
81- var toolDef tools.Tool
82- if err := json .Unmarshal (toolDefinition , & toolDef ); err != nil {
83- return tools.ToolCall {}, tools.Tool {}, fmt .Errorf ("failed to unmarshal tool definition: %w" , err )
84- }
85-
86- return toolCall , toolDef , nil
87- }
88-
8993// doRequest performs an HTTP request and handles common response patterns
9094func (c * Client ) doRequest (ctx context.Context , method , endpoint string , body , result any ) error {
9195 var reqBody io.Reader
@@ -321,73 +325,41 @@ func (c *Client) runAgentWithAgentName(ctx context.Context, sessionID, agent, ag
321325
322326 scanner := bufio .NewScanner (resp .Body )
323327 for scanner .Scan () {
324- line := scanner .Text ()
328+ line := scanner .Bytes ()
329+ if len (line ) == 0 || line [0 ] == ':' {
330+ continue
331+ }
332+
333+ after , ok := bytes .CutPrefix (line , []byte ("data: " ))
334+ if ! ok {
335+ continue
336+ }
337+
338+ slog .Debug ("event" , "event" , string (after ))
325339
326- if line == "" || strings .HasPrefix (line , ":" ) {
340+ // First unmarshal to get the type
341+ var baseEvent struct {
342+ Type string `json:"type"`
343+ }
344+ if err := json .Unmarshal (after , & baseEvent ); err != nil {
345+ slog .Debug ("event" , "error" , err )
327346 continue
328347 }
329348
330- if after , ok := strings .CutPrefix (line , "data: " ); ok {
331- var event map [string ]any
332- if err := json .Unmarshal ([]byte (after ), & event ); err != nil {
333- continue
334- }
335-
336- slog .Debug ("event" , "event" , after )
337-
338- switch event ["type" ] {
339- case "user_message" :
340- eventChan <- UserMessage (event ["message" ].(string ))
341- case "tool_call_confirmation" :
342- if toolCall , toolDef , err := parseToolCall (event ["tool_call" ].([]byte ), event ["tool_definition" ].([]byte )); err == nil {
343- eventChan <- ToolCallConfirmation (toolCall , toolDef , event ["agent_name" ].(string ))
344- }
345- case "partial_tool_call" :
346- if toolCall , toolDef , err := parseToolCall (event ["tool_call" ].([]byte ), event ["tool_definition" ].([]byte )); err == nil {
347- eventChan <- PartialToolCall (toolCall , toolDef , event ["agent_name" ].(string ))
348- }
349- case "tool_call" :
350- if toolCall , toolDef , err := parseToolCall (event ["tool_call" ].([]byte ), event ["tool_definition" ].([]byte )); err == nil {
351- eventChan <- ToolCall (toolCall , toolDef , event ["agent_name" ].(string ))
352- }
353- case "tool_call_response" :
354- if toolCall , _ , err := parseToolCall (event ["tool_call" ].([]byte ), event ["tool_definition" ].([]byte )); err == nil {
355- eventChan <- ToolCallResponse (toolCall , event ["response" ].(string ), event ["agent_name" ].(string ))
356- }
357- case "agent_choice_reasoning" :
358- eventChan <- AgentChoiceReasoning (event ["agent_name" ].(string ), event ["content" ].(string ))
359- case "agent_choice" :
360- eventChan <- AgentChoice (event ["agent_name" ].(string ), event ["content" ].(string ))
361- case "stream_started" :
362- eventChan <- StreamStarted (sessionID , event ["agent_name" ].(string ))
363- case "stream_stopped" :
364- eventChan <- StreamStopped (sessionID , event ["agent_name" ].(string ))
365- case "authorization_required" :
366- eventChan <- AuthorizationRequired (event ["server_url" ].(string ), event ["server_type" ].(string ), event ["confirmation" ].(string ), event ["agent_name" ].(string ))
367- case "session_compaction" :
368- eventChan <- SessionCompaction (event ["session_id" ].(string ), event ["status" ].(string ), event ["agent_name" ].(string ))
369- case "token_usage" :
370- usage := event ["usage" ].(map [string ]any )
371- inputTokens , _ := usage ["input_tokens" ].(float64 )
372- outputTokens , _ := usage ["output_tokens" ].(float64 )
373- contextLength , _ := usage ["context_length" ].(float64 )
374- contextLimit , _ := usage ["context_limit" ].(float64 )
375- cost , _ := usage ["cost" ].(float64 )
376-
377- eventChan <- TokenUsage (int (inputTokens ), int (outputTokens ), int (contextLength ), int (contextLimit ), cost )
378- case "max_iterations_reached" :
379- maxIterations , _ := event ["max_iterations" ].(float64 )
380- eventChan <- MaxIterationsReached (int (maxIterations ))
381- case "session_title" :
382- eventChan <- SessionTitle (event ["session_id" ].(string ), event ["title" ].(string ), event ["agent_name" ].(string ))
383- case "session_summary" :
384- eventChan <- SessionSummary (event ["session_id" ].(string ), event ["summary" ].(string ), event ["agent_name" ].(string ))
385- case "shell" :
386- eventChan <- ShellOutput (event ["output" ].(string ))
387- case "error" :
388- eventChan <- Error (event ["error" ].(string ))
389- }
349+ // Then unmarshal the full event
350+ createEvent , found := c .registry [baseEvent .Type ]
351+ if ! found {
352+ slog .Debug ("event" , "invalid_type" , baseEvent .Type )
353+ continue
390354 }
355+
356+ e := createEvent ()
357+ if err := json .Unmarshal (after , & e ); err != nil {
358+ slog .Debug ("event" , "error" , err )
359+ continue
360+ }
361+
362+ eventChan <- e
391363 }
392364
393365 if err := scanner .Err (); err != nil {
0 commit comments