Skip to content

Commit 1ffea6c

Browse files
authored
Merge pull request #1941 from dgageot/sub-sessions-usage
Improve sub-sessions usage
2 parents 800dacd + 7fdb5df commit 1ffea6c

6 files changed

Lines changed: 547 additions & 65 deletions

File tree

pkg/runtime/runtime.go

Lines changed: 14 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -2002,25 +2002,29 @@ func (r *LocalRuntime) handleTaskTransfer(ctx context.Context, sess *session.Ses
20022002
session.WithParentID(sess.ID),
20032003
)
20042004

2005-
for event := range r.RunStream(ctx, s) {
2005+
return r.runSubSession(ctx, sess, s, span, evts, a.Name())
2006+
}
2007+
2008+
// runSubSession runs a child session within the parent, forwarding events and
2009+
// propagating state (tool approvals, thinking) back to the parent when done.
2010+
func (r *LocalRuntime) runSubSession(ctx context.Context, parent, child *session.Session, span trace.Span, evts chan Event, agentName string) (*tools.ToolCallResult, error) {
2011+
for event := range r.RunStream(ctx, child) {
20062012
evts <- event
20072013
if errEvent, ok := event.(*ErrorEvent); ok {
20082014
span.RecordError(fmt.Errorf("%s", errEvent.Error))
2009-
span.SetStatus(codes.Error, "error in transferred session")
2015+
span.SetStatus(codes.Error, "sub-session error")
20102016
return nil, fmt.Errorf("%s", errEvent.Error)
20112017
}
20122018
}
20132019

2014-
sess.ToolsApproved = s.ToolsApproved
2015-
sess.Thinking = s.Thinking
2016-
2017-
sess.AddSubSession(s)
2018-
evts <- SubSessionCompleted(sess.ID, s, a.Name())
2020+
parent.ToolsApproved = child.ToolsApproved
2021+
parent.Thinking = child.Thinking
20192022

2020-
slog.Debug("Task transfer completed", "agent", params.Agent, "task", params.Task)
2023+
parent.AddSubSession(child)
2024+
evts <- SubSessionCompleted(parent.ID, child, agentName)
20212025

2022-
span.SetStatus(codes.Ok, "task transfer completed")
2023-
return tools.ResultSuccess(s.GetLastAssistantMessageContent()), nil
2026+
span.SetStatus(codes.Ok, "sub-session completed")
2027+
return tools.ResultSuccess(child.GetLastAssistantMessageContent()), nil
20242028
}
20252029

20262030
func (r *LocalRuntime) handleHandoff(_ context.Context, _ *session.Session, toolCall tools.ToolCall, _ chan Event) (*tools.ToolCallResult, error) {

pkg/tui/components/sidebar/sidebar.go

Lines changed: 75 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -130,6 +130,7 @@ type model struct {
130130
toolsLoading bool // true when more tools may still be loading
131131
sessionState *service.SessionState
132132
workingAgent string // Name of the agent currently working (empty if none)
133+
currentSessionID string // Session ID of the currently active stream
133134
scrollview *scrollview.Model
134135
workingDirectory string
135136
queuedMessages []string // Truncated preview of queued messages
@@ -484,17 +485,22 @@ func formatCost(cost float64) string {
484485
return fmt.Sprintf("%.2f", cost)
485486
}
486487

487-
// contextPercent returns a context usage percentage string for the current agent's session.
488-
// It looks up the session belonging to the current agent; if none is found, it falls back
489-
// to returning the percentage when there is exactly one session.
490-
func (m *model) contextPercent() string {
491-
// Try to find the session belonging to the current agent.
488+
// currentSessionUsage returns the usage snapshot for the current agent's session.
489+
// It uses a 3-tier lookup: session ID (most reliable) → agent name → single-session fallback.
490+
func (m *model) currentSessionUsage() (*runtime.Usage, bool) {
491+
// Direct lookup by current session ID (most reliable, no map iteration ambiguity).
492+
if m.currentSessionID != "" {
493+
if usage, ok := m.sessionUsage[m.currentSessionID]; ok {
494+
return usage, true
495+
}
496+
}
497+
498+
// Fallback: search by current agent name.
492499
if m.currentAgent != "" {
493500
for sessionID, agentName := range m.sessionAgent {
494501
if agentName == m.currentAgent {
495-
if usage, ok := m.sessionUsage[sessionID]; ok && usage.ContextLimit > 0 {
496-
percent := (float64(usage.ContextLength) / float64(usage.ContextLimit)) * 100
497-
return fmt.Sprintf("%.0f%%", percent)
502+
if usage, ok := m.sessionUsage[sessionID]; ok {
503+
return usage, true
498504
}
499505
}
500506
}
@@ -503,12 +509,26 @@ func (m *model) contextPercent() string {
503509
// Fallback: if there's exactly one session, use it.
504510
if len(m.sessionUsage) == 1 {
505511
for _, usage := range m.sessionUsage {
506-
if usage.ContextLimit > 0 {
507-
percent := (float64(usage.ContextLength) / float64(usage.ContextLimit)) * 100
508-
return fmt.Sprintf("%.0f%%", percent)
509-
}
512+
return usage, true
510513
}
511514
}
515+
return nil, false
516+
}
517+
518+
// currentSessionTokens returns the token count for the current agent's session.
519+
func (m *model) currentSessionTokens() (tokens int64, found bool) {
520+
if usage, ok := m.currentSessionUsage(); ok {
521+
return usage.InputTokens + usage.OutputTokens, true
522+
}
523+
return 0, false
524+
}
525+
526+
// contextPercent returns a context usage percentage string for the current agent's session.
527+
func (m *model) contextPercent() string {
528+
if usage, ok := m.currentSessionUsage(); ok && usage.ContextLimit > 0 {
529+
percent := (float64(usage.ContextLength) / float64(usage.ContextLimit)) * 100
530+
return fmt.Sprintf("%.0f%%", percent)
531+
}
512532
return ""
513533
}
514534

@@ -626,6 +646,7 @@ func (m *model) Update(msg tea.Msg) (layout.Model, tea.Cmd) {
626646
// New stream starting - reset cancelled flag and enable spinner
627647
m.streamCancelled = false
628648
m.workingAgent = msg.AgentName
649+
m.currentSessionID = msg.SessionID
629650
// If title hasn't been generated yet, show the title generation spinner
630651
if !m.titleGenerated {
631652
m.titleRegenerating = true
@@ -987,22 +1008,40 @@ func (m *model) formatProgress(state *ragIndexingState) string {
9871008
return ""
9881009
}
9891010

990-
func (m *model) tokenUsage(contentWidth int) string {
991-
var totalTokens int64
992-
var totalCost float64
1011+
// usageStats holds aggregated usage statistics across all sessions, computed
1012+
// once so both tokenUsage (vertical) and tokenUsageSummary (collapsed) can
1013+
// reuse the values without duplicating the computation logic.
1014+
type usageStats struct {
1015+
tokens int64
1016+
contextPct string
1017+
totalCost float64
1018+
sessionCount int
1019+
}
1020+
1021+
func (m *model) computeUsageStats() usageStats {
1022+
var s usageStats
9931023
for _, usage := range m.sessionUsage {
994-
totalTokens += usage.InputTokens + usage.OutputTokens
995-
totalCost += usage.Cost
1024+
s.totalCost += usage.Cost
1025+
s.sessionCount++
9961026
}
1027+
s.tokens, _ = m.currentSessionTokens()
1028+
s.contextPct = m.contextPercent()
1029+
return s
1030+
}
1031+
1032+
func (m *model) tokenUsage(contentWidth int) string {
1033+
s := m.computeUsageStats()
9971034

998-
var tokenUsage strings.Builder
999-
fmt.Fprintf(&tokenUsage, "%s", formatTokenCount(totalTokens))
1000-
if ctxText := m.contextPercent(); ctxText != "" {
1001-
fmt.Fprintf(&tokenUsage, " (%s)", ctxText)
1035+
line := formatTokenCount(s.tokens)
1036+
if s.contextPct != "" {
1037+
line += " (" + s.contextPct + ")"
1038+
}
1039+
line += " " + styles.TabAccentStyle.Render("$"+formatCost(s.totalCost))
1040+
if s.sessionCount > 1 {
1041+
line += " " + styles.MutedStyle.Render(fmt.Sprintf("(%d sub-sessions)", s.sessionCount-1))
10021042
}
1003-
fmt.Fprintf(&tokenUsage, " %s", styles.TabAccentStyle.Render("$"+formatCost(totalCost)))
10041043

1005-
return m.renderTab("Token Usage", tokenUsage.String(), contentWidth)
1044+
return m.renderTab("Token Usage", line, contentWidth)
10061045
}
10071046

10081047
// tokenUsageSummary returns a single-line summary for horizontal layout.
@@ -1011,18 +1050,22 @@ func (m *model) tokenUsageSummary() string {
10111050
return ""
10121051
}
10131052

1014-
var totalTokens int64
1015-
var totalCost float64
1016-
for _, usage := range m.sessionUsage {
1017-
totalTokens += usage.InputTokens + usage.OutputTokens
1018-
totalCost += usage.Cost
1019-
}
1053+
s := m.computeUsageStats()
10201054

1021-
if ctxText := m.contextPercent(); ctxText != "" {
1022-
return fmt.Sprintf("Tokens: %s | Cost: $%s | Context: %s", formatTokenCount(totalTokens), formatCost(totalCost), ctxText)
1055+
parts := []string{fmt.Sprintf("Tokens: %s", formatTokenCount(s.tokens))}
1056+
if s.sessionCount > 1 {
1057+
if s.contextPct != "" {
1058+
parts = append(parts, fmt.Sprintf("Context: %s", s.contextPct))
1059+
}
1060+
parts = append(parts, fmt.Sprintf("Cost: $%s", formatCost(s.totalCost)), fmt.Sprintf("%d sub-sessions", s.sessionCount-1))
1061+
} else {
1062+
parts = append(parts, fmt.Sprintf("Cost: $%s", formatCost(s.totalCost)))
1063+
if s.contextPct != "" {
1064+
parts = append(parts, fmt.Sprintf("Context: %s", s.contextPct))
1065+
}
10231066
}
10241067

1025-
return fmt.Sprintf("Tokens: %s | Cost: $%s", formatTokenCount(totalTokens), formatCost(totalCost))
1068+
return strings.Join(parts, " | ")
10261069
}
10271070

10281071
func (m *model) sessionInfo(contentWidth int) string {

0 commit comments

Comments
 (0)