Skip to content
Merged
Show file tree
Hide file tree
Changes from 9 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions pkg/api/types.go
Original file line number Diff line number Diff line change
Expand Up @@ -160,6 +160,13 @@ type ResumeElicitationRequest struct {
Content map[string]any `json:"content"` // The submitted form data (only present when action is "accept")
}

// SteerSessionRequest represents a request to inject user messages into a
// running agent session. The messages are picked up by the agent loop between
// tool execution and the next LLM call.
type SteerSessionRequest struct {
Messages []Message `json:"messages"`
}

// UpdateSessionTitleRequest represents a request to update a session's title
type UpdateSessionTitleRequest struct {
Title string `json:"title"`
Expand Down
12 changes: 12 additions & 0 deletions pkg/app/app.go
Original file line number Diff line number Diff line change
Expand Up @@ -532,6 +532,18 @@ func (a *App) SubscribeWith(ctx context.Context, send func(tea.Msg)) {
}
}

// Steer enqueues a user message for mid-turn injection into the running
// agent loop. Works with both local and remote runtimes.
func (a *App) Steer(msg runtime.QueuedMessage) error {
return a.runtime.Steer(msg)
}

// FollowUp enqueues a message for end-of-turn processing. Each follow-up
// gets a full undivided agent turn.
func (a *App) FollowUp(msg runtime.QueuedMessage) error {
return a.runtime.FollowUp(msg)
}

// Resume resumes the runtime with the given confirmation request
func (a *App) Resume(req runtime.ResumeRequest) {
a.runtime.Resume(context.Background(), req)
Expand Down
2 changes: 2 additions & 0 deletions pkg/app/app_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,8 @@ func (m *mockRuntime) UpdateSessionTitle(_ context.Context, sess *session.Sessio
func (m *mockRuntime) TitleGenerator() *sessiontitle.Generator { return nil }
func (m *mockRuntime) Close() error { return nil }
func (m *mockRuntime) Stop() {}
func (m *mockRuntime) Steer(_ runtime.QueuedMessage) error { return nil }
func (m *mockRuntime) FollowUp(_ runtime.QueuedMessage) error { return nil }

// Verify mockRuntime implements runtime.Runtime
var _ runtime.Runtime = (*mockRuntime)(nil)
Expand Down
2 changes: 2 additions & 0 deletions pkg/cli/runner_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,8 @@ func (m *mockRuntime) ExecuteMCPPrompt(context.Context, string, map[string]strin
func (m *mockRuntime) UpdateSessionTitle(context.Context, *session.Session, string) error { return nil }
func (m *mockRuntime) TitleGenerator() *sessiontitle.Generator { return nil }
func (m *mockRuntime) Close() error { return nil }
func (m *mockRuntime) Steer(runtime.QueuedMessage) error { return nil }
func (m *mockRuntime) FollowUp(runtime.QueuedMessage) error { return nil }
func (m *mockRuntime) RegenerateTitle(context.Context, *session.Session, chan runtime.Event) {}

func (m *mockRuntime) Resume(_ context.Context, req runtime.ResumeRequest) {
Expand Down
12 changes: 12 additions & 0 deletions pkg/runtime/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -266,6 +266,18 @@ func (c *Client) ResumeSession(ctx context.Context, id, confirmation, reason, to
return c.doRequest(ctx, http.MethodPost, "/api/sessions/"+id+"/resume", req, nil)
}

// SteerSession injects user messages into a running session mid-turn.
func (c *Client) SteerSession(ctx context.Context, sessionID string, messages []api.Message) error {
req := api.SteerSessionRequest{Messages: messages}
return c.doRequest(ctx, http.MethodPost, "/api/sessions/"+sessionID+"/steer", req, nil)
}

// FollowUpSession queues messages for end-of-turn processing.
func (c *Client) FollowUpSession(ctx context.Context, sessionID string, messages []api.Message) error {
req := api.SteerSessionRequest{Messages: messages}
return c.doRequest(ctx, http.MethodPost, "/api/sessions/"+sessionID+"/followup", req, nil)
}

// DeleteSession deletes a session by ID
func (c *Client) DeleteSession(ctx context.Context, id string) error {
return c.doRequest(ctx, "DELETE", "/api/sessions/"+id, nil, nil)
Expand Down
2 changes: 2 additions & 0 deletions pkg/runtime/commands_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,8 @@ func (m *mockRuntime) UpdateSessionTitle(context.Context, *session.Session, stri
}
func (m *mockRuntime) TitleGenerator() *sessiontitle.Generator { return nil }
func (m *mockRuntime) Close() error { return nil }
func (m *mockRuntime) Steer(QueuedMessage) error { return nil }
func (m *mockRuntime) FollowUp(QueuedMessage) error { return nil }

func (m *mockRuntime) RegenerateTitle(context.Context, *session.Session, chan Event) {
}
Expand Down
48 changes: 46 additions & 2 deletions pkg/runtime/loop.go
Original file line number Diff line number Diff line change
Expand Up @@ -386,13 +386,57 @@ func (r *LocalRuntime) RunStream(ctx context.Context, sess *session.Session) <-c
// Record per-toolset model override for the next LLM turn.
toolModelOverride = resolveToolCallModelOverride(res.Calls, agentTools)

// Only compact proactively when the model will continue (has
// tool calls to process on the next turn). If the model stopped
// and no steered messages override that, compaction is wasteful
// because no further LLM call follows.
if !res.Stopped {
r.compactIfNeeded(ctx, sess, a, m, contextLimit, messageCountBeforeTools, events)
}

Comment thread
trungutt marked this conversation as resolved.
Outdated
// --- STEERING: mid-turn injection ---
// Drain ALL pending steer messages. These are urgent course-
// corrections that the model should see on the very next
// iteration, wrapped in <system-reminder> tags.
if steered := r.DrainSteeredMessages(ctx); len(steered) > 0 {
for _, sm := range steered {
wrapped := fmt.Sprintf(
"<system-reminder>\nThe user sent the following message while you were working:\n%s\n\nPlease address this in your next response while continuing with your current tasks.\n</system-reminder>",
sm.Content,
)
userMsg := session.UserMessage(wrapped, sm.MultiContent...)
sess.AddMessage(userMsg)
events <- UserMessage(sm.Content, sess.ID, sm.MultiContent, len(sess.Messages)-1)
}

// Force the loop to continue — the model must respond to
// the injected messages even if it was about to stop.
res.Stopped = false
Comment thread
trungutt marked this conversation as resolved.
Outdated

// Now that the loop will continue, compact if needed.
r.compactIfNeeded(ctx, sess, a, m, contextLimit, messageCountBeforeTools, events)
}

if res.Stopped {
slog.Debug("Conversation stopped", "agent", a.Name())
r.executeStopHooks(ctx, sess, a, res.Content, events)

// --- FOLLOW-UP: end-of-turn injection ---
// Pop exactly one follow-up message. Unlike steered
// messages, follow-ups are plain user messages that start
// a new turn — the model sees them as fresh input, not a
// mid-stream interruption. Each follow-up gets a full
// undivided agent turn.
if followUp, ok := r.DequeueFollowUp(ctx); ok {
userMsg := session.UserMessage(followUp.Content, followUp.MultiContent...)
sess.AddMessage(userMsg)
events <- UserMessage(followUp.Content, sess.ID, followUp.MultiContent, len(sess.Messages)-1)
r.compactIfNeeded(ctx, sess, a, m, contextLimit, messageCountBeforeTools, events)
continue // re-enter the loop for a new turn
}

break
}

r.compactIfNeeded(ctx, sess, a, m, contextLimit, messageCountBeforeTools, events)
}
}()

Expand Down
121 changes: 121 additions & 0 deletions pkg/runtime/message_queue.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,121 @@
package runtime

import (
"context"

"github.com/docker/docker-agent/pkg/chat"
)

// QueuedMessage is a user message waiting to be injected into the agent loop,
// either mid-turn (via the steer queue) or at end-of-turn (via the follow-up
// queue).
type QueuedMessage struct {
Content string
MultiContent []chat.MessagePart
}

// MessageQueue is the interface for storing messages that are injected into
// the agent loop. Implementations must be safe for concurrent use: Enqueue
// is called from API handlers while Dequeue/Drain are called from the agent
// loop goroutine.
//
// Dequeue uses a Lock + Confirm/Cancel pattern: Dequeue locks the next
// message (making it invisible to subsequent Dequeue calls), Confirm
// permanently removes it after the message has been successfully processed,
// and Cancel releases it back to the queue if processing fails. This
// prevents message loss in persistent queue implementations where the
// session store is also durable.
//
// Note: for the default in-memory queue, Confirm and Cancel are no-ops
// because the message is consumed from the channel on Dequeue and the
// session is also in-memory. The pattern exists so that persistent
// implementations (with a durable session store) can guarantee
// exactly-once delivery.
//
// The default implementation is NewInMemoryMessageQueue. Callers that need
// durable or distributed storage can provide their own implementation
// via the WithSteerQueue or WithFollowUpQueue options.
type MessageQueue interface {
// Enqueue adds a message to the queue. Returns false if the queue is
// full or the context is cancelled.
Enqueue(ctx context.Context, msg QueuedMessage) bool
// Dequeue locks and returns the next message from the queue. The
// message is invisible to subsequent Dequeue calls until Confirm or
// Cancel is called. Returns the message and true, or a zero value
// and false if the queue is empty. Must not block.
Dequeue(ctx context.Context) (QueuedMessage, bool)
// Confirm permanently removes the most recently dequeued message.
// Must be called after the message has been successfully persisted
// to the session. For in-memory queues this is a no-op.
Confirm(ctx context.Context) error
Comment thread
trungutt marked this conversation as resolved.
Outdated
// Cancel releases the most recently dequeued message back to the
// queue. For in-memory queues this is a no-op (the message was
// already consumed from the channel).
Cancel(ctx context.Context) error
// Drain locks, returns, and auto-confirms all pending messages.
// Must not block — if the queue is empty it returns nil.
Drain(ctx context.Context) []QueuedMessage
// Len returns the current number of messages in the queue.
Len(ctx context.Context) int
Comment thread
trungutt marked this conversation as resolved.
Outdated
}

// inMemoryMessageQueue is the default MessageQueue backed by a buffered channel.
type inMemoryMessageQueue struct {
ch chan QueuedMessage
}

const (
// defaultSteerQueueCapacity is the buffer size for the default in-memory steer queue.
defaultSteerQueueCapacity = 5
// defaultFollowUpQueueCapacity is the buffer size for the default in-memory follow-up queue.
// Higher than steer because follow-ups accumulate while waiting for the turn to end.
defaultFollowUpQueueCapacity = 20
)

// NewInMemoryMessageQueue creates a MessageQueue backed by a buffered channel
// with the given capacity.
func NewInMemoryMessageQueue(capacity int) MessageQueue {
return &inMemoryMessageQueue{ch: make(chan QueuedMessage, capacity)}
}

func (q *inMemoryMessageQueue) Enqueue(_ context.Context, msg QueuedMessage) bool {
select {
case q.ch <- msg:
return true
default:
return false
}
}

func (q *inMemoryMessageQueue) Dequeue(_ context.Context) (QueuedMessage, bool) {
select {
case m := <-q.ch:
return m, true
default:
return QueuedMessage{}, false
}
}

// Confirm is a no-op for in-memory queues — the message was already
// removed from the channel on Dequeue.
func (q *inMemoryMessageQueue) Confirm(_ context.Context) error { return nil }

// Cancel is a no-op for in-memory queues — the message cannot be put
// back into a buffered channel without risking deadlock.
func (q *inMemoryMessageQueue) Cancel(_ context.Context) error { return nil }

func (q *inMemoryMessageQueue) Drain(_ context.Context) []QueuedMessage {
var msgs []QueuedMessage
for {
select {
case m := <-q.ch:
msgs = append(msgs, m)
default:
return msgs
}
}
}

func (q *inMemoryMessageQueue) Len(_ context.Context) int {
return len(q.ch)
}
6 changes: 6 additions & 0 deletions pkg/runtime/remote_client.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,12 @@ type RemoteClient interface {
// RunAgentWithAgentName executes an agent with a specific agent name
RunAgentWithAgentName(ctx context.Context, sessionID, agent, agentName string, messages []api.Message) (<-chan Event, error)

// SteerSession injects user messages into a running session mid-turn
SteerSession(ctx context.Context, sessionID string, messages []api.Message) error

// FollowUpSession queues messages for end-of-turn processing
FollowUpSession(ctx context.Context, sessionID string, messages []api.Message) error

// UpdateSessionTitle updates the title of a session
UpdateSessionTitle(ctx context.Context, sessionID, title string) error

Expand Down
21 changes: 21 additions & 0 deletions pkg/runtime/remote_runtime.go
Original file line number Diff line number Diff line change
Expand Up @@ -211,6 +211,27 @@ func (r *RemoteRuntime) Run(ctx context.Context, sess *session.Session) ([]sessi
return sess.GetAllMessages(), nil
}

// Steer enqueues a user message for mid-turn injection into the running
// agent loop on the remote server.
func (r *RemoteRuntime) Steer(msg QueuedMessage) error {
if r.sessionID == "" {
return errors.New("no active session")
}
return r.client.SteerSession(context.Background(), r.sessionID, []api.Message{
{Content: msg.Content, MultiContent: msg.MultiContent},
})
}

// FollowUp enqueues a message for end-of-turn processing on the remote server.
func (r *RemoteRuntime) FollowUp(msg QueuedMessage) error {
if r.sessionID == "" {
return errors.New("no active session")
}
return r.client.FollowUpSession(context.Background(), r.sessionID, []api.Message{
{Content: msg.Content, MultiContent: msg.MultiContent},
})
}

// Resume allows resuming execution after user confirmation
func (r *RemoteRuntime) Resume(ctx context.Context, req ResumeRequest) {
slog.Debug("Resuming remote runtime", "agent", r.currentAgent, "type", req.Type, "reason", req.Reason, "tool_name", req.ToolName, "session_id", r.sessionID)
Expand Down
Loading
Loading