Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
152 changes: 98 additions & 54 deletions internal/agent/agent.go
Original file line number Diff line number Diff line change
Expand Up @@ -252,6 +252,12 @@ type Agent struct {
// callers (e.g. Kit.applyComposedSystemPrompt invoked from multiple
// goroutines) don't race on a.systemPrompt / a.fantasyAgent.
promptMu sync.Mutex

// toolsMu guards extraTools so the live tool set can be re-read each
// step (via composeAllTools) from inside an in-flight Stream while a
// concurrent SetExtraTools mutates it. Without this, mid-turn AddTools/
// RemoveTools would race on the extraTools slice header.
toolsMu sync.RWMutex
}

// GenerateWithLoopResult contains the result and conversation history from an agent interaction.
Expand Down Expand Up @@ -430,17 +436,7 @@ func (a *Agent) ensureMCPTools() {
// tool set (core + MCP + extension tools). Used after MCP tools arrive
// asynchronously and by SetModel.
func (a *Agent) rebuildFantasyAgent() {
allTools := make([]fantasy.AgentTool, len(a.coreTools))
copy(allTools, a.coreTools)
if a.toolManager != nil {
allTools = append(allTools, mcpToolsToAgentTools(a.toolManager.GetTools(), a.toolManager)...)
}
if len(a.extraTools) > 0 {
allTools = append(allTools, a.extraTools...)
}
if a.toolWrapper != nil {
allTools = a.toolWrapper(allTools)
}
allTools := a.composeAllTools()

providerResult := &models.ProviderResult{
Model: a.model,
Expand All @@ -456,6 +452,29 @@ func (a *Agent) rebuildFantasyAgent() {
a.fantasyAgent = fantasy.NewAgent(a.model, agentOpts...)
}

// composeAllTools builds the full live tool set (core + MCP + extra tools)
// with the tool wrapper applied, matching the tools baked into the fantasy
// agent by rebuildFantasyAgent. It re-reads the current extraTools under
// toolsMu so callers (notably the per-step PrepareStep callback) observe
// runtime AddTools/RemoveTools changes mid-turn.
func (a *Agent) composeAllTools() []fantasy.AgentTool {
a.toolsMu.RLock()
defer a.toolsMu.RUnlock()

allTools := make([]fantasy.AgentTool, len(a.coreTools))
copy(allTools, a.coreTools)
if a.toolManager != nil {
allTools = append(allTools, mcpToolsToAgentTools(a.toolManager.GetTools(), a.toolManager)...)
}
if len(a.extraTools) > 0 {
allTools = append(allTools, a.extraTools...)
}
if a.toolWrapper != nil {
allTools = a.toolWrapper(allTools)
}
return allTools
}

// buildAgentOptions constructs the fantasy.AgentOption slice from config,
// provider result, and the combined tool list. Shared by NewAgent,
// rebuildFantasyAgent, and SetModel.
Expand Down Expand Up @@ -825,60 +844,71 @@ func (a *Agent) GenerateWithCallbacks(ctx context.Context, messages []fantasy.Me
},
}

// Always wire up PrepareStep to handle both steering and the
// OnPrepareStep hook. Steering drains its channel first, then
// OnPrepareStep hooks run against the (possibly already steered)
// messages.
// Always wire up PrepareStep. It serves three purposes:
// 1. Re-read the live tool set each step so runtime AddTools/
// RemoveTools (and MCP server changes) take effect at the next
// LLM step of the *current* turn, as documented. The entire
// multi-step loop runs inside a single fantasy Stream call that
// otherwise captures the tool snapshot taken when Stream began;
// populating PrepareStepResult.Tools makes fantasy re-read tools
// per step instead.
// 2. Steering: drain queued steer messages.
// 3. The OnPrepareStep hook.
// Steering drains its channel first, then OnPrepareStep hooks run
// against the (possibly already steered) messages.
steerCh := steerChFromContext(ctx)
onConsumed := steerConsumedFromContext(ctx)
hasSteering := steerCh != nil
hasPrepareStepHook := cb.OnPrepareStep != nil

if hasSteering || hasPrepareStepHook {
streamCall.PrepareStep = func(
stepCtx context.Context,
opts fantasy.PrepareStepFunctionOptions,
) (context.Context, fantasy.PrepareStepResult, error) {
result := fantasy.PrepareStepResult{
Model: opts.Model,
Messages: opts.Messages,
}
streamCall.PrepareStep = func(
stepCtx context.Context,
opts fantasy.PrepareStepFunctionOptions,
) (context.Context, fantasy.PrepareStepResult, error) {
result := fantasy.PrepareStepResult{
Model: opts.Model,
Messages: opts.Messages,
// Re-read the live tool set so mid-turn tool changes are
// honored. composeAllTools matches the composition baked
// into the fantasy agent, so in the steady state this is
// identical to the snapshot fantasy would have used.
Tools: a.composeAllTools(),
}

// Phase 1: Drain steering channel (if present).
if hasSteering {
var steered []SteerMessage
for {
select {
case msg := <-steerCh:
steered = append(steered, msg)
default:
goto done
}
// Phase 1: Drain steering channel (if present).
if hasSteering {
var steered []SteerMessage
for {
select {
case msg := <-steerCh:
steered = append(steered, msg)
default:
goto done
}
done:
if len(steered) > 0 {
for _, sm := range steered {
result.Messages = append(result.Messages,
fantasy.NewUserMessage(sm.Text, sm.Files...))
}
if onConsumed != nil {
onConsumed(len(steered))
}
}
done:
if len(steered) > 0 {
for _, sm := range steered {
result.Messages = append(result.Messages,
fantasy.NewUserMessage(sm.Text, sm.Files...))
}
if onConsumed != nil {
onConsumed(len(steered))
}
}
}

// Phase 2: Run OnPrepareStep hook (if registered).
if hasPrepareStepHook {
if replacement := cb.OnPrepareStep(opts.StepNumber, result.Messages); replacement != nil {
result.Messages = replacement
}
// Phase 2: Run OnPrepareStep hook (if registered).
if hasPrepareStepHook {
if replacement := cb.OnPrepareStep(opts.StepNumber, result.Messages); replacement != nil {
result.Messages = replacement
}
}

// Apply message-level cache control for Anthropic models.
result.Messages = applyCacheControlToMessages(result.Messages)
// Apply message-level cache control for Anthropic models.
result.Messages = applyCacheControlToMessages(result.Messages)

return stepCtx, result, nil
}
return stepCtx, result, nil
}

// Wire OnRetry callback if provided.
Expand Down Expand Up @@ -1076,6 +1106,8 @@ func extractMCPContentText(result string) string {
// GetTools returns the list of available tools loaded in the agent,
// including core tools, MCP tools, and extension-registered tools.
func (a *Agent) GetTools() []fantasy.AgentTool {
a.toolsMu.RLock()
defer a.toolsMu.RUnlock()
allTools := make([]fantasy.AgentTool, len(a.coreTools))
copy(allTools, a.coreTools)
if a.toolManager != nil {
Expand All @@ -1102,13 +1134,17 @@ func (a *Agent) GetMCPToolCount() int {

// GetExtensionToolCount returns the number of tools registered by extensions.
func (a *Agent) GetExtensionToolCount() int {
a.toolsMu.RLock()
defer a.toolsMu.RUnlock()
return len(a.extraTools)
}

// GetExtraTools returns the agent's current extra tools (e.g.
// extension-registered tools). The returned slice is a copy so callers can
// snapshot and later restore it via SetExtraTools.
func (a *Agent) GetExtraTools() []fantasy.AgentTool {
a.toolsMu.RLock()
defer a.toolsMu.RUnlock()
if len(a.extraTools) == 0 {
return nil
}
Expand All @@ -1119,9 +1155,17 @@ func (a *Agent) GetExtraTools() []fantasy.AgentTool {

// SetExtraTools replaces the agent's extra tools (e.g. extension-registered
// tools) and rebuilds the internal agent with the updated tool list. The
// model, system prompt, and all other configuration are preserved.
// model, system prompt, and all other configuration are preserved. The
// incoming slice is cloned so later caller mutations cannot bypass toolsMu
// and race composeAllTools. The rebuild is serialized under promptMu so it
// can't race other fantasyAgent rebuilds (SetSystemPrompt, SetModel, MCP).
func (a *Agent) SetExtraTools(extraTools []fantasy.AgentTool) {
a.extraTools = extraTools
a.toolsMu.Lock()
a.extraTools = append([]fantasy.AgentTool(nil), extraTools...)
a.toolsMu.Unlock()
Comment thread
coderabbitai[bot] marked this conversation as resolved.

a.promptMu.Lock()
defer a.promptMu.Unlock()
a.rebuildFantasyAgent()
Comment thread
coderabbitai[bot] marked this conversation as resolved.
}

Expand Down
168 changes: 168 additions & 0 deletions internal/agent/agent_midturn_tools_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,168 @@
package agent

import (
"context"
"testing"

"charm.land/fantasy"
)

// toolNamesFromStep extracts tool names from a PrepareStep result for assertions.
func toolNamesFromStep(tools []fantasy.AgentTool) map[string]struct{} {
out := make(map[string]struct{}, len(tools))
for _, t := range tools {
out[t.Info().Name] = struct{}{}
}
return out
}

// newTestTool builds a minimal AgentTool with the given name for use in tests.
func newTestTool(name string) fantasy.AgentTool {
type emptyInput struct{}
return fantasy.NewAgentTool(name, "test tool "+name,
func(_ context.Context, _ emptyInput, _ fantasy.ToolCall) (fantasy.ToolResponse, error) {
return fantasy.NewTextResponse("ok"), nil
},
)
}

// midTurnToolAgent is a fake fantasy.Agent that simulates the multi-step
// agentic loop fantasy runs inside a single Stream call. Between step 0 and
// step 1 it invokes mutate(), emulating a tool handler that calls AddTools/
// RemoveTools mid-turn. It records the tool list each step's PrepareStep
// callback yields so the test can assert that mid-turn changes are observed.
type midTurnToolAgent struct {
mutate func()
stepTools []map[string]struct{}
prepareErr error
}

func (f *midTurnToolAgent) Generate(_ context.Context, _ fantasy.AgentCall) (*fantasy.AgentResult, error) {
return &fantasy.AgentResult{}, nil
}

func (f *midTurnToolAgent) Stream(ctx context.Context, opts fantasy.AgentStreamCall) (*fantasy.AgentResult, error) {
// Step 0: capture the tools fantasy would use for the first step.
for step := range 2 {
if step == 1 && f.mutate != nil {
// Emulate a tool handler that mutates the live tool set
// mid-turn (e.g. enable_toolset calling host.AddTools).
f.mutate()
}
if opts.PrepareStep != nil {
_, prepared, err := opts.PrepareStep(ctx, fantasy.PrepareStepFunctionOptions{
StepNumber: step,
Model: nil,
Messages: nil,
})
if err != nil {
f.prepareErr = err
return nil, err
}
f.stepTools = append(f.stepTools, toolNamesFromStep(prepared.Tools))
}
}
return &fantasy.AgentResult{}, nil
}

// TestPrepareStepReflectsMidTurnTools is the regression test for #76.
// AddTools/RemoveTools must take effect at the next LLM step of the current
// turn. Because the whole agentic loop runs inside a single fantasy Stream
// call, this only works if Kit's PrepareStep callback re-reads the live tool
// set each step and populates PrepareStepResult.Tools. Before the fix the
// per-step tool list was the snapshot captured when Stream began, so a tool
// added mid-turn never appeared until the next turn.
func TestPrepareStepReflectsMidTurnTools(t *testing.T) {
t.Parallel()

core := newTestTool("read")
loadMore := newTestTool("load_more")
foo := newTestTool("foo")

fake := &midTurnToolAgent{}

a := &Agent{
fantasyAgent: fake,
streamingEnabled: true,
coreTools: []fantasy.AgentTool{core},
extraTools: []fantasy.AgentTool{loadMore},
}

// Mid-turn, add a brand new tool "foo" (additive to load_more).
fake.mutate = func() {
a.SetExtraTools([]fantasy.AgentTool{loadMore, foo})
}

msgs := []fantasy.Message{fantasy.NewUserMessage("go")}
if _, err := a.GenerateWithCallbacks(context.Background(), msgs, GenerateCallbacks{}); err != nil {
t.Fatalf("GenerateWithCallbacks returned error: %v", err)
}

if len(fake.stepTools) != 2 {
t.Fatalf("expected 2 prepared steps, got %d", len(fake.stepTools))
}

// Step 0: foo must NOT be present yet; load_more and read are.
step0 := fake.stepTools[0]
if _, ok := step0["read"]; !ok {
t.Errorf("step 0: expected core tool 'read' to be present")
}
if _, ok := step0["load_more"]; !ok {
t.Errorf("step 0: expected 'load_more' to be present")
}
if _, ok := step0["foo"]; ok {
t.Errorf("step 0: 'foo' should not be present before it was added")
}

// Step 1: after the mid-turn AddTools, foo MUST be visible to the step.
step1 := fake.stepTools[1]
if _, ok := step1["foo"]; !ok {
t.Errorf("step 1: expected mid-turn-added 'foo' to be present (regression #76)")
}
if _, ok := step1["read"]; !ok {
t.Errorf("step 1: expected core tool 'read' to remain present")
}
if _, ok := step1["load_more"]; !ok {
t.Errorf("step 1: expected 'load_more' to remain present")
}
}

// TestPrepareStepReflectsMidTurnToolRemoval verifies the RemoveTools side of
// the contract: a tool removed mid-turn disappears from the next step.
func TestPrepareStepReflectsMidTurnToolRemoval(t *testing.T) {
t.Parallel()

core := newTestTool("read")
loadMore := newTestTool("load_more")
temp := newTestTool("temp")

fake := &midTurnToolAgent{}

a := &Agent{
fantasyAgent: fake,
streamingEnabled: true,
coreTools: []fantasy.AgentTool{core},
extraTools: []fantasy.AgentTool{loadMore, temp},
}

// Mid-turn, remove "temp".
fake.mutate = func() {
a.SetExtraTools([]fantasy.AgentTool{loadMore})
}

msgs := []fantasy.Message{fantasy.NewUserMessage("go")}
if _, err := a.GenerateWithCallbacks(context.Background(), msgs, GenerateCallbacks{}); err != nil {
t.Fatalf("GenerateWithCallbacks returned error: %v", err)
}

if len(fake.stepTools) != 2 {
t.Fatalf("expected 2 prepared steps, got %d", len(fake.stepTools))
}

if _, ok := fake.stepTools[0]["temp"]; !ok {
t.Errorf("step 0: expected 'temp' to be present before removal")
}
if _, ok := fake.stepTools[1]["temp"]; ok {
t.Errorf("step 1: 'temp' should be gone after mid-turn RemoveTools (regression #76)")
}
}
Loading
Loading