Skip to content

Commit e6f7898

Browse files
committed
Add mid-turn message steering for running agent sessions
Addresses #2223. Allow API clients to inject user messages into an active agent session without waiting for the current turn to finish. This is a common pattern in agentic coding tools where the user can steer or provide follow-up context while the agent is executing tool calls. New API endpoint: POST /sessions/:id/steer Runtime changes: - SteerQueue interface (Enqueue/Drain) so callers can provide their own storage implementation; default is an in-memory buffered queue - WithSteerQueue option on LocalRuntime for injecting custom implementations - Agent loop drains the queue after tool execution and before the stop-condition check; emits user_message events so clients know when the LLM actually picks them up - Messages wrapped in <system-reminder> tags for clear LLM attribution Server changes: - POST /sessions/:id/steer endpoint (202 Accepted) - SteerSession() on SessionManager with GetLocalRuntime() helper for PersistentRuntime unwrapping - Concurrent stream guard on RunSession (rejects if already streaming) - Proper defer ordering: streaming flag cleared before channel close No behavioral change to the TUI — the existing client-side message queue continues to work as before. The TUI can adopt mid-turn steering in a future change by calling LocalRuntime.Steer() directly.
1 parent 3fac361 commit e6f7898

6 files changed

Lines changed: 215 additions & 8 deletions

File tree

pkg/api/types.go

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -160,6 +160,13 @@ type ResumeElicitationRequest struct {
160160
Content map[string]any `json:"content"` // The submitted form data (only present when action is "accept")
161161
}
162162

163+
// SteerSessionRequest represents a request to inject user messages into a
164+
// running agent session. The messages are picked up by the agent loop between
165+
// tool execution and the next LLM call.
166+
type SteerSessionRequest struct {
167+
Messages []Message `json:"messages"`
168+
}
169+
163170
// UpdateSessionTitleRequest represents a request to update a session's title
164171
type UpdateSessionTitleRequest struct {
165172
Title string `json:"title"`

pkg/runtime/loop.go

Lines changed: 31 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -386,13 +386,42 @@ func (r *LocalRuntime) RunStream(ctx context.Context, sess *session.Session) <-c
386386
// Record per-toolset model override for the next LLM turn.
387387
toolModelOverride = resolveToolCallModelOverride(res.Calls, agentTools)
388388

389+
// Only compact proactively when the model will continue (has
390+
// tool calls to process on the next turn). If the model stopped
391+
// and no steered messages override that, compaction is wasteful
392+
// because no further LLM call follows.
393+
if !res.Stopped {
394+
r.compactIfNeeded(ctx, sess, a, m, contextLimit, messageCountBeforeTools, events)
395+
}
396+
397+
// Drain any steered (mid-turn) user messages that arrived while
398+
// the current iteration was in progress. Injecting them here —
399+
// after tool execution, before the stop check — ensures the LLM
400+
// sees the new messages on the next iteration via GetMessages().
401+
if steered := r.DrainSteeredMessages(); len(steered) > 0 {
402+
for _, sm := range steered {
403+
wrapped := fmt.Sprintf(
404+
"<system-reminder>\nThe user sent the following message while you were working:\n%s\n\nPlease address this in your next response while continuing with your current tasks.\n</system-reminder>",
405+
sm.Content,
406+
)
407+
userMsg := session.UserMessage(wrapped, sm.MultiContent...)
408+
sess.AddMessage(userMsg)
409+
events <- UserMessage(sm.Content, sess.ID, sm.MultiContent, len(sess.Messages)-1)
410+
}
411+
412+
// Force the loop to continue — the model must respond to
413+
// the injected messages even if it was about to stop.
414+
res.Stopped = false
415+
416+
// Now that the loop will continue, compact if needed.
417+
r.compactIfNeeded(ctx, sess, a, m, contextLimit, messageCountBeforeTools, events)
418+
}
419+
389420
if res.Stopped {
390421
slog.Debug("Conversation stopped", "agent", a.Name())
391422
r.executeStopHooks(ctx, sess, a, res.Content, events)
392423
break
393424
}
394-
395-
r.compactIfNeeded(ctx, sess, a, m, contextLimit, messageCountBeforeTools, events)
396425
}
397426
}()
398427

