Skip to content

Commit 3ab4dfe

Browse files
authored
Merge pull request #1037 from rumpl/extract-title-gen
Extract the title generation from the runtime
2 parents 10fc93a + 8bdfa76 commit 3ab4dfe

8 files changed

Lines changed: 290 additions & 266 deletions

File tree

e2e/runtime_test.go

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -31,8 +31,8 @@ func TestRuntime_OpenAI_Basic(t *testing.T) {
3131
require.NoError(t, err)
3232

3333
response := sess.GetLastAssistantMessageContent()
34-
assert.Equal(t, "2 + 2 is equal to 4.", response)
35-
assert.Equal(t, "Simple Math Calculation", sess.Title)
34+
assert.Equal(t, "2 + 2 equals 4.", response)
35+
assert.Equal(t, "Basic Math Question", sess.Title)
3636
}
3737

3838
func TestRuntime_Mistral_Basic(t *testing.T) {
@@ -55,5 +55,5 @@ func TestRuntime_Mistral_Basic(t *testing.T) {
5555

5656
response := sess.GetLastAssistantMessageContent()
5757
assert.Equal(t, "The sum of 2 + 2 is 4.", response)
58-
assert.Equal(t, "Basic Arithmetic: Sum of 2 and 2", sess.Title)
58+
assert.Equal(t, "Math Basics: Simple Addition", sess.Title)
5959
}

e2e/testdata/cassettes/TestRuntime_Mistral_Basic.yaml

Lines changed: 88 additions & 80 deletions
Large diffs are not rendered by default.

e2e/testdata/cassettes/TestRuntime_OpenAI_Basic.yaml

Lines changed: 78 additions & 40 deletions
Large diffs are not rendered by default.

pkg/runtime/event.go

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -219,12 +219,11 @@ type SessionTitleEvent struct {
219219
AgentContext
220220
}
221221

222-
func SessionTitle(sessionID, title, agentName string) Event {
222+
func SessionTitle(sessionID, title string) Event {
223223
return &SessionTitleEvent{
224-
Type: "session_title",
225-
SessionID: sessionID,
226-
Title: title,
227-
AgentContext: AgentContext{AgentName: agentName},
224+
Type: "session_title",
225+
SessionID: sessionID,
226+
Title: title,
228227
}
229228
}
230229

pkg/runtime/runtime.go

Lines changed: 10 additions & 73 deletions
Original file line numberDiff line numberDiff line change
@@ -132,7 +132,7 @@ type LocalRuntime struct {
132132
elicitationEventsChannel chan Event // Current events channel for sending elicitation requests
133133
elicitationEventsChannelMux sync.RWMutex // Protects elicitationEventsChannel
134134
ragInitialized atomic.Bool
135-
titleGenerationWg sync.WaitGroup // Wait group for title generation
135+
titleGen *titleGenerator
136136
}
137137

138138
type streamResult struct {
@@ -210,6 +210,13 @@ func New(agents *team.Team, opts ...Opt) (*LocalRuntime, error) {
210210
return nil, err
211211
}
212212

213+
model := agents.Model()
214+
if model == nil {
215+
return nil, errors.New("no model found for the team; ensure at least one agent has a valid model")
216+
}
217+
218+
r.titleGen = newTitleGenerator(model)
219+
213220
slog.Debug("Creating new runtime", "agent", r.currentAgent, "available_agents", agents.Size())
214221

215222
return r, nil
@@ -488,8 +495,7 @@ func (r *LocalRuntime) finalizeEventChannel(ctx context.Context, sess *session.S
488495

489496
telemetry.RecordSessionEnd(ctx)
490497

491-
// Wait for title generation if it's in progress
492-
r.titleGenerationWg.Wait()
498+
r.titleGen.Wait()
493499
}
494500

495501
// RunStream starts the agent's interaction loop and returns a channel of events
@@ -543,7 +549,6 @@ func (r *LocalRuntime) RunStream(ctx context.Context, sess *session.Session) <-c
543549
return
544550
}
545551

546-
// Emit toolset information
547552
events <- ToolsetInfo(len(agentTools), r.currentAgent)
548553

549554
messages := sess.GetMessages(a)
@@ -558,9 +563,7 @@ func (r *LocalRuntime) RunStream(ctx context.Context, sess *session.Session) <-c
558563
r.registerDefaultTools()
559564

560565
if sess.Title == "" {
561-
r.titleGenerationWg.Go(func() {
562-
r.generateSessionTitle(ctx, sess, events)
563-
})
566+
r.titleGen.Generate(ctx, sess, events)
564567
}
565568

566569
iteration := 0
@@ -1353,72 +1356,6 @@ func (r *LocalRuntime) handleHandoff(_ context.Context, _ *session.Session, tool
13531356
}, nil
13541357
}
13551358

1356-
// truncateTitle truncates a title to maxLength characters, adding an ellipsis if needed
1357-
func truncateTitle(title string, maxLength int) string {
1358-
if len(title) <= maxLength {
1359-
return title
1360-
}
1361-
// Ensure we have room for the ellipsis
1362-
if maxLength < 3 {
1363-
return "..."
1364-
}
1365-
return title[:maxLength-3] + "..."
1366-
}
1367-
1368-
// generateSessionTitle generates a title for the session based on the first user message
1369-
func (r *LocalRuntime) generateSessionTitle(ctx context.Context, sess *session.Session, events chan Event) {
1370-
slog.Debug("Generating title for session", "session_id", sess.ID)
1371-
1372-
firstUserMessage := sess.GetLastUserMessageContent()
1373-
if firstUserMessage == "" {
1374-
slog.Error("Failed generating session title: no user message found in session", "session_id", sess.ID)
1375-
events <- SessionTitle(sess.ID, "Untitled", r.currentAgent)
1376-
return
1377-
}
1378-
1379-
systemPrompt := "You are a helpful AI assistant that generates concise, descriptive titles for conversations. You will be given a conversation history and asked to create a title that captures the main topic."
1380-
userPrompt := fmt.Sprintf("Based on the following message a user sent to an AI assistant, generate a short, descriptive title (maximum 50 characters) that captures the main topic or purpose of the conversation. Return ONLY the title text, nothing else.\n\nUser message: %s\n\n", firstUserMessage)
1381-
1382-
titleModel := provider.CloneWithOptions(
1383-
ctx,
1384-
r.CurrentAgent().Model(),
1385-
options.WithStructuredOutput(nil),
1386-
options.WithMaxTokens(100),
1387-
options.WithGeneratingTitle(),
1388-
)
1389-
newTeam := team.New(
1390-
team.WithAgents(agent.New("root", systemPrompt, agent.WithModel(titleModel))),
1391-
)
1392-
titleSession := session.New(
1393-
session.WithUserMessage(userPrompt),
1394-
session.WithTitle("Generating title..."),
1395-
)
1396-
1397-
titleRuntime, err := New(newTeam, WithSessionCompaction(false))
1398-
if err != nil {
1399-
slog.Error("Failed to create title generator runtime", "error", err)
1400-
return
1401-
}
1402-
1403-
// Run the title generation (this will be a simple back-and-forth)
1404-
_, err = titleRuntime.Run(ctx, titleSession)
1405-
if err != nil {
1406-
slog.Error("Failed to generate session title", "session_id", sess.ID, "error", err)
1407-
return
1408-
}
1409-
1410-
// Get the generated title from the last assistant message
1411-
title := titleSession.GetLastAssistantMessageContent()
1412-
if title == "" {
1413-
return
1414-
}
1415-
// Truncate title to 50 characters with ellipsis if needed
1416-
title = truncateTitle(title, 50)
1417-
sess.Title = title
1418-
slog.Debug("Generated session title", "session_id", sess.ID, "title", title)
1419-
events <- SessionTitle(sess.ID, title, r.currentAgent)
1420-
}
1421-
14221359
// Summarize generates a summary for the session based on the conversation history
14231360
func (r *LocalRuntime) Summarize(ctx context.Context, sess *session.Session, events chan Event) {
14241361
slog.Debug("Generating summary for session", "session_id", sess.ID)

pkg/runtime/runtime_test.go

Lines changed: 2 additions & 65 deletions
Original file line numberDiff line numberDiff line change
@@ -699,7 +699,7 @@ func TestGetTools_WarningHandling(t *testing.T) {
699699

700700
for _, tt := range tests {
701701
t.Run(tt.name, func(t *testing.T) {
702-
root := agent.New("root", "test", agent.WithToolSets(tt.toolsets...))
702+
root := agent.New("root", "test", agent.WithToolSets(tt.toolsets...), agent.WithModel(&mockProvider{}))
703703
tm := team.New(team.WithAgents(root))
704704
rt, err := New(tm, WithModelStore(mockModelStore{}))
705705
require.NoError(t, err)
@@ -769,7 +769,7 @@ func TestSummarize_EmptySession(t *testing.T) {
769769

770770
func TestProcessToolCalls_UnknownTool_NoToolResultMessage(t *testing.T) {
771771
// Build a runtime with a simple agent but no tools registered matching the call
772-
root := agent.New("root", "You are a test agent")
772+
root := agent.New("root", "You are a test agent", agent.WithModel(&mockProvider{}))
773773
tm := team.New(team.WithAgents(root))
774774

775775
rt, err := New(tm, WithSessionCompaction(false), WithModelStore(mockModelStore{}))
@@ -856,66 +856,3 @@ func TestEmitStartupInfo(t *testing.T) {
856856
// Should be empty due to deduplication
857857
require.Empty(t, collectedEvents2, "EmitStartupInfo should not emit duplicate events")
858858
}
859-
860-
func TestTruncateTitle(t *testing.T) {
861-
tests := []struct {
862-
name string
863-
title string
864-
maxLength int
865-
expected string
866-
}{
867-
{
868-
name: "title shorter than max length",
869-
title: "Short title",
870-
maxLength: 50,
871-
expected: "Short title",
872-
},
873-
{
874-
name: "title exactly at max length",
875-
title: "This is exactly fifty characters in length now.",
876-
maxLength: 50,
877-
expected: "This is exactly fifty characters in length now.",
878-
},
879-
{
880-
name: "title longer than max length",
881-
title: "This is a very long title that exceeds the maximum character limit",
882-
maxLength: 50,
883-
expected: "This is a very long title that exceeds the maxi...",
884-
},
885-
{
886-
name: "very short max length",
887-
title: "Any title",
888-
maxLength: 5,
889-
expected: "An...",
890-
},
891-
{
892-
name: "max length less than 3",
893-
title: "Any title",
894-
maxLength: 2,
895-
expected: "...",
896-
},
897-
{
898-
name: "empty title",
899-
title: "",
900-
maxLength: 50,
901-
expected: "",
902-
},
903-
{
904-
name: "title with unicode characters",
905-
title: "こんにちは、これは日本語のタイトルです。とても長いタイトルなので切り捨てられるはずです。",
906-
maxLength: 50,
907-
expected: "こんにちは、これは日本語のタイトルです。とても長いタイトルなので切り捨てられるはずです。"[:47] + "...",
908-
},
909-
}
910-
911-
for _, tt := range tests {
912-
t.Run(tt.name, func(t *testing.T) {
913-
result := truncateTitle(tt.title, tt.maxLength)
914-
require.Equal(t, tt.expected, result)
915-
// Only check length constraint if maxLength >= 3 (otherwise ellipsis alone is 3 chars)
916-
if tt.maxLength >= 3 {
917-
require.LessOrEqual(t, len(result), tt.maxLength)
918-
}
919-
})
920-
}
921-
}

pkg/runtime/title_generator.go

Lines changed: 89 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,89 @@
1+
package runtime
2+
3+
import (
4+
"context"
5+
"fmt"
6+
"log/slog"
7+
"sync"
8+
9+
"github.com/docker/cagent/pkg/agent"
10+
"github.com/docker/cagent/pkg/model/provider"
11+
"github.com/docker/cagent/pkg/model/provider/options"
12+
"github.com/docker/cagent/pkg/session"
13+
"github.com/docker/cagent/pkg/team"
14+
)
15+
16+
const (
17+
titleSystemPrompt = "You are a helpful AI assistant that generates concise, descriptive titles for conversations. You will be given a conversation history and asked to create a title that captures the main topic."
18+
titleUserPromptFormat = "Based on the following message a user sent to an AI assistant, generate a short, descriptive title (maximum 50 characters) that captures the main topic or purpose of the conversation. Return ONLY the title text, nothing else.\n\nUser message: %s\n\n"
19+
)
20+
21+
type titleGenerator struct {
22+
wg sync.WaitGroup
23+
model provider.Provider
24+
}
25+
26+
func newTitleGenerator(model provider.Provider) *titleGenerator {
27+
return &titleGenerator{
28+
model: model,
29+
}
30+
}
31+
32+
func (t *titleGenerator) Generate(ctx context.Context, sess *session.Session, events chan<- Event) {
33+
t.wg.Go(func() {
34+
t.generate(ctx, sess, events)
35+
})
36+
}
37+
38+
func (t *titleGenerator) Wait() {
39+
t.wg.Wait()
40+
}
41+
42+
func (t *titleGenerator) generate(ctx context.Context, sess *session.Session, events chan<- Event) {
43+
slog.Debug("Generating title for session", "session_id", sess.ID)
44+
45+
firstUserMessage := sess.GetLastUserMessageContent()
46+
if firstUserMessage == "" {
47+
return
48+
}
49+
50+
userPrompt := fmt.Sprintf(titleUserPromptFormat, firstUserMessage)
51+
52+
titleModel := provider.CloneWithOptions(
53+
ctx,
54+
t.model,
55+
options.WithStructuredOutput(nil),
56+
options.WithMaxTokens(20),
57+
options.WithGeneratingTitle(),
58+
)
59+
60+
newTeam := team.New(
61+
team.WithAgents(agent.New("root", titleSystemPrompt, agent.WithModel(titleModel))),
62+
)
63+
64+
titleSession := session.New(
65+
session.WithUserMessage(userPrompt),
66+
session.WithTitle("Generating title..."),
67+
)
68+
69+
titleRuntime, err := New(newTeam, WithSessionCompaction(false))
70+
if err != nil {
71+
slog.Error("Failed to create title generator runtime", "error", err)
72+
return
73+
}
74+
75+
_, err = titleRuntime.Run(ctx, titleSession)
76+
if err != nil {
77+
slog.Error("Failed to generate session title", "session_id", sess.ID, "error", err)
78+
return
79+
}
80+
81+
title := titleSession.GetLastAssistantMessageContent()
82+
if title == "" {
83+
return
84+
}
85+
86+
sess.Title = title
87+
slog.Debug("Generated session title", "session_id", sess.ID, "title", title)
88+
events <- SessionTitle(sess.ID, title)
89+
}

pkg/team/team.go

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ import (
99
"strings"
1010

1111
"github.com/docker/cagent/pkg/agent"
12+
"github.com/docker/cagent/pkg/model/provider"
1213
"github.com/docker/cagent/pkg/rag"
1314
)
1415

@@ -66,6 +67,21 @@ func (t *Team) Agent(name string) (*agent.Agent, error) {
6667
return found, nil
6768
}
6869

70+
func (t *Team) Model() provider.Provider {
71+
root, err := t.Agent("root")
72+
if err == nil {
73+
return root.Model()
74+
}
75+
76+
for _, agentName := range t.AgentNames() {
77+
a, err := t.Agent(agentName)
78+
if err == nil {
79+
return a.Model()
80+
}
81+
}
82+
return nil
83+
}
84+
6985
func (t *Team) Size() int {
7086
return len(t.agents)
7187
}

0 commit comments

Comments
 (0)