Skip to content

Commit 0905f99

Browse files
authored
Merge pull request #13 from tombee/spec/SPEC-5
Add tool result streaming for long-running operations
2 parents 1623084 + 6e638f3 commit 0905f99

11 files changed

Lines changed: 3224 additions & 11 deletions

File tree

pkg/agent/agent.go

Lines changed: 92 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -171,6 +171,37 @@ type ToolExecution struct {
171171

172172
// DurationMs is the duration in milliseconds (for spec compliance)
173173
DurationMs int
174+
175+
// OutputChunks contains streaming output chunks from the tool execution
176+
OutputChunks []ToolOutputChunk
177+
}
178+
179+
// ToolOutputChunk represents a streaming output chunk from a tool execution.
180+
type ToolOutputChunk struct {
181+
// ToolCallID links to the tool call
182+
ToolCallID string
183+
184+
// ToolName is the name of the tool
185+
ToolName string
186+
187+
// Stream identifies the output stream ("stdout", "stderr", or "")
188+
Stream string
189+
190+
// Data is the chunk content
191+
Data string
192+
193+
// IsFinal indicates this is the last chunk
194+
IsFinal bool
195+
196+
// Metadata contains optional metadata
197+
Metadata map[string]interface{}
198+
}
199+
200+
// StepContext contains contextual information available during agent execution steps.
201+
// This context accumulates data across iterations and is available for reasoning.
202+
type StepContext struct {
203+
// ToolOutputChunks contains all streaming output chunks from tool executions
204+
ToolOutputChunks []ToolOutputChunk
174205
}
175206

176207
// NewAgent creates a new agent.
@@ -216,6 +247,11 @@ func (a *Agent) Run(ctx context.Context, systemPrompt string, userPrompt string)
216247
ToolExecutions: []ToolExecution{},
217248
}
218249

250+
// Initialize step context
251+
stepContext := &StepContext{
252+
ToolOutputChunks: []ToolOutputChunk{},
253+
}
254+
219255
// Initialize conversation with system and user messages
220256
messages := []Message{
221257
{Role: "system", Content: systemPrompt},
@@ -274,7 +310,7 @@ func (a *Agent) Run(ctx context.Context, systemPrompt string, userPrompt string)
274310
// Execute tool calls if any
275311
if len(response.ToolCalls) > 0 {
276312
for _, toolCall := range response.ToolCalls {
277-
execution := a.executeTool(ctx, toolCall)
313+
execution := a.executeTool(ctx, toolCall, stepContext)
278314
result.ToolExecutions = append(result.ToolExecutions, execution)
279315

280316
// Add tool result to conversation
@@ -313,11 +349,12 @@ func (a *Agent) Run(ctx context.Context, systemPrompt string, userPrompt string)
313349
return result, fmt.Errorf("max iterations reached")
314350
}
315351

316-
// executeTool executes a single tool call.
317-
func (a *Agent) executeTool(ctx context.Context, toolCall ToolCall) ToolExecution {
352+
// executeTool executes a single tool call using streaming execution.
353+
func (a *Agent) executeTool(ctx context.Context, toolCall ToolCall, stepContext *StepContext) ToolExecution {
318354
startTime := time.Now()
319355
execution := ToolExecution{
320-
ToolName: toolCall.Name,
356+
ToolName: toolCall.Name,
357+
OutputChunks: []ToolOutputChunk{},
321358
}
322359

323360
// Parse arguments
@@ -340,15 +377,62 @@ func (a *Agent) executeTool(ctx context.Context, toolCall ToolCall) ToolExecutio
340377

341378
execution.Inputs = inputs
342379

343-
// Execute tool
344-
outputs, err := a.registry.Execute(ctx, toolCall.Name, inputs)
380+
// Execute tool with streaming support
381+
chunks, err := a.registry.ExecuteStream(ctx, toolCall.Name, inputs, toolCall.ID)
382+
if err != nil {
383+
execution.Success = false
384+
execution.Status = "error"
385+
execution.Error = err.Error()
386+
execution.Duration = time.Since(startTime)
387+
execution.DurationMs = int(execution.Duration.Milliseconds())
388+
return execution
389+
}
390+
391+
// Process streaming chunks
392+
var outputs map[string]interface{}
393+
var execError error
394+
395+
for chunk := range chunks {
396+
// Create output chunk for this execution
397+
outputChunk := ToolOutputChunk{
398+
ToolCallID: toolCall.ID,
399+
ToolName: toolCall.Name,
400+
Stream: chunk.Stream,
401+
Data: chunk.Data,
402+
IsFinal: chunk.IsFinal,
403+
Metadata: chunk.Metadata,
404+
}
405+
406+
// Store chunk in execution and step context
407+
execution.OutputChunks = append(execution.OutputChunks, outputChunk)
408+
stepContext.ToolOutputChunks = append(stepContext.ToolOutputChunks, outputChunk)
409+
410+
// Emit event via callback if configured
411+
if a.eventCallback != nil {
412+
a.eventCallback("tool.output", map[string]interface{}{
413+
"tool_call_id": toolCall.ID,
414+
"tool_name": toolCall.Name,
415+
"stream": chunk.Stream,
416+
"data": chunk.Data,
417+
"is_final": chunk.IsFinal,
418+
"metadata": chunk.Metadata,
419+
})
420+
}
421+
422+
// Extract final result
423+
if chunk.IsFinal {
424+
outputs = chunk.Result
425+
execError = chunk.Error
426+
}
427+
}
428+
345429
execution.Duration = time.Since(startTime)
346430
execution.DurationMs = int(execution.Duration.Milliseconds())
347431

348-
if err != nil {
432+
if execError != nil {
349433
execution.Success = false
350434
execution.Status = "error"
351-
execution.Error = err.Error()
435+
execution.Error = execError.Error()
352436
return execution
353437
}
354438

pkg/agent/agent_test.go

Lines changed: 190 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -560,3 +560,193 @@ func (m *mockStreamingLLMProvider) Stream(ctx context.Context, messages []Messag
560560
close(ch)
561561
return ch, nil
562562
}
563+
564+
// mockStreamingTool implements StreamingTool for testing
565+
type mockStreamingTool struct {
566+
name string
567+
chunks []tools.ToolChunk
568+
}
569+
570+
func (m *mockStreamingTool) Name() string {
571+
return m.name
572+
}
573+
574+
func (m *mockStreamingTool) Description() string {
575+
return "A mock streaming tool"
576+
}
577+
578+
func (m *mockStreamingTool) Schema() *tools.Schema {
579+
return &tools.Schema{
580+
Inputs: &tools.ParameterSchema{
581+
Type: "object",
582+
},
583+
Outputs: &tools.ParameterSchema{
584+
Type: "object",
585+
},
586+
}
587+
}
588+
589+
func (m *mockStreamingTool) Execute(ctx context.Context, inputs map[string]interface{}) (map[string]interface{}, error) {
590+
// For non-streaming execution, collect all chunks and return the final result
591+
ch, err := m.ExecuteStream(ctx, inputs)
592+
if err != nil {
593+
return nil, err
594+
}
595+
596+
var result map[string]interface{}
597+
var execError error
598+
for chunk := range ch {
599+
if chunk.IsFinal {
600+
result = chunk.Result
601+
execError = chunk.Error
602+
}
603+
}
604+
605+
if execError != nil {
606+
return nil, execError
607+
}
608+
return result, nil
609+
}
610+
611+
func (m *mockStreamingTool) ExecuteStream(ctx context.Context, inputs map[string]interface{}) (<-chan tools.ToolChunk, error) {
612+
ch := make(chan tools.ToolChunk, len(m.chunks))
613+
go func() {
614+
defer close(ch)
615+
for _, chunk := range m.chunks {
616+
ch <- chunk
617+
}
618+
}()
619+
return ch, nil
620+
}
621+
622+
func TestAgent_ToolStreamingExecution(t *testing.T) {
623+
// Create a streaming tool that emits chunks
624+
streamingTool := &mockStreamingTool{
625+
name: "streaming-tool",
626+
chunks: []tools.ToolChunk{
627+
{
628+
Data: "Line 1\n",
629+
Stream: "stdout",
630+
},
631+
{
632+
Data: "Line 2\n",
633+
Stream: "stdout",
634+
},
635+
{
636+
Data: "Error message\n",
637+
Stream: "stderr",
638+
},
639+
{
640+
IsFinal: true,
641+
Result: map[string]interface{}{
642+
"exit_code": 0,
643+
"duration": 100,
644+
},
645+
},
646+
},
647+
}
648+
649+
registry := tools.NewRegistry()
650+
if err := registry.Register(streamingTool); err != nil {
651+
t.Fatalf("Failed to register tool: %v", err)
652+
}
653+
654+
llm := &mockLLMProvider{
655+
responses: []Response{
656+
{
657+
Content: "Using streaming tool",
658+
FinishReason: "tool_calls",
659+
ToolCalls: []ToolCall{
660+
{
661+
ID: "call-1",
662+
Name: "streaming-tool",
663+
Arguments: map[string]interface{}{},
664+
},
665+
},
666+
Usage: TokenUsage{TotalTokens: 10},
667+
},
668+
{
669+
Content: "Completed with streaming output",
670+
FinishReason: "stop",
671+
Usage: TokenUsage{TotalTokens: 10},
672+
},
673+
},
674+
}
675+
676+
// Track events emitted via callback
677+
var capturedEvents []map[string]interface{}
678+
agent := NewAgent(llm, registry).WithEventCallback(func(eventType string, data interface{}) {
679+
if eventType == "tool.output" {
680+
if eventData, ok := data.(map[string]interface{}); ok {
681+
capturedEvents = append(capturedEvents, eventData)
682+
}
683+
}
684+
})
685+
686+
ctx := context.Background()
687+
result, err := agent.Run(ctx, "System", "Task")
688+
if err != nil {
689+
t.Fatalf("Run() error = %v", err)
690+
}
691+
692+
// Verify tool execution
693+
if len(result.ToolExecutions) != 1 {
694+
t.Fatalf("ToolExecutions count = %d, want 1", len(result.ToolExecutions))
695+
}
696+
697+
execution := result.ToolExecutions[0]
698+
699+
// Verify output chunks are captured in execution
700+
if len(execution.OutputChunks) != 4 {
701+
t.Errorf("OutputChunks count = %d, want 4", len(execution.OutputChunks))
702+
}
703+
704+
// Verify chunk content
705+
if execution.OutputChunks[0].Data != "Line 1\n" {
706+
t.Errorf("Chunk 0 data = %q, want %q", execution.OutputChunks[0].Data, "Line 1\n")
707+
}
708+
if execution.OutputChunks[0].Stream != "stdout" {
709+
t.Errorf("Chunk 0 stream = %q, want %q", execution.OutputChunks[0].Stream, "stdout")
710+
}
711+
712+
if execution.OutputChunks[2].Stream != "stderr" {
713+
t.Errorf("Chunk 2 stream = %q, want %q", execution.OutputChunks[2].Stream, "stderr")
714+
}
715+
716+
// Verify final chunk
717+
if !execution.OutputChunks[3].IsFinal {
718+
t.Error("Last chunk should have IsFinal=true")
719+
}
720+
721+
// Verify events were emitted
722+
if len(capturedEvents) != 4 {
723+
t.Errorf("Captured %d events, want 4", len(capturedEvents))
724+
}
725+
726+
// Verify event structure
727+
if len(capturedEvents) > 0 {
728+
firstEvent := capturedEvents[0]
729+
if firstEvent["tool_name"] != "streaming-tool" {
730+
t.Errorf("Event tool_name = %q, want %q", firstEvent["tool_name"], "streaming-tool")
731+
}
732+
if firstEvent["tool_call_id"] != "call-1" {
733+
t.Errorf("Event tool_call_id = %q, want %q", firstEvent["tool_call_id"], "call-1")
734+
}
735+
if firstEvent["data"] != "Line 1\n" {
736+
t.Errorf("Event data = %q, want %q", firstEvent["data"], "Line 1\n")
737+
}
738+
}
739+
740+
// Verify execution succeeded
741+
if !execution.Success {
742+
t.Error("Tool execution should have succeeded")
743+
}
744+
745+
// Verify final result
746+
if execution.Outputs == nil {
747+
t.Error("Execution outputs should not be nil")
748+
}
749+
if exitCode, ok := execution.Outputs["exit_code"].(int); !ok || exitCode != 0 {
750+
t.Errorf("Exit code = %v, want 0", execution.Outputs["exit_code"])
751+
}
752+
}

0 commit comments

Comments
 (0)