pkg/runtime/persistent_runtime.go

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,21 @@ type streamingState struct {
2525
messageID int64 // ID of the current streaming message (0 if none)
2626
}
2727

28+
// GetLocalRuntime extracts the underlying *LocalRuntime from a Runtime
29+
// implementation. It handles both *LocalRuntime and *PersistentRuntime
30+
// (which embeds *LocalRuntime). Returns nil if the runtime type is not
31+
// supported (e.g. RemoteRuntime).
32+
func GetLocalRuntime(rt Runtime) *LocalRuntime {
33+
switch r := rt.(type) {
34+
case *LocalRuntime:
35+
return r
36+
case *PersistentRuntime:
37+
return r.LocalRuntime
38+
default:
39+
return nil
40+
}
41+
}
42+
2843
// New creates a new runtime for an agent and its team.
2944
// The runtime automatically persists session changes to the configured store.
3045
// Returns a Runtime interface which wraps LocalRuntime with persistence handling.

pkg/runtime/runtime.go

Lines changed: 89 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,66 @@ func ResumeReject(reason string) ResumeRequest {
8080
return ResumeRequest{Type: ResumeTypeReject, Reason: reason}
8181
}
8282

83+
// SteeredMessage is a user message injected mid-turn while the agent loop is
84+
// running. It is enqueued via a SteerQueue and drained inside the loop between
85+
// tool execution and the stop-condition check.
86+
type SteeredMessage struct {
87+
Content string
88+
MultiContent []chat.MessagePart
89+
}
90+
91+
// SteerQueue is the interface for storing steered messages that are injected
92+
// into a running agent loop mid-turn. Implementations must be safe for
93+
// concurrent use: Enqueue is called from API handlers while Drain is called
94+
// from the agent loop goroutine.
95+
//
96+
// The default implementation is InMemorySteerQueue. Callers that need
97+
// durable or distributed storage can provide their own implementation
98+
// via the WithSteerQueue option.
99+
type SteerQueue interface {
100+
// Enqueue adds a message to the queue. Returns false if the queue is
101+
// full and the message was not accepted.
102+
Enqueue(msg SteeredMessage) bool
103+
// Drain returns all pending messages and removes them from the queue.
104+
// It must not block — if the queue is empty it returns nil.
105+
Drain() []SteeredMessage
106+
}
107+
108+
// inMemorySteerQueue is the default SteerQueue backed by a buffered channel.
109+
type inMemorySteerQueue struct {
110+
ch chan SteeredMessage
111+
}
112+
113+
// defaultSteerQueueCapacity is the buffer size for the default in-memory queue.
114+
const defaultSteerQueueCapacity = 5
115+
116+
// NewInMemorySteerQueue creates a SteerQueue backed by a buffered channel
117+
// with the given capacity.
118+
func NewInMemorySteerQueue(capacity int) SteerQueue {
119+
return &inMemorySteerQueue{ch: make(chan SteeredMessage, capacity)}
120+
}
121+
122+
func (q *inMemorySteerQueue) Enqueue(msg SteeredMessage) bool {
123+
select {
124+
case q.ch <- msg:
125+
return true
126+
default:
127+
return false
128+
}
129+
}
130+
131+
func (q *inMemorySteerQueue) Drain() []SteeredMessage {
132+
var msgs []SteeredMessage
133+
for {
134+
select {
135+
case m := <-q.ch:
136+
msgs = append(msgs, m)
137+
default:
138+
return msgs
139+
}
140+
}
141+
}
142+
83143
// ToolHandlerFunc is a function type for handling tool calls
84144
type ToolHandlerFunc func(ctx context.Context, sess *session.Session, toolCall tools.ToolCall, events chan Event) (*tools.ToolCallResult, error)
85145

@@ -201,6 +261,11 @@ type LocalRuntime struct {
201261

202262
currentAgentMu sync.RWMutex
203263

264+
// steerQueue stores user messages injected mid-turn. The agent loop
265+
// drains this queue after tool execution, before checking the stop
266+
// condition, so the LLM sees the new messages on its next iteration.
267+
steerQueue SteerQueue
268+
204269
// onToolsChanged is called when an MCP toolset reports a tool list change.
205270
onToolsChanged func(Event)
206271

@@ -228,6 +293,14 @@ func WithTracer(t trace.Tracer) Opt {
228293
}
229294
}
230295

296+
// WithSteerQueue sets a custom SteerQueue implementation for mid-turn message
297+
// injection. If not provided, an in-memory buffered queue is used.
298+
func WithSteerQueue(q SteerQueue) Opt {
299+
return func(r *LocalRuntime) {
300+
r.steerQueue = q
301+
}
302+
}
303+
231304
func WithSessionCompaction(sessionCompaction bool) Opt {
232305
return func(r *LocalRuntime) {
233306
r.sessionCompaction = sessionCompaction
@@ -291,6 +364,7 @@ func NewLocalRuntime(agents *team.Team, opts ...Opt) (*LocalRuntime, error) {
291364
currentAgent: defaultAgent.Name(),
292365
resumeChan: make(chan ResumeRequest),
293366
elicitationRequestCh: make(chan ElicitationResult),
367+
steerQueue: NewInMemorySteerQueue(defaultSteerQueueCapacity),
294368
sessionCompaction: true,
295369
managedOAuth: true,
296370
sessionStore: session.NewInMemorySessionStore(),
@@ -1015,6 +1089,21 @@ func (r *LocalRuntime) ResumeElicitation(ctx context.Context, action tools.Elici
10151089
}
10161090
}
10171091

1092+
// Steer enqueues a user message for mid-turn injection into the running
1093+
// agent loop. The message will be picked up after the current batch of tool
1094+
// calls finishes but before the loop checks whether to stop. Returns false
1095+
// if the queue is full and the message was not enqueued.
1096+
func (r *LocalRuntime) Steer(msg SteeredMessage) bool {
1097+
return r.steerQueue.Enqueue(msg)
1098+
}
1099+
1100+
// DrainSteeredMessages returns all pending steered messages without blocking.
1101+
// It is called inside the agent loop to batch-inject any messages that arrived
1102+
// while the current iteration was in progress.
1103+
func (r *LocalRuntime) DrainSteeredMessages() []SteeredMessage {
1104+
return r.steerQueue.Drain()
1105+
}
1106+
10181107
// Run starts the agent's interaction loop
10191108

10201109
func (r *LocalRuntime) startSpan(ctx context.Context, name string, opts ...trace.SpanStartOption) (context.Context, trace.Span) {

pkg/server/server.go

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,8 @@ func New(ctx context.Context, sessionStore session.Store, runConfig *config.Runt
6262
group.POST("/sessions/:id/agent/:agent", s.runAgent)
6363
group.POST("/sessions/:id/agent/:agent/:agent_name", s.runAgent)
6464
group.POST("/sessions/:id/elicitation", s.elicitation)
65+
// Steer: inject user messages into a running agent session mid-turn
66+
group.POST("/sessions/:id/steer", s.steerSession)
6567

6668
// Agent tool count
6769
group.GET("/agents/:id/:agent_name/tools/count", s.getAgentToolCount)
@@ -317,3 +319,21 @@ func (s *Server) elicitation(c echo.Context) error {
317319

318320
return c.JSON(http.StatusOK, nil)
319321
}
322+
323+
func (s *Server) steerSession(c echo.Context) error {
324+
sessionID := c.Param("id")
325+
var req api.SteerSessionRequest
326+
if err := c.Bind(&req); err != nil {
327+
return echo.NewHTTPError(http.StatusBadRequest, fmt.Sprintf("invalid request body: %v", err))
328+
}
329+
330+
if len(req.Messages) == 0 {
331+
return echo.NewHTTPError(http.StatusBadRequest, "at least one message is required")
332+
}
333+
334+
if err := s.sm.SteerSession(c.Request().Context(), sessionID, req.Messages); err != nil {
335+
return echo.NewHTTPError(http.StatusConflict, fmt.Sprintf("failed to steer session: %v", err))
336+
}
337+
338+
return c.JSON(http.StatusAccepted, map[string]string{"status": "queued"})
339+
}

pkg/server/session_manager.go

Lines changed: 53 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -23,10 +23,11 @@ import (
2323
)
2424

2525
type activeRuntimes struct {
26-
runtime runtime.Runtime
27-
cancel context.CancelFunc
28-
session *session.Session // The actual session object used by the runtime
29-
titleGen *sessiontitle.Generator // Title generator (includes fallback models)
26+
runtime runtime.Runtime
27+
cancel context.CancelFunc
28+
session *session.Session // The actual session object used by the runtime
29+
titleGen *sessiontitle.Generator // Title generator (includes fallback models)
30+
streaming bool // True while RunStream is active; prevents concurrent runs
3031
}
3132

3233
// SessionManager manages sessions for HTTP and Connect-RPC servers.
@@ -160,6 +161,14 @@ func (sm *SessionManager) RunSession(ctx context.Context, sessionID, agentFilena
160161
}
161162

162163
runtimeSession, exists := sm.runtimeSessions.Load(sessionID)
164+
165+
// Reject if a stream is already active for this session. The caller
166+
// should use POST /sessions/:id/steer to inject follow-up messages
167+
// into a running session instead of starting a second concurrent stream.
168+
if exists && runtimeSession.streaming {
169+
return nil, errors.New("session is already streaming; use /steer to send follow-up messages")
170+
}
171+
163172
streamCtx, cancel := context.WithCancel(ctx)
164173
var titleGen *sessiontitle.Generator
165174
if !exists {
@@ -182,6 +191,8 @@ func (sm *SessionManager) RunSession(ctx context.Context, sessionID, agentFilena
182191
titleGen = runtimeSession.titleGen
183192
}
184193

194+
runtimeSession.streaming = true
195+
185196
streamChan := make(chan runtime.Event)
186197

187198
// Check if we need to generate a title
@@ -194,8 +205,17 @@ func (sm *SessionManager) RunSession(ctx context.Context, sessionID, agentFilena
194205
}
195206

196207
stream := runtimeSession.runtime.RunStream(streamCtx, sess)
197-
defer cancel()
198-
defer close(streamChan)
208+
// Single defer to control ordering: clear the streaming flag
209+
// BEFORE closing streamChan. When the client sees the channel
210+
// close it may immediately call RunSession for the next queued
211+
// message; streaming must already be false by then.
212+
defer func() {
213+
sm.mux.Lock()
214+
runtimeSession.streaming = false
215+
sm.mux.Unlock()
216+
close(streamChan)
217+
cancel()
218+
}()
199219
for event := range stream {
200220
if streamCtx.Err() != nil {
201221
return
@@ -230,6 +250,33 @@ func (sm *SessionManager) ResumeSession(ctx context.Context, sessionID, confirma
230250
return nil
231251
}
232252

253+
// SteerSession enqueues user messages for mid-turn injection into a running
254+
// session. The messages are picked up by the agent loop after the current tool
255+
// calls finish but before the next LLM call. Returns an error if the session
256+
// is not actively running or if the steer buffer is full.
257+
func (sm *SessionManager) SteerSession(_ context.Context, sessionID string, messages []api.Message) error {
258+
rt, exists := sm.runtimeSessions.Load(sessionID)
259+
if !exists {
260+
return errors.New("session not found or not running")
261+
}
262+
263+
localRT := runtime.GetLocalRuntime(rt.runtime)
264+
if localRT == nil {
265+
return errors.New("steering not supported for this runtime type")
266+
}
267+
268+
for _, msg := range messages {
269+
if !localRT.Steer(runtime.SteeredMessage{
270+
Content: msg.Content,
271+
MultiContent: msg.MultiContent,
272+
}) {
273+
return errors.New("steer queue full")
274+
}
275+
}
276+
277+
return nil
278+
}
279+
233280
// ResumeElicitation resumes an elicitation request.
234281
func (sm *SessionManager) ResumeElicitation(ctx context.Context, sessionID, action string, content map[string]any) error {
235282
sm.mux.Lock()

0 commit comments

Comments
 (0)