From 1d9f801e9d119a6652570f850d5ffed29f947d8d Mon Sep 17 00:00:00 2001 From: Matteo Merola Date: Thu, 25 Jun 2026 09:50:29 +0200 Subject: [PATCH] fix(agent): track write-action state per action, not per request MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit A single successful, skipped, or denied write set a request-wide flag (`write_executed` / `write_decision`) that blocked every subsequent write in the same turn. Creating two jobs in one interaction, or retrying after a skip, failed with "Request already fulfilled" or "User skipped this action". Track write state per distinct action instead, keyed on a stable signature of the tool name plus its arguments: - `executed_writes`: set of completed action signatures — only an identical repeat is suppressed (still guards accidental double-sends). - `write_decisions`: per-signature approval/skip/deny decisions — reused only for the identical action, so a skip/deny never leaks onto a different write. Distinct writes (e.g. scheduling several reminders) now each run and prompt independently. Fixes #8. --- core/agent.py | 68 ++++++++++++++++++++--------- tests/test_tools.py | 101 ++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 150 insertions(+), 19 deletions(-) diff --git a/core/agent.py b/core/agent.py index b73665f..449c7ae 100644 --- a/core/agent.py +++ b/core/agent.py @@ -499,7 +499,7 @@ async def _process_injection( ) # Agentic loop — keep going while the LLM wants to call tools - request_state = {"write_executed": False, "write_decision": None, "approvals": {}} + request_state = self._new_request_state() tool_log: list[dict] = [] while response.tool_calls: tool_results = [] @@ -595,7 +595,7 @@ async def _process_session( # Agentic loop — keep going while the LLM wants to call tools new_messages: list[dict] = [] - request_state = {"write_executed": False, "write_decision": None, "approvals": {}} + request_state = self._new_request_state() tool_log: list[dict] = [] while response.tool_calls: tool_results = [] @@ -697,18 +697,27 @@ async def _execute_tool( params = tool_call.arguments if request_state is None: - request_state = {"write_executed": False, "write_decision": None, "approvals": {}} + request_state = self._new_request_state() is_write_action = self.permissions.is_write_action(name, params) - if is_write_action and request_state.get("write_executed"): - return {"error": "Request already fulfilled; not repeating write actions."} - if is_write_action and request_state.get("write_decision") == "denied": + # Write-state is tracked per distinct action (tool + params), so a + # failure, skip, or completion of one write never blocks a different one. + write_sig = self._write_signature(name, params) if is_write_action else None + executed_writes = request_state.setdefault("executed_writes", set()) + write_decisions = request_state.setdefault("write_decisions", {}) + if is_write_action and write_sig in executed_writes: + return { + "error": ( + "This exact action was already completed in this request; not repeating it." + ) + } + if is_write_action and write_decisions.get(write_sig) == "denied": return {"error": "Action denied by user."} - if is_write_action and request_state.get("write_decision") == "skipped": + if is_write_action and write_decisions.get(write_sig) == "skipped": return { "error": ( "User skipped this action. " - "Do not retry this action or attempt similar alternatives — " + "Do not retry this exact action — " "move on to something else." ) } @@ -723,17 +732,19 @@ async def _execute_tool( if level == PermissionLevel.ASK and channel != "system": match_key = self.permissions.match_key(name, params) approvals = request_state.get("approvals", {}) - if is_write_action and request_state.get("write_decision") is not None: - decision = request_state.get("write_decision") - elif isinstance(approvals, dict) and match_key in approvals: + if is_write_action and write_sig in write_decisions: + # Same write asked earlier in this turn — reuse that decision + # rather than prompting again, but only for the identical action. + decision = write_decisions[write_sig] + elif not is_write_action and isinstance(approvals, dict) and match_key in approvals: decision = approvals[match_key] else: decision = await self._request_approval(name, params, channel, user_id) - if isinstance(approvals, dict): + if is_write_action: + write_decisions[write_sig] = decision + elif isinstance(approvals, dict): approvals[match_key] = decision request_state["approvals"] = approvals - if is_write_action: - request_state["write_decision"] = decision if decision == "skipped": log.info("Permission SKIPPED by user: %s", name) return { @@ -761,25 +772,25 @@ async def _execute_tool( if name == "send_email": result = await self._tool_send_email(params) if is_write_action and self._is_tool_success(result): - request_state["write_executed"] = True + executed_writes.add(write_sig) return result if name == "reply_email": result = await self._tool_reply_email(params) if is_write_action and self._is_tool_success(result): - request_state["write_executed"] = True + executed_writes.add(write_sig) return result if name == "send_message": result = await self._tool_send_message(params) if is_write_action and self._is_tool_success(result): - request_state["write_executed"] = True + executed_writes.add(write_sig) return result if name == "create_calendar_event": result = await self._tool_create_calendar_event(params) if is_write_action and self._is_tool_success(result): - request_state["write_executed"] = True + executed_writes.add(write_sig) return result if name == "web_search": @@ -799,11 +810,30 @@ async def _execute_tool( log.info("Tool call: manage_jobs — %s", params.get("action", "")) result = await self._tool_manage_jobs(params) if is_write_action and self._is_tool_success(result): - request_state["write_executed"] = True + executed_writes.add(write_sig) return result return {"error": f"Unknown tool: {name}"} + @staticmethod + def _new_request_state() -> dict: + """Fresh per-turn state tracking write actions and approval decisions.""" + return {"executed_writes": set(), "write_decisions": {}, "approvals": {}} + + @staticmethod + def _write_signature(name: str, params: dict) -> str: + """Stable signature for a write action, keyed on tool name + arguments. + + Two calls share a signature only when they would perform the identical + write, so deduplication and remembered skip/deny decisions apply per + action rather than blocking every write after the first. + """ + try: + payload = json.dumps(params, sort_keys=True, default=str) + except Exception: + payload = repr(params) + return f"{name}:{payload}" + @staticmethod def _is_tool_success(result: dict) -> bool: if not isinstance(result, dict): diff --git a/tests/test_tools.py b/tests/test_tools.py index a7f4092..82635ab 100644 --- a/tests/test_tools.py +++ b/tests/test_tools.py @@ -166,3 +166,104 @@ async def fake_build(*args, **kwargs) -> str: third = await agent._session_system_prompt("telegram", "u1", "") assert third == "SYSTEM-2" assert calls["n"] == 2 + + +# --------------------------------------------------------------------------- +# Per-action write state — one write's outcome must not block a different one +# --------------------------------------------------------------------------- + + +def _job_call(call_id: str, **params): + from core.llm import LLMToolCall + + return LLMToolCall(id=call_id, name="manage_jobs", arguments={"action": "create", **params}) + + +async def _approve(name, params, channel, user_id): + return "approved" + + +async def _ok_manage_jobs(params): + return {"ok": True, "job_id": "job_" + params.get("task", ""), "task": params.get("task")} + + +@pytest.mark.asyncio +async def test_write_signature_distinguishes_distinct_actions(agent) -> None: + a = agent._write_signature("manage_jobs", {"action": "create", "task": "A"}) + b = agent._write_signature("manage_jobs", {"action": "create", "task": "B"}) + a_again = agent._write_signature("manage_jobs", {"task": "A", "action": "create"}) + assert a != b # different params → different signature + assert a == a_again # key order does not matter + + +@pytest.mark.asyncio +async def test_distinct_writes_are_independent_after_success(agent, monkeypatch) -> None: + """A completed write must not block a *different* subsequent write.""" + monkeypatch.setattr(agent, "_request_approval", _approve) + monkeypatch.setattr(agent, "_tool_manage_jobs", _ok_manage_jobs) + agent.channels = {"telegram": object()} # presence so approval path runs + + state = agent._new_request_state() + first = await agent._execute_tool( + _job_call("1", task="ping mum", run_at="2026-07-01T09:00:00"), + "telegram", + "u1", + state, + ) + second = await agent._execute_tool( + _job_call("2", task="ping dad", run_at="2026-07-02T09:00:00"), + "telegram", + "u1", + state, + ) + assert first.get("ok") is True + assert second.get("ok") is True # not blocked by "already fulfilled" + + +@pytest.mark.asyncio +async def test_identical_write_is_deduplicated(agent, monkeypatch) -> None: + """An identical repeated write within a turn is still suppressed.""" + monkeypatch.setattr(agent, "_request_approval", _approve) + monkeypatch.setattr(agent, "_tool_manage_jobs", _ok_manage_jobs) + agent.channels = {"telegram": object()} + + state = agent._new_request_state() + call = _job_call("1", task="ping mum", run_at="2026-07-01T09:00:00") + first = await agent._execute_tool(call, "telegram", "u1", state) + repeat = await agent._execute_tool( + _job_call("2", task="ping mum", run_at="2026-07-01T09:00:00"), + "telegram", + "u1", + state, + ) + assert first.get("ok") is True + assert "already completed" in repeat.get("error", "") + + +@pytest.mark.asyncio +async def test_skipping_one_write_does_not_block_a_different_one(agent, monkeypatch) -> None: + """Skipping a write blocks only that exact action, not other writes.""" + decisions = {"ping mum": "skipped", "ping dad": "approved"} + + async def fake_approval(name, params, channel, user_id): + return decisions.get(params.get("task"), "approved") + + monkeypatch.setattr(agent, "_request_approval", fake_approval) + monkeypatch.setattr(agent, "_tool_manage_jobs", _ok_manage_jobs) + agent.channels = {"telegram": object()} + + state = agent._new_request_state() + skipped = await agent._execute_tool( + _job_call("1", task="ping mum", run_at="2026-07-01T09:00:00"), + "telegram", + "u1", + state, + ) + other = await agent._execute_tool( + _job_call("2", task="ping dad", run_at="2026-07-02T09:00:00"), + "telegram", + "u1", + state, + ) + assert "skipped" in skipped.get("error", "") + assert other.get("ok") is True # the skip did not leak onto a different write