diff --git a/core/agent.py b/core/agent.py index f5590b8..7bd5884 100644 --- a/core/agent.py +++ b/core/agent.py @@ -503,6 +503,7 @@ async def _process_injection( request_state = self._new_request_state() tool_log: list[dict] = [] while response.tool_calls: + await self._batch_approve_writes(response.tool_calls, channel, user_id, request_state) tool_results = [] for call in response.tool_calls: result = await self._execute_tool(call, channel, user_id, request_state) @@ -599,6 +600,7 @@ async def _process_session( request_state = self._new_request_state() tool_log: list[dict] = [] while response.tool_calls: + await self._batch_approve_writes(response.tool_calls, channel, user_id, request_state) tool_results = [] for call in response.tool_calls: result = await self._execute_tool(call, channel, user_id, request_state) @@ -1107,21 +1109,79 @@ async def _tool_web_search(self, params: dict) -> dict: async def _request_approval( self, tool_name: str, params: dict, channel: str, user_id: str ) -> str: - """Ask the user for approval via their channel (e.g. Telegram inline keyboard). + """Ask the user to approve a single tool call via their channel. - Creates a pending approval future, sends the prompt to the channel, - and waits for the user to respond. + Returns one of ``"approved"``, ``"denied"``, or ``"skipped"``. + """ + return await self._await_approval( + format_approval_message(tool_name, params), + channel, + user_id, + tool_name, + params, + ) + + async def _batch_approve_writes( + self, + tool_calls: list, + channel: str, + user_id: str, + request_state: dict, + ) -> None: + """Approve a turn's pending write actions with a single prompt. + + The LLM can emit several write tool calls in one response (e.g. "set + reminders for the next 5 days"). Prompting for each separately forces + the user to approve one-at-a-time. Instead, collect every write that + still needs a decision, ask once, and record the decision per action + so :meth:`_execute_tool` reuses it instead of prompting again. + + A lone write is left to the per-call path — batching only helps when + there are two or more. The decision is all-or-nothing across the batch. + """ + if channel == "system": + return + write_decisions = request_state.setdefault("write_decisions", {}) + pending: list[tuple[str, str]] = [] # (signature, description) + seen: set[str] = set() + for call in tool_calls: + if not self.permissions.is_write_action(call.name, call.arguments): + continue + if self.permissions.check(call.name, call.arguments) != PermissionLevel.ASK: + continue + sig = self._write_signature(call.name, call.arguments) + if sig in write_decisions or sig in seen: + continue + seen.add(sig) + pending.append((sig, format_approval_message(call.name, call.arguments))) + if len(pending) < 2: + return + lines = "\n\n".join(f"{i}. {desc}" for i, (_, desc) in enumerate(pending, 1)) + description = f"Approve these {len(pending)} actions?\n\n{lines}" + decision = await self._await_approval(description, channel, user_id) + for sig, _ in pending: + write_decisions[sig] = decision + + async def _await_approval( + self, + description: str, + channel: str, + user_id: str, + tool_name: str | None = None, + params: dict | None = None, + ) -> str: + """Send an approval prompt to the channel and wait for the response. + Creates a pending approval future, sends the prompt, and waits. Returns one of ``"approved"``, ``"denied"``, or ``"skipped"``. """ ch = self.channels.get(channel) if not ch: # No channel available to ask — auto-approve (e.g. admin API) - log.warning("No channel %r for approval, auto-approving %s", channel, tool_name) + log.warning("No channel %r for approval, auto-approving", channel) return "approved" request_id, future = self.permissions.create_approval_request(tool_name, params) - description = format_approval_message(tool_name, params) # Send the approval prompt via the channel try: diff --git a/tests/test_tools.py b/tests/test_tools.py index bd35e2a..5dd3649 100644 --- a/tests/test_tools.py +++ b/tests/test_tools.py @@ -362,3 +362,79 @@ async def fake_approval(name, params, channel, user_id): ) assert "skipped" in skipped.get("error", "") assert other.get("ok") is True # the skip did not leak onto a different write + + +# --------------------------------------------------------------------------- +# Batch approval — multiple writes in one turn share a single prompt (#12) +# --------------------------------------------------------------------------- + + +@pytest.mark.asyncio +async def test_batch_approval_asks_once_for_multiple_writes(agent, monkeypatch) -> None: + """Several writes in one turn must trigger exactly one approval prompt.""" + prompts = {"n": 0} + + async def fake_await(description, channel, user_id, tool_name=None, params=None): + prompts["n"] += 1 + return "approved" + + monkeypatch.setattr(agent, "_await_approval", fake_await) + monkeypatch.setattr(agent, "_tool_manage_jobs", _ok_manage_jobs) + agent.channels = {"telegram": object()} + + state = agent._new_request_state() + c1 = _job_call("1", task="ping mum", run_at="2026-07-01T09:00:00") + c2 = _job_call("2", task="ping dad", run_at="2026-07-02T09:00:00") + + await agent._batch_approve_writes([c1, c2], "telegram", "u1", state) + assert prompts["n"] == 1 # one prompt covered both writes + + r1 = await agent._execute_tool(c1, "telegram", "u1", state) + r2 = await agent._execute_tool(c2, "telegram", "u1", state) + assert prompts["n"] == 1 # execution reused the batch decision, no re-prompt + assert r1.get("ok") is True and r2.get("ok") is True + + +@pytest.mark.asyncio +async def test_batch_approval_denied_blocks_every_write(agent, monkeypatch) -> None: + """Denying the batch blocks all of its writes, not just one.""" + + async def deny(description, channel, user_id, tool_name=None, params=None): + return "denied" + + monkeypatch.setattr(agent, "_await_approval", deny) + monkeypatch.setattr(agent, "_tool_manage_jobs", _ok_manage_jobs) + agent.channels = {"telegram": object()} + + state = agent._new_request_state() + c1 = _job_call("1", task="ping mum", run_at="2026-07-01T09:00:00") + c2 = _job_call("2", task="ping dad", run_at="2026-07-02T09:00:00") + + await agent._batch_approve_writes([c1, c2], "telegram", "u1", state) + r1 = await agent._execute_tool(c1, "telegram", "u1", state) + r2 = await agent._execute_tool(c2, "telegram", "u1", state) + assert "denied" in r1.get("error", "") + assert "denied" in r2.get("error", "") + + +@pytest.mark.asyncio +async def test_single_write_is_not_batched(agent, monkeypatch) -> None: + """A lone write is left to the per-call path, not the batch prompt.""" + prompts = {"n": 0} + + async def fake_await(description, channel, user_id, tool_name=None, params=None): + prompts["n"] += 1 + return "approved" + + monkeypatch.setattr(agent, "_await_approval", fake_await) + agent.channels = {"telegram": object()} + + state = agent._new_request_state() + await agent._batch_approve_writes( + [_job_call("1", task="ping mum", run_at="2026-07-01T09:00:00")], + "telegram", + "u1", + state, + ) + assert prompts["n"] == 0 # nothing to batch for a single write + assert state["write_decisions"] == {}