Skip to content

Commit 5796f2c

Browse files
committed
Fix when compaction occurs
This is to avoid breaking message sequencing if compaction occurs before all tools of a given assistant message have been processed Signed-off-by: Christopher Petito <chrisjpetito@gmail.com>
1 parent 7b74cda commit 5796f2c

2 files changed

Lines changed: 153 additions & 2 deletions

File tree

pkg/runtime/runtime.go

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -369,6 +369,23 @@ func (r *runtime) RunStream(ctx context.Context, sess *session.Session) <-chan E
369369
events <- TokenUsage(sess.InputTokens, sess.OutputTokens, sess.InputTokens+sess.OutputTokens, contextLimit, sess.Cost)
370370

371371
if m != nil && r.sessionCompaction {
372+
if sess.InputTokens+sess.OutputTokens > int(float64(contextLimit)*0.9) {
373+
// Avoid inserting a summary between assistant tool_use and tool_result messages.
374+
// Defer compaction until after tool calls are processed in this iteration.
375+
if len(res.Calls) == 0 {
376+
events <- SessionCompaction(sess.ID, "start", r.currentAgent)
377+
r.Summarize(ctx, sess, events)
378+
events <- TokenUsage(sess.InputTokens, sess.OutputTokens, sess.InputTokens+sess.OutputTokens, contextLimit, sess.Cost)
379+
events <- SessionCompaction(sess.ID, "completed", r.currentAgent)
380+
}
381+
}
382+
}
383+
384+
r.processToolCalls(ctx, sess, res.Calls, agentTools, events)
385+
386+
// If tool_use occurred, perform compaction after tool results are appended
387+
// to avoid splitting assistant tool_use and user tool_result adjacency.
388+
if m != nil && r.sessionCompaction && len(res.Calls) > 0 {
372389
if sess.InputTokens+sess.OutputTokens > int(float64(contextLimit)*0.9) {
373390
events <- SessionCompaction(sess.ID, "start", r.currentAgent)
374391
r.Summarize(ctx, sess, events)
@@ -381,8 +398,6 @@ func (r *runtime) RunStream(ctx context.Context, sess *session.Session) <-chan E
381398
slog.Debug("Conversation stopped", "agent", a.Name())
382399
break
383400
}
384-
385-
r.processToolCalls(ctx, sess, res.Calls, agentTools, events)
386401
}
387402
}()
388403

pkg/runtime/runtime_test.go

Lines changed: 136 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -402,6 +402,101 @@ func TestToolCallVariations(t *testing.T) {
402402
}
403403
}
404404

