Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 15 additions & 1 deletion src/iac_code/acp/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,6 +157,20 @@ def _history_message_to_updates(msg: Message) -> list[Any]:
_PREFIX_DENY_RULE = "deny_rule:"


def _tool_supports_blanket_allow(agent_loop, tool_name: str) -> bool:
"""Return False only when the registered tool explicitly disables blanket allow."""
registry = getattr(agent_loop, "tool_registry", None)
get_tool = getattr(registry, "get", None)
if get_tool is None:
return True

tool = get_tool(tool_name)
if tool is None:
return True

return bool(getattr(tool, "supports_blanket_allow", True))


class ACPSession:
def __init__(
self,
Expand Down Expand Up @@ -457,7 +471,7 @@ async def _request_permission(self, event: PermissionRequestEvent) -> bool:
kind="allow_always",
)
)
else:
elif _tool_supports_blanket_allow(self.agent_loop, tool_name):
options.append(
acp.schema.PermissionOption(
option_id=_OPTION_ALLOW_ALWAYS,
Expand Down
43 changes: 39 additions & 4 deletions src/iac_code/acp/tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

import acp

from iac_code.i18n import _
from iac_code.tools.base import Tool, ToolContext, ToolResult

TERMINAL_TIMEOUT = 300 # 5 minutes default timeout
Expand Down Expand Up @@ -32,9 +33,37 @@ def input_schema(self) -> dict:
def timeout(self) -> float | None:
return self._original.timeout

@property
def supports_blanket_allow(self) -> bool:
return self._original.supports_blanket_allow

def user_facing_name(self, input: dict | None = None) -> str:
return self._original.user_facing_name(input)

def get_activity_description(self, input: dict | None = None) -> str | None:
return self._original.get_activity_description(input)

def get_tool_use_summary(self, input: dict | None = None) -> str | None:
return self._original.get_tool_use_summary(input)

def render_tool_use_message(self, input: dict, *, verbose: bool = False) -> str | None:
return self._original.render_tool_use_message(input, verbose=verbose)

def render_tool_result_message(self, output: str, *, is_error: bool = False, verbose: bool = False) -> str | None:
return self._original.render_tool_result_message(output, is_error=is_error, verbose=verbose)

def render_tool_use_error_message(self, error: str) -> str | None:
return self._original.render_tool_use_error_message(error)

def streaming_preview_fields(self) -> list[str]:
return self._original.streaming_preview_fields()

def is_read_only(self, input: dict | None = None) -> bool:
return self._original.is_read_only(input)

def is_concurrency_safe(self, tool_input: dict) -> bool:
return self._original.is_concurrency_safe(tool_input)

def is_destructive(self, input: dict | None = None) -> bool:
return self._original.is_destructive(input)

Expand All @@ -44,7 +73,7 @@ async def check_permissions(self, input: dict, context: dict | None = None):
async def execute(self, *, tool_input: dict, context: ToolContext) -> ToolResult:
command = tool_input.get("command")
if not command:
return ToolResult.error("Bash command is required.")
return ToolResult.error(_("Bash command is required."))

timeout = tool_input.get("timeout", TERMINAL_TIMEOUT)
terminal_id: str | None = None
Expand Down Expand Up @@ -74,7 +103,7 @@ async def _wait_and_fetch():
except asyncio.TimeoutError:
with suppress(Exception):
await self._conn.kill_terminal(session_id=self._session_id, terminal_id=terminal_id)
return ToolResult.error(f"Command timed out after {timeout} seconds")
return ToolResult.error(_("Command timed out after {timeout} seconds").format(timeout=timeout))

if output.exit_status:
exit_status = output.exit_status
Expand All @@ -84,9 +113,15 @@ async def _wait_and_fetch():
# at the session layer in a future phase.

if exit_status.signal:
return ToolResult.error(f"Command terminated by signal: {exit_status.signal}\n{output.output}")
return ToolResult.error(
_("Command terminated by signal: {signal}").format(signal=exit_status.signal) + "\n" + output.output
)
if exit_status.exit_code not in (None, 0):
return ToolResult.error(f"Command failed with exit code {exit_status.exit_code}\n{output.output}")
return ToolResult.error(
_("Command failed with exit code {exit_code}").format(exit_code=exit_status.exit_code)
+ "\n"
+ output.output
)
return ToolResult.success(output.output)
except asyncio.CancelledError:
if terminal_id is not None:
Expand Down
14 changes: 11 additions & 3 deletions src/iac_code/agent/agent_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,7 @@ def __init__(
model_name = provider_manager.get_model_name()

self.context_manager = ContextManager(system_prompt=system_prompt, model=model_name)
self._sync_tool_definitions()
if resume_messages:
self.context_manager.load_messages(resume_messages)
self._tool_executor = ToolExecutor(registry=tool_registry)
Expand All @@ -116,6 +117,7 @@ def set_provider(self, provider_manager: Any, system_prompt: str | None = None)
if system_prompt is not None:
self.system_prompt = system_prompt
self.context_manager.set_system_prompt(system_prompt)
self._sync_tool_definitions()

def set_auto_trigger_skills(self, skill_commands: list[Any] | None) -> None:
"""Refresh skills considered for automatic trigger injection."""
Expand All @@ -136,6 +138,12 @@ def _get_tool_definitions(self):
)
return tools

def _sync_tool_definitions(self):
"""Refresh context token accounting from the current tool registry."""
tool_definitions = self._get_tool_definitions()
self.context_manager.set_tool_definitions(tool_definitions)
return tool_definitions

def _get_provider_messages(self):
"""Convert context manager messages to provider Message format."""
from iac_code.providers.base import ContentBlock
Expand Down Expand Up @@ -277,9 +285,9 @@ async def _run_streaming_inner(self, user_input: str | list[ContentBlock]) -> As
from iac_code.services.telemetry import start_span
from iac_code.services.telemetry.names import GenAiAttr, GenAiOperationName, GenAiSpanKind, Spans

tool_definitions = self._get_tool_definitions()

for _turn in range(self._max_turns):
tool_definitions = self._sync_tool_definitions()

# Auto-compact if needed
if self.context_manager.needs_compaction():
compact_event = await self._auto_compact()
Expand All @@ -303,7 +311,7 @@ async def _run_streaming_inner(self, user_input: str | list[ContentBlock]) -> As
async for event in self._provider_manager.stream(
messages=self._get_provider_messages(),
system=self.system_prompt,
tools=tool_definitions if self.tool_registry.list_tools() else None,
tools=tool_definitions if tool_definitions else None,
):
yield event # Forward all provider events to UI

Expand Down
104 changes: 79 additions & 25 deletions src/iac_code/agent/agent_tool.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,12 @@
from __future__ import annotations

import asyncio
from contextlib import suppress
from dataclasses import dataclass
from typing import Any

from iac_code.agent.agent_types import filter_tools, get_agent_definition, get_builtin_agents
from iac_code.i18n import _
from iac_code.i18n import _, ngettext
from iac_code.tools.base import Tool, ToolContext, ToolResult


Expand All @@ -21,6 +22,14 @@ class AgentProgress:
summary: str = ""


def _format_base_exception(error: BaseException) -> str:
detail = str(error)
error_type = type(error).__name__
if not detail:
return error_type
return "{error_type}: {detail}".format(error_type=error_type, detail=detail)


async def run_sub_agent(
*,
prompt: str,
Expand All @@ -45,7 +54,7 @@ async def run_sub_agent(

defn = get_agent_definition(agent_type)
if defn is None:
raise ValueError(f"Unknown agent type: {agent_type}")
raise ValueError(_("Unknown agent type: {agent_type}").format(agent_type=agent_type))

sub_registry = filter_tools(parent_tool_registry, defn) if parent_tool_registry else parent_tool_registry
system_prompt = parent_system_prompt or build_system_prompt(cwd=cwd)
Expand All @@ -58,7 +67,7 @@ async def run_sub_agent(
permission_context=permission_context,
)

progress = AgentProgress(summary=f"Running {agent_type} agent")
progress = AgentProgress(summary=_("Running {agent_type} agent").format(agent_type=agent_type))
text_chunks: list[str] = []
# Track tool inputs: tool_use_id -> (name, input)
pending_tool_inputs: dict[str, tuple[str, dict]] = {}
Expand Down Expand Up @@ -128,7 +137,6 @@ def __init__(
self._tool_registry = tool_registry
self._system_prompt = system_prompt
self._permission_context = permission_context
self._event_queue: asyncio.Queue | None = None # Set by ToolExecutor via ToolCallRequest

@property
def name(self) -> str:
Expand All @@ -137,8 +145,13 @@ def name(self) -> str:
@property
def description(self) -> str:
agents = get_builtin_agents()
agent_list = "\n".join(f" - {a.agent_type}: {a.when_to_use}" for a in agents)
return f"Launch a sub-agent to handle complex tasks.\n\nAvailable agent types:\n{agent_list}"
agent_list = "\n".join(
" - {agent_type}: {when_to_use}".format(agent_type=agent.agent_type, when_to_use=agent.when_to_use)
for agent in agents
)
return _("Launch a sub-agent to handle complex tasks.\n\nAvailable agent types:\n{agent_list}").format(
agent_list=agent_list
)

@property
def input_schema(self) -> dict[str, Any]:
Expand All @@ -148,20 +161,20 @@ def input_schema(self) -> dict[str, Any]:
"properties": {
"prompt": {
"type": "string",
"description": "The task for the sub-agent to perform.",
"description": _("The task for the sub-agent to perform."),
},
"description": {
"type": "string",
"description": "Short (3-5 word) description of the task.",
"description": _("Short (3-5 word) description of the task."),
},
"subagent_type": {
"type": "string",
"enum": agent_types,
"description": "The type of specialized agent to use.",
"description": _("The type of specialized agent to use."),
},
"run_in_background": {
"type": "boolean",
"description": "Run agent in background, parent continues.",
"description": _("Run agent in background, parent continues."),
},
},
"required": ["prompt", "description"],
Expand All @@ -171,18 +184,30 @@ async def execute(self, *, tool_input: dict[str, Any], context: ToolContext) ->
prompt = tool_input["prompt"]
agent_type = tool_input.get("subagent_type", tool_input.get("agent_type", "general-purpose"))
run_in_background = tool_input.get("run_in_background", False)
event_queue = context.event_queue

defn = get_agent_definition(agent_type)
if defn is None:
return ToolResult.error(f"Unknown agent type: '{agent_type}'")
return ToolResult.error(_("Unknown agent type: '{agent_type}'").format(agent_type=agent_type))

if run_in_background and self._task_manager:
task_id = self._task_manager.register(
description=tool_input.get("description", "Sub-agent task"),
description=tool_input.get("description", _("Sub-agent task")),
agent_type=agent_type,
)
asyncio.create_task(self._run_background(task_id, prompt, agent_type, context))
return ToolResult.success(f"Background agent launched (task_id: {task_id}, type: {agent_type})")
background_task = asyncio.create_task(self._run_background(task_id, prompt, agent_type, context))
attach_task = getattr(self._task_manager, "attach_task", None)
if callable(attach_task):
attach_task(task_id, background_task)
background_task.add_done_callback(self._consume_background_task_exception)
if event_queue is not None:
await event_queue.put(None)
return ToolResult.success(
_("Background agent launched (task_id: {task_id}, type: {agent_type})").format(
task_id=task_id,
agent_type=agent_type,
)
)

try:
result_text, progress = await run_sub_agent(
Expand All @@ -192,18 +217,25 @@ async def execute(self, *, tool_input: dict[str, Any], context: ToolContext) ->
parent_provider_manager=self._provider_manager,
parent_tool_registry=self._tool_registry,
parent_system_prompt=self._system_prompt,
event_queue=self._event_queue,
event_queue=event_queue,
permission_context=self._permission_context,
)
if self._event_queue:
await self._event_queue.put(None)
if event_queue is not None:
await event_queue.put(None)
return ToolResult.success(
f"{result_text}\n\n[Agent stats: {progress.tool_use_count} tool calls, {progress.token_count} tokens]"
)
except Exception as e:
if self._event_queue:
await self._event_queue.put(None)
return ToolResult.error(f"Sub-agent failed: {e}")
if event_queue is not None:
await event_queue.put(None)
return ToolResult.error(_("Sub-agent failed: {error}").format(error=e))

@staticmethod
def _consume_background_task_exception(task: asyncio.Task) -> None:
if task.cancelled():
return
with suppress(asyncio.CancelledError):
task.exception()

async def _run_background(
self,
Expand Down Expand Up @@ -231,15 +263,37 @@ async def _run_background(
if self._notification_queue:
self._notification_queue.enqueue(
task_id=task_id,
message=f"Agent completed: {progress.tool_use_count} tool calls",
message=ngettext(
"Agent completed: {tool_count} tool call",
"Agent completed: {tool_count} tool calls",
progress.tool_use_count,
).format(tool_count=progress.tool_use_count),
)
except asyncio.CancelledError:
self._task_manager.stop(task_id)
if self._notification_queue:
self._notification_queue.enqueue(
task_id=task_id,
message=_("Agent stopped"),
)
raise
except Exception as e:
self._task_manager.fail(task_id, error=str(e))
error = str(e) or type(e).__name__
self._task_manager.fail(task_id, error=error)
if self._notification_queue:
self._notification_queue.enqueue(
task_id=task_id,
message=_("Agent failed: {error}").format(error=error),
)
except BaseException as e:
error = _format_base_exception(e)
self._task_manager.fail(task_id, error=error)
if self._notification_queue:
self._notification_queue.enqueue(
task_id=task_id,
message=f"Agent failed: {e}",
message=_("Agent failed: {error}").format(error=error),
)
raise

def is_read_only(self, input: dict | None = None) -> bool:
return False
Expand All @@ -252,7 +306,7 @@ def render_tool_use_message(self, input: dict, *, verbose: bool = False) -> str

def render_tool_result_message(self, output: str, *, is_error: bool = False, verbose: bool = False) -> str | None:
if is_error:
return f"Agent error: {output[:200]}"
return _("Agent error: {error}").format(error=output[:200])
if verbose:
return output
# Extract stats from the end of the output
Expand Down Expand Up @@ -282,4 +336,4 @@ def user_facing_name(self, input: dict | None = None) -> str:
def get_activity_description(self, input: dict | None = None) -> str | None:
if input is None:
return None
return f"Running agent: {input.get('description', 'sub-agent')}"
return _("Running agent: {description}").format(description=input.get("description", _("sub-agent")))
Loading
Loading