Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
70 changes: 65 additions & 5 deletions core/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -503,6 +503,7 @@ async def _process_injection(
request_state = self._new_request_state()
tool_log: list[dict] = []
while response.tool_calls:
await self._batch_approve_writes(response.tool_calls, channel, user_id, request_state)
tool_results = []
for call in response.tool_calls:
result = await self._execute_tool(call, channel, user_id, request_state)
Expand Down Expand Up @@ -599,6 +600,7 @@ async def _process_session(
request_state = self._new_request_state()
tool_log: list[dict] = []
while response.tool_calls:
await self._batch_approve_writes(response.tool_calls, channel, user_id, request_state)
tool_results = []
for call in response.tool_calls:
result = await self._execute_tool(call, channel, user_id, request_state)
Expand Down Expand Up @@ -1107,21 +1109,79 @@ async def _tool_web_search(self, params: dict) -> dict:
async def _request_approval(
self, tool_name: str, params: dict, channel: str, user_id: str
) -> str:
"""Ask the user for approval via their channel (e.g. Telegram inline keyboard).
"""Ask the user to approve a single tool call via their channel.

Creates a pending approval future, sends the prompt to the channel,
and waits for the user to respond.
Returns one of ``"approved"``, ``"denied"``, or ``"skipped"``.
"""
return await self._await_approval(
format_approval_message(tool_name, params),
channel,
user_id,
tool_name,
params,
)

async def _batch_approve_writes(
self,
tool_calls: list,
channel: str,
user_id: str,
request_state: dict,
) -> None:
"""Approve a turn's pending write actions with a single prompt.

The LLM can emit several write tool calls in one response (e.g. "set
reminders for the next 5 days"). Prompting for each separately forces
the user to approve one-at-a-time. Instead, collect every write that
still needs a decision, ask once, and record the decision per action
so :meth:`_execute_tool` reuses it instead of prompting again.

A lone write is left to the per-call path — batching only helps when
there are two or more. The decision is all-or-nothing across the batch.
"""
if channel == "system":
return
write_decisions = request_state.setdefault("write_decisions", {})
pending: list[tuple[str, str]] = [] # (signature, description)
seen: set[str] = set()
for call in tool_calls:
if not self.permissions.is_write_action(call.name, call.arguments):
continue
if self.permissions.check(call.name, call.arguments) != PermissionLevel.ASK:
continue
sig = self._write_signature(call.name, call.arguments)
if sig in write_decisions or sig in seen:
continue
seen.add(sig)
pending.append((sig, format_approval_message(call.name, call.arguments)))
if len(pending) < 2:
return
lines = "\n\n".join(f"{i}. {desc}" for i, (_, desc) in enumerate(pending, 1))
description = f"Approve these {len(pending)} actions?\n\n{lines}"
decision = await self._await_approval(description, channel, user_id)
for sig, _ in pending:
write_decisions[sig] = decision

async def _await_approval(
self,
description: str,
channel: str,
user_id: str,
tool_name: str | None = None,
params: dict | None = None,
) -> str:
"""Send an approval prompt to the channel and wait for the response.

Creates a pending approval future, sends the prompt, and waits.
Returns one of ``"approved"``, ``"denied"``, or ``"skipped"``.
"""
ch = self.channels.get(channel)
if not ch:
# No channel available to ask — auto-approve (e.g. admin API)
log.warning("No channel %r for approval, auto-approving %s", channel, tool_name)
log.warning("No channel %r for approval, auto-approving", channel)
return "approved"

request_id, future = self.permissions.create_approval_request(tool_name, params)
description = format_approval_message(tool_name, params)

# Send the approval prompt via the channel
try:
Expand Down
76 changes: 76 additions & 0 deletions tests/test_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -362,3 +362,79 @@ async def fake_approval(name, params, channel, user_id):
)
assert "skipped" in skipped.get("error", "")
assert other.get("ok") is True # the skip did not leak onto a different write


# ---------------------------------------------------------------------------
# Batch approval — multiple writes in one turn share a single prompt (#12)
# ---------------------------------------------------------------------------


@pytest.mark.asyncio
async def test_batch_approval_asks_once_for_multiple_writes(agent, monkeypatch) -> None:
"""Several writes in one turn must trigger exactly one approval prompt."""
prompts = {"n": 0}

async def fake_await(description, channel, user_id, tool_name=None, params=None):
prompts["n"] += 1
return "approved"

monkeypatch.setattr(agent, "_await_approval", fake_await)
monkeypatch.setattr(agent, "_tool_manage_jobs", _ok_manage_jobs)
agent.channels = {"telegram": object()}

state = agent._new_request_state()
c1 = _job_call("1", task="ping mum", run_at="2026-07-01T09:00:00")
c2 = _job_call("2", task="ping dad", run_at="2026-07-02T09:00:00")

await agent._batch_approve_writes([c1, c2], "telegram", "u1", state)
assert prompts["n"] == 1 # one prompt covered both writes

r1 = await agent._execute_tool(c1, "telegram", "u1", state)
r2 = await agent._execute_tool(c2, "telegram", "u1", state)
assert prompts["n"] == 1 # execution reused the batch decision, no re-prompt
assert r1.get("ok") is True and r2.get("ok") is True


@pytest.mark.asyncio
async def test_batch_approval_denied_blocks_every_write(agent, monkeypatch) -> None:
"""Denying the batch blocks all of its writes, not just one."""

async def deny(description, channel, user_id, tool_name=None, params=None):
return "denied"

monkeypatch.setattr(agent, "_await_approval", deny)
monkeypatch.setattr(agent, "_tool_manage_jobs", _ok_manage_jobs)
agent.channels = {"telegram": object()}

state = agent._new_request_state()
c1 = _job_call("1", task="ping mum", run_at="2026-07-01T09:00:00")
c2 = _job_call("2", task="ping dad", run_at="2026-07-02T09:00:00")

await agent._batch_approve_writes([c1, c2], "telegram", "u1", state)
r1 = await agent._execute_tool(c1, "telegram", "u1", state)
r2 = await agent._execute_tool(c2, "telegram", "u1", state)
assert "denied" in r1.get("error", "")
assert "denied" in r2.get("error", "")


@pytest.mark.asyncio
async def test_single_write_is_not_batched(agent, monkeypatch) -> None:
"""A lone write is left to the per-call path, not the batch prompt."""
prompts = {"n": 0}

async def fake_await(description, channel, user_id, tool_name=None, params=None):
prompts["n"] += 1
return "approved"

monkeypatch.setattr(agent, "_await_approval", fake_await)
agent.channels = {"telegram": object()}

state = agent._new_request_state()
await agent._batch_approve_writes(
[_job_call("1", task="ping mum", run_at="2026-07-01T09:00:00")],
"telegram",
"u1",
state,
)
assert prompts["n"] == 0 # nothing to batch for a single write
assert state["write_decisions"] == {}
Loading