405+
// queueProvider returns a different stream on each CreateChatCompletionStream call.
406+
type queueProvider struct {
407+
id string
408+
streams []chat.MessageStream
409+
}
410+
411+
func (p *queueProvider) ID() string { return p.id }
412+
413+
func (p *queueProvider) CreateChatCompletionStream(context.Context, []chat.Message, []tools.Tool) (chat.MessageStream, error) {
414+
if len(p.streams) == 0 {
415+
return &mockStream{}, nil
416+
}
417+
s := p.streams[0]
418+
p.streams = p.streams[1:]
419+
return s, nil
420+
}
421+
422+
func (p *queueProvider) Options() options.ModelOptions { return options.ModelOptions{} }
423+
424+
type mockModelStoreWithLimit struct{ limit int }
425+
426+
func (m mockModelStoreWithLimit) GetModel(context.Context, string) (*modelsdev.Model, error) {
427+
return &modelsdev.Model{Limit: modelsdev.Limit{Context: m.limit}, Cost: &modelsdev.Cost{}}, nil
428+
}
429+
430+
func TestCompactionOccursAfterToolResultsWhenToolUsePresent(t *testing.T) {
431+
// First stream: assistant issues a tool call and usage exceeds 90% threshold
432+
mainStream := newStreamBuilder().
433+
AddToolCallName("call_1", "test_tool").
434+
AddToolCallArguments("call_1", "{}").
435+
AddStopWithUsage(95, 0). // Context limit will be 100
436+
Build()
437+
438+
// Second stream: summary generation (simple content)
439+
summaryStream := newStreamBuilder().
440+
AddContent("summary").
441+
AddStopWithUsage(1, 1).
442+
Build()
443+
444+
prov := &queueProvider{id: "test/mock-model", streams: []chat.MessageStream{mainStream, summaryStream}}
445+
446+
// Provide an agent tool that will satisfy the tool call without requiring approvals
447+
testTool := tools.Tool{
448+
Name: "test_tool",
449+
Description: "test",
450+
Parameters: map[string]any{},
451+
Annotations: tools.ToolAnnotations{ReadOnlyHint: true},
452+
Handler: func(ctx context.Context, call tools.ToolCall) (*tools.ToolCallResult, error) {
453+
return &tools.ToolCallResult{Output: "ok"}, nil
454+
},
455+
}
456+
457+
root := agent.New("root", "You are a test agent",
458+
agent.WithModel(prov),
459+
agent.WithTools(testTool),
460+
)
461+
tm := team.New(team.WithAgents(root))
462+
463+
// Enable compaction and provide a model store with context limit = 100
464+
rt, err := New(tm, WithSessionCompaction(true), WithModelStore(mockModelStoreWithLimit{limit: 100}))
465+
require.NoError(t, err)
466+
467+
sess := session.New(session.WithUserMessage("", "Start"))
468+
events := rt.RunStream(t.Context(), sess)
469+
470+
// Collect events
471+
var seen []Event
472+
for ev := range events {
473+
seen = append(seen, ev)
474+
}
475+
476+
// Find indices of ToolCallResponse and compaction start (from RunStream)
477+
toolRespIdx := -1
478+
compactionStartIdx := -1
479+
for i, ev := range seen {
480+
switch e := ev.(type) {
481+
case *ToolCallResponseEvent:
482+
if toolRespIdx == -1 {
483+
toolRespIdx = i
484+
}
485+
case *SessionCompactionEvent:
486+
// We only want the RunStream-level "start" status (not Summarize's "started")
487+
if e.Status == "start" && compactionStartIdx == -1 {
488+
compactionStartIdx = i
489+
}
490+
}
491+
}
492+
493+
require.NotEqual(t, -1, toolRespIdx, "expected a ToolCallResponseEvent")
494+
require.NotEqual(t, -1, compactionStartIdx, "expected a SessionCompaction start event")
495+
496+
// Assert compaction is triggered only after tool results have been appended
497+
require.Greater(t, compactionStartIdx, toolRespIdx, "compaction should occur after tool results when tool_use is present")
498+
}
499+
405500
func TestSessionWithoutUserMessage(t *testing.T) {
406501
stream := newStreamBuilder().AddContent("OK").AddStopWithUsage(1, 1).Build()
407502

@@ -434,3 +529,44 @@ func TestNewRuntime_InvalidCurrentAgentError(t *testing.T) {
434529
require.Contains(t, err.Error(), "agent \"other\" not found")
435530
require.Contains(t, err.Error(), "root") // available agents listed in error
436531
}
532+
533+
func TestProcessToolCalls_UnknownTool_NoToolResultMessage(t *testing.T) {
534+
// Build a runtime with a simple agent but no tools registered matching the call
535+
root := agent.New("root", "You are a test agent")
536+
tm := team.New(team.WithAgents(root))
537+
538+
rt, err := New(tm, WithSessionCompaction(false), WithModelStore(mockModelStore{}))
539+
require.NoError(t, err)
540+
541+
// Register default tools (contains only transfer_task) to ensure unknown tool isn't matched
542+
rt.(*runtime).registerDefaultTools()
543+
544+
sess := session.New(session.WithUserMessage("", "Start"))
545+
546+
// Simulate a model-issued tool call to a non-existent tool
547+
calls := []tools.ToolCall{{
548+
ID: "tool-unknown-1",
549+
Type: "function",
550+
Function: tools.FunctionCall{Name: "non_existent_tool", Arguments: "{}"},
551+
}}
552+
553+
events := make(chan Event, 10)
554+
555+
// No agentTools provided and runtime toolMap doesn't have this tool name
556+
rt.(*runtime).processToolCalls(t.Context(), sess, calls, nil, events)
557+
558+
// Drain events channel
559+
close(events)
560+
for range events {
561+
}
562+
563+
// Verify no tool result message was added for the unknown tool
564+
var sawToolMsg bool
565+
for _, it := range sess.Messages {
566+
if it.IsMessage() && it.Message.Message.Role == chat.MessageRoleTool && it.Message.Message.ToolCallID == "tool-unknown-1" {
567+
sawToolMsg = true
568+
break
569+
}
570+
}
571+
require.False(t, sawToolMsg, "no tool result should be added for unknown tool; this reproduces invalid sequencing state")
572+
}

0 commit comments

Comments
 (0)