Skip to content

Commit 9a75703

Browse files
authored
Merge pull request #2118 from dgageot/board/the-showing-of-which-model-is-currently-fa66a995
Fix model name display in TUI sidebar for all model types
2 parents d871092 + 43976c0 commit 9a75703

6 files changed

Lines changed: 49 additions & 48 deletions

File tree

pkg/model/provider/rulebased/client.go

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -41,9 +41,10 @@ type ProviderFactory func(ctx context.Context, modelSpec string, models map[stri
4141
// Client implements the Provider interface for rule-based model routing.
4242
type Client struct {
4343
base.Config
44-
routes []Provider
45-
fallback Provider
46-
index bleve.Index
44+
routes []Provider
45+
fallback Provider
46+
index bleve.Index
47+
lastSelectedID string // ID of the provider selected by the most recent call
4748
}
4849

4950
// NewClient creates a new rule-based routing client.
@@ -152,6 +153,7 @@ func filterOutMaxTokens(opts []options.Opt) []options.Opt {
152153
}
153154

154155
// CreateChatCompletionStream selects a provider based on input and delegates the call.
156+
// The selected provider's ID is recorded in LastSelectedModelID.
155157
func (c *Client) CreateChatCompletionStream(
156158
ctx context.Context,
157159
messages []chat.Message,
@@ -162,15 +164,23 @@ func (c *Client) CreateChatCompletionStream(
162164
return nil, errors.New("no provider available for routing")
163165
}
164166

167+
c.lastSelectedID = provider.ID()
165168
slog.Debug("Rule-based router selected model",
166169
"router", c.ID(),
167-
"selected_model", provider.ID(),
170+
"selected_model", c.lastSelectedID,
168171
"message_count", len(messages),
169172
)
170173

171174
return provider.CreateChatCompletionStream(ctx, messages, availableTools)
172175
}
173176

177+
// LastSelectedModelID returns the ID of the provider selected by the most
178+
// recent CreateChatCompletionStream call. This allows callers to display
179+
// the YAML-configured sub-model name for rule-based routing.
180+
func (c *Client) LastSelectedModelID() string {
181+
return c.lastSelectedID
182+
}
183+
174184
// selectProvider finds the best matching provider for the messages.
175185
// Bleve returns hits sorted by score, so the top hit determines the route.
176186
func (c *Client) selectProvider(messages []chat.Message) Provider {

pkg/runtime/fallback.go

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -283,6 +283,15 @@ func (r *LocalRuntime) tryModelWithFallback(
283283

284284
// Stream created successfully, now handle it
285285
slog.Debug("Processing stream", "agent", a.Name(), "model", modelEntry.provider.ID())
286+
287+
// If the provider is a rule-based router, notify the sidebar
288+
// of the selected sub-model's YAML-configured name.
289+
if rp, ok := modelEntry.provider.(interface{ LastSelectedModelID() string }); ok {
290+
if selected := rp.LastSelectedModelID(); selected != "" {
291+
events <- AgentInfo(a.Name(), selected, a.Description(), a.WelcomeMessage())
292+
}
293+
}
294+
286295
res, err := r.handleStream(ctx, stream, a, agentTools, sess, m, events)
287296
if err != nil {
288297
lastErr = err

pkg/runtime/loop.go

Lines changed: 5 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
package runtime
22

33
import (
4-
"cmp"
54
"context"
65
"errors"
76
"fmt"
@@ -86,10 +85,6 @@ func (r *LocalRuntime) RunStream(ctx context.Context, sess *session.Session) <-c
8685

8786
a := r.resolveSessionAgent(sess)
8887

89-
// Emit agent information for sidebar display
90-
// Use getEffectiveModelID to account for active fallback cooldowns
91-
events <- AgentInfo(a.Name(), r.getEffectiveModelID(a), a.Description(), a.WelcomeMessage())
92-
9388
// Emit team information
9489
events <- TeamInfo(r.agentDetailsFromTeam(), a.Name())
9590

@@ -210,7 +205,6 @@ func (r *LocalRuntime) RunStream(ctx context.Context, sess *session.Session) <-c
210205
))
211206

212207
model := a.Model()
213-
defaultModelID := r.getEffectiveModelID(a)
214208

215209
// Per-tool model routing: use a cheaper model for this turn
216210
// if the previous tool calls specified one, then reset.
@@ -236,10 +230,10 @@ func (r *LocalRuntime) RunStream(ctx context.Context, sess *session.Session) <-c
236230

237231
modelID := model.ID()
238232

239-
// Notify sidebar when this turn uses a different model (per-tool override).
240-
if modelID != defaultModelID {
241-
events <- AgentInfo(a.Name(), modelID, a.Description(), a.WelcomeMessage())
242-
}
233+
// Notify sidebar of the model for this turn. For rule-based
234+
// routing, the actual routed model is emitted from within the
235+
// stream once the first chunk arrives.
236+
events <- AgentInfo(a.Name(), modelID, a.Description(), a.WelcomeMessage())
243237

244238
slog.Debug("Using agent", "agent", a.Name(), "model", modelID)
245239
slog.Debug("Getting model definition", "model_id", modelID)
@@ -311,16 +305,9 @@ func (r *LocalRuntime) RunStream(ctx context.Context, sess *session.Session) <-c
311305
return
312306
}
313307

314-
// Update sidebar model info to reflect what was actually used this turn.
315-
// Fallback models are sticky (cooldown system persists them), so we only
316-
// emit once. Per-tool model overrides are temporary (one turn), so we
317-
// emit the override and then revert to the agent's default.
318308
if usedModel != nil && usedModel.ID() != model.ID() {
319309
slog.Info("Used fallback model", "agent", a.Name(), "primary", model.ID(), "used", usedModel.ID())
320310
events <- AgentInfo(a.Name(), usedModel.ID(), a.Description(), a.WelcomeMessage())
321-
} else if model.ID() != defaultModelID {
322-
// Per-tool override was active: revert sidebar to the agent's default model.
323-
events <- AgentInfo(a.Name(), defaultModelID, a.Description(), a.WelcomeMessage())
324311
}
325312
streamSpan.SetAttributes(
326313
attribute.Int("tool.calls", len(res.Calls)),
@@ -410,7 +397,7 @@ func (r *LocalRuntime) recordAssistantMessage(
410397
float64(res.Usage.CacheWriteTokens)*m.Cost.CacheWrite) / 1e6
411398
}
412399

413-
messageModel := cmp.Or(res.ActualModel, modelID)
400+
messageModel := modelID
414401

415402
assistantMessage := chat.Message{
416403
Role: chat.MessageRoleAssistant,

pkg/runtime/runtime_test.go

Lines changed: 15 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -276,12 +276,12 @@ func TestSimple(t *testing.T) {
276276
require.Equal(t, chat.MessageRoleAssistant, msgAdded.Message.Message.Role)
277277

278278
expectedEvents := []Event{
279-
AgentInfo("root", "test/mock-model", "", ""),
280279
TeamInfo([]AgentDetails{{Name: "root", Provider: "test", Model: "mock-model"}}, "root"),
281280
ToolsetInfo(0, false, "root"),
282281
UserMessage("Hi", sess.ID, nil, 0),
283282
StreamStarted(sess.ID, "root"),
284283
ToolsetInfo(0, false, "root"),
284+
AgentInfo("root", "test/mock-model", "", ""),
285285
AgentChoice("root", sess.ID, "Hello"),
286286
MessageAdded(sess.ID, msgAdded.Message, "root"),
287287
NewTokenUsageEvent(sess.ID, "root", &Usage{InputTokens: 3, OutputTokens: 2, ContextLength: 5, LastMessage: &MessageUsage{
@@ -315,12 +315,12 @@ func TestMultipleContentChunks(t *testing.T) {
315315
require.NotNil(t, msgAdded.Message)
316316

317317
expectedEvents := []Event{
318-
AgentInfo("root", "test/mock-model", "", ""),
319318
TeamInfo([]AgentDetails{{Name: "root", Provider: "test", Model: "mock-model"}}, "root"),
320319
ToolsetInfo(0, false, "root"),
321320
UserMessage("Please greet me", sess.ID, nil, 0),
322321
StreamStarted(sess.ID, "root"),
323322
ToolsetInfo(0, false, "root"),
323+
AgentInfo("root", "test/mock-model", "", ""),
324324
AgentChoice("root", sess.ID, "Hello "),
325325
AgentChoice("root", sess.ID, "there, "),
326326
AgentChoice("root", sess.ID, "how "),
@@ -356,12 +356,12 @@ func TestWithReasoning(t *testing.T) {
356356
require.NotNil(t, msgAdded.Message)
357357

358358
expectedEvents := []Event{
359-
AgentInfo("root", "test/mock-model", "", ""),
360359
TeamInfo([]AgentDetails{{Name: "root", Provider: "test", Model: "mock-model"}}, "root"),
361360
ToolsetInfo(0, false, "root"),
362361
UserMessage("Hi", sess.ID, nil, 0),
363362
StreamStarted(sess.ID, "root"),
364363
ToolsetInfo(0, false, "root"),
364+
AgentInfo("root", "test/mock-model", "", ""),
365365
AgentChoiceReasoning("root", sess.ID, "Let me think about this..."),
366366
AgentChoiceReasoning("root", sess.ID, " I should respond politely."),
367367
AgentChoice("root", sess.ID, "Hello, how can I help you?"),
@@ -396,12 +396,12 @@ func TestMixedContentAndReasoning(t *testing.T) {
396396
require.NotNil(t, msgAdded.Message)
397397

398398
expectedEvents := []Event{
399-
AgentInfo("root", "test/mock-model", "", ""),
400399
TeamInfo([]AgentDetails{{Name: "root", Provider: "test", Model: "mock-model"}}, "root"),
401400
ToolsetInfo(0, false, "root"),
402401
UserMessage("Hi there", sess.ID, nil, 0),
403402
StreamStarted(sess.ID, "root"),
404403
ToolsetInfo(0, false, "root"),
404+
AgentInfo("root", "test/mock-model", "", ""),
405405
AgentChoiceReasoning("root", sess.ID, "The user wants a greeting"),
406406
AgentChoice("root", sess.ID, "Hello!"),
407407
AgentChoiceReasoning("root", sess.ID, " I should be friendly"),
@@ -454,12 +454,12 @@ func TestErrorEvent(t *testing.T) {
454454
}
455455

456456
require.Len(t, events, 8)
457-
require.IsType(t, &AgentInfoEvent{}, events[0])
458-
require.IsType(t, &TeamInfoEvent{}, events[1])
459-
require.IsType(t, &ToolsetInfoEvent{}, events[2])
460-
require.IsType(t, &UserMessageEvent{}, events[3])
461-
require.IsType(t, &StreamStartedEvent{}, events[4])
462-
require.IsType(t, &ToolsetInfoEvent{}, events[5])
457+
require.IsType(t, &TeamInfoEvent{}, events[0])
458+
require.IsType(t, &ToolsetInfoEvent{}, events[1])
459+
require.IsType(t, &UserMessageEvent{}, events[2])
460+
require.IsType(t, &StreamStartedEvent{}, events[3])
461+
require.IsType(t, &ToolsetInfoEvent{}, events[4])
462+
require.IsType(t, &AgentInfoEvent{}, events[5])
463463
require.IsType(t, &ErrorEvent{}, events[6])
464464
require.IsType(t, &StreamStoppedEvent{}, events[7])
465465

@@ -493,12 +493,11 @@ func TestContextCancellation(t *testing.T) {
493493
events = append(events, ev)
494494
}
495495

496-
require.GreaterOrEqual(t, len(events), 5)
497-
require.IsType(t, &AgentInfoEvent{}, events[0])
498-
require.IsType(t, &TeamInfoEvent{}, events[1])
499-
require.IsType(t, &ToolsetInfoEvent{}, events[2])
500-
require.IsType(t, &UserMessageEvent{}, events[3])
501-
require.IsType(t, &StreamStartedEvent{}, events[4])
496+
require.GreaterOrEqual(t, len(events), 4)
497+
require.IsType(t, &TeamInfoEvent{}, events[0])
498+
require.IsType(t, &ToolsetInfoEvent{}, events[1])
499+
require.IsType(t, &UserMessageEvent{}, events[2])
500+
require.IsType(t, &StreamStartedEvent{}, events[3])
502501
require.IsType(t, &StreamStoppedEvent{}, events[len(events)-1])
503502
}
504503

pkg/runtime/streaming.go

Lines changed: 0 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,6 @@ type streamResult struct {
2626
ThinkingSignature string
2727
ThoughtSignature []byte
2828
Stopped bool
29-
ActualModel string
3029
Usage *chat.Usage
3130
RateLimit *chat.RateLimit
3231
}
@@ -43,7 +42,6 @@ func (r *LocalRuntime) handleStream(ctx context.Context, stream chat.MessageStre
4342
var thinkingSignature string
4443
var thoughtSignature []byte
4544
var toolCalls []tools.ToolCall
46-
var actualModel string
4745
var messageUsage *chat.Usage
4846
var messageRateLimit *chat.RateLimit
4947

@@ -102,11 +100,6 @@ func (r *LocalRuntime) handleStream(ctx context.Context, stream chat.MessageStre
102100
thoughtSignature = choice.Delta.ThoughtSignature
103101
}
104102

105-
// Capture the actual model from the stream response (useful for model routing)
106-
if actualModel == "" && response.Model != "" {
107-
actualModel = response.Model
108-
}
109-
110103
if choice.FinishReason == chat.FinishReasonStop || choice.FinishReason == chat.FinishReasonLength {
111104
recordUsage()
112105
return streamResult{
@@ -116,7 +109,6 @@ func (r *LocalRuntime) handleStream(ctx context.Context, stream chat.MessageStre
116109
ThinkingSignature: thinkingSignature,
117110
ThoughtSignature: thoughtSignature,
118111
Stopped: true,
119-
ActualModel: actualModel,
120112
Usage: messageUsage,
121113
RateLimit: messageRateLimit,
122114
}, nil
@@ -191,7 +183,6 @@ func (r *LocalRuntime) handleStream(ctx context.Context, stream chat.MessageStre
191183
ThinkingSignature: thinkingSignature,
192184
ThoughtSignature: thoughtSignature,
193185
Stopped: stoppedDueToNoOutput,
194-
ActualModel: actualModel,
195186
Usage: messageUsage,
196187
RateLimit: messageRateLimit,
197188
}, nil

pkg/tui/components/sidebar/sidebar.go

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -268,8 +268,13 @@ func checkReasoningSupportCmd(ctx context.Context, modelID string) tea.Cmd {
268268
}
269269
}
270270

271-
// SetAgentInfo sets the current agent information and updates the model in availableAgents
271+
// SetAgentInfo sets the current agent information and updates the model in availableAgents.
272+
// It no-ops when the values are unchanged to avoid unnecessary cache invalidation and re-renders.
272273
func (m *model) SetAgentInfo(agentName, modelID, description string) tea.Cmd {
274+
if m.currentAgent == agentName && m.agentModel == modelID && m.agentDescription == description {
275+
return nil
276+
}
277+
273278
m.currentAgent = agentName
274279
m.agentModel = modelID
275280
m.agentDescription = description

0 commit comments

Comments
 (0)