From dc065efda51d975b3873464078cdb904525b07e0 Mon Sep 17 00:00:00 2001 From: agent-of-mkmeral Date: Wed, 1 Apr 2026 17:41:44 +0000 Subject: [PATCH] feat: add CodeAct plugin for code-based agent interaction Implements the CodeAct paradigm as a vended Plugin that replaces standard JSON tool calling with code-based orchestration. The model generates Python code that calls tools as async functions, and the plugin executes the code locally and feeds results back via AfterInvocationEvent.resume. Key features: - Plugin-based architecture using @hook decorators - BeforeInvocationEvent: injects CodeAct instructions + tool signatures - AfterInvocationEvent: parses code blocks, executes, sets resume - Persistent namespace across turns for state accumulation - Tool wrappers that route through agent.tool.X() (record_direct_tool_call=False) - final_answer() termination mechanism - AST-level code validation (blocks exec/eval/compile/dangerous dunders) - Import restrictions with configurable allowed modules - Max iteration safety limit (default: 10) - Error feedback loop for self-correction No sandbox dependency - executes locally in host process (Phase 1). Sandbox integration can be added later. 64 new tests, all passing. 1828 existing tests, 0 regressions. --- .../vended_plugins/codeact/__init__.py | 23 + .../vended_plugins/codeact/codeact_plugin.py | 616 ++++++++++++++++ .../vended_plugins/codeact/__init__.py | 0 .../codeact/test_codeact_plugin.py | 671 ++++++++++++++++++ 4 files changed, 1310 insertions(+) create mode 100644 src/strands/vended_plugins/codeact/__init__.py create mode 100644 src/strands/vended_plugins/codeact/codeact_plugin.py create mode 100644 tests/strands/vended_plugins/codeact/__init__.py create mode 100644 tests/strands/vended_plugins/codeact/test_codeact_plugin.py diff --git a/src/strands/vended_plugins/codeact/__init__.py b/src/strands/vended_plugins/codeact/__init__.py new file mode 100644 index 000000000..f5876cdf2 --- /dev/null +++ b/src/strands/vended_plugins/codeact/__init__.py @@ -0,0 +1,23 @@ +"""CodeAct plugin for code-based agent interaction. + +CodeAct replaces standard tool calling with code-based orchestration. +The model generates Python code that calls tools as async functions, +and the plugin executes this code and feeds results back to the model. + +Example: + ```python + from strands import Agent + from strands.vended_plugins.codeact import CodeActPlugin + + agent = Agent( + tools=[shell, calculator], + plugins=[CodeActPlugin()], + ) + + result = agent("Calculate squares of 1-10 and sum them") + ``` +""" + +from .codeact_plugin import CodeActPlugin + +__all__ = ["CodeActPlugin"] diff --git a/src/strands/vended_plugins/codeact/codeact_plugin.py b/src/strands/vended_plugins/codeact/codeact_plugin.py new file mode 100644 index 000000000..ddfa6b592 --- /dev/null +++ b/src/strands/vended_plugins/codeact/codeact_plugin.py @@ -0,0 +1,616 @@ +"""CodeAct plugin implementation. + +Implements the CodeAct paradigm where the agent responds with Python code +instead of JSON tool calls. The plugin hooks into the agent lifecycle to: + +1. Modify the system prompt to instruct the model to respond with code +2. Parse code blocks from model responses +3. Execute code locally with tool wrappers injected +4. Feed execution results back via ``AfterInvocationEvent.resume`` + +The loop terminates when the model calls ``final_answer()`` or responds +without a code block. + +References: + - Apple ML Research: CodeAct (https://machinelearning.apple.com/research/codeact) + - HuggingFace smolagents (https://huggingface.co/docs/smolagents/en/index) + - Anthropic advanced tool use (https://www.anthropic.com/engineering/advanced-tool-use) +""" + +from __future__ import annotations + +import ast +import asyncio +import io +import logging +import re +import textwrap +from contextlib import redirect_stdout +from typing import TYPE_CHECKING, Any + +from ...hooks.events import AfterInvocationEvent, BeforeInvocationEvent +from ...plugins import Plugin, hook + +if TYPE_CHECKING: + from ...agent.agent import Agent + +logger = logging.getLogger(__name__) + +_DEFAULT_MAX_ITERATIONS = 10 + +_CODEACT_SYSTEM_PROMPT_PREFIX = """You are a CodeAct agent. Instead of calling tools via JSON, you write Python code to accomplish tasks. + +## How to respond + +When you need to take action, write Python code inside a ```python code block. Your code can: +- Call available tools as async functions (e.g., `result = await shell(command="ls -la")`) +- Use loops, conditionals, and data transformations +- Store intermediate results in variables +- Use `print()` to output information (printed output will be shown back to you) + +When you have the final answer, call `final_answer("your answer here")` in your code. + +If you can answer directly without tools (e.g., for simple questions), respond in plain text without a code block. + +## Available tools + +The following tools are available as async functions in your code: + +""" + +_CODEACT_OBSERVATION_PREFIX = "**Observation:**\n```\n" +_CODEACT_OBSERVATION_SUFFIX = "\n```" + + +class CodeActPlugin(Plugin): + """Plugin that implements the CodeAct paradigm. + + CodeAct replaces standard tool calling with code-based orchestration. + The model generates Python code that calls tools as async functions, + and the plugin executes this code and feeds results back to the model. + + The plugin maintains a persistent namespace across turns, so variables + set in one code block are available in the next. This enables multi-step + reasoning with state accumulation. + + Code executes locally in the host process. Tool calls within the code + are routed through the agent's tool caller (``agent.tool.X()``). + + Args: + max_iterations: Maximum number of code execution rounds before + forcing termination. Defaults to 10. + allowed_modules: Optional set of module names that can be imported + in generated code. If None, a default safe set is used. + + Example: + ```python + from strands import Agent + from strands.vended_plugins.codeact import CodeActPlugin + + agent = Agent( + tools=[shell, calculator, http_request], + plugins=[CodeActPlugin(max_iterations=15)], + ) + + result = agent("Fetch HN front page, extract titles, save to file") + ``` + """ + + name = "codeact" + + _DEFAULT_ALLOWED_MODULES = frozenset({ + "math", + "json", + "re", + "collections", + "itertools", + "functools", + "datetime", + "os.path", + "pathlib", + "textwrap", + "statistics", + "string", + "random", + "hashlib", + "base64", + "urllib.parse", + "csv", + "io", + }) + + def __init__( + self, + max_iterations: int = _DEFAULT_MAX_ITERATIONS, + allowed_modules: set[str] | None = None, + ) -> None: + """Initialize the CodeAct plugin. + + Args: + max_iterations: Maximum number of code execution rounds before + forcing termination. + allowed_modules: Optional set of module names that can be imported + in generated code. If None, a default safe set is used. + """ + self._max_iterations = max_iterations + self._allowed_modules = ( + frozenset(allowed_modules) if allowed_modules is not None else self._DEFAULT_ALLOWED_MODULES + ) + super().__init__() + + def init_agent(self, agent: Agent) -> None: + """Initialize the plugin with an agent instance. + + Args: + agent: The agent instance to extend with CodeAct support. + """ + logger.debug( + "max_iterations=<%d>, allowed_modules=<%d> | codeact plugin initialized", + self._max_iterations, + len(self._allowed_modules), + ) + + @hook + def _on_before_invocation(self, event: BeforeInvocationEvent) -> None: + """Inject CodeAct instructions and tool signatures into the system prompt. + + Modifies the system prompt to instruct the model to respond with Python + code blocks instead of JSON tool calls. Injects function signatures for + all available tools so the model knows how to call them. + + Also initializes per-invocation state (namespace, iteration counter). + + Args: + event: The before-invocation event containing the agent reference. + """ + agent = event.agent + + # Initialize CodeAct state in invocation_state + event.invocation_state.setdefault("codeact_namespace", self._build_initial_namespace(agent)) + event.invocation_state.setdefault("codeact_iteration", 0) + + # Build tool signatures + tool_signatures = self._build_tool_signatures(agent) + + # Inject CodeAct instructions into system prompt + current_prompt = agent.system_prompt or "" + + # Remove previously injected CodeAct block if present + state_data = agent.state.get("codeact") + last_injected = state_data.get("last_injected_prompt") if isinstance(state_data, dict) else None + if last_injected and last_injected in current_prompt: + current_prompt = current_prompt.replace(last_injected, "") + + codeact_block = _CODEACT_SYSTEM_PROMPT_PREFIX + tool_signatures + injection = f"\n\n{codeact_block}" + new_prompt = f"{current_prompt}{injection}" if current_prompt else codeact_block + + # Track what we injected for cleanup + new_injected = injection if current_prompt else codeact_block + self._set_state_field(agent, "last_injected_prompt", new_injected) + + agent.system_prompt = new_prompt + + @hook + def _on_after_invocation(self, event: AfterInvocationEvent) -> None: + """Parse code from model response, execute it, and resume with results. + + Extracts Python code blocks from the model's response, executes them + in a persistent namespace with tool wrappers available, and sets + ``event.resume`` with the execution output so the agent loops back + for another turn. + + The loop terminates when: + - The model calls ``final_answer()`` + - The model responds without a code block + - Maximum iterations are reached + + Args: + event: The after-invocation event containing the agent result. + """ + if event.result is None: + return + + # Get iteration count + iteration = event.invocation_state.get("codeact_iteration", 0) + if iteration >= self._max_iterations: + logger.warning( + "iteration=<%d>, max=<%d> | codeact max iterations reached, stopping", + iteration, + self._max_iterations, + ) + return + + # Get the model's response text + response_text = self._extract_response_text(event.result) + if not response_text: + return + + # Parse code block from response + code = self._parse_code_block(response_text) + if not code: + # No code block — model is responding in plain text, we're done + logger.debug("iteration=<%d> | no code block found, codeact loop complete", iteration) + return + + # Get or create namespace + namespace = event.invocation_state.get("codeact_namespace", {}) + + # Check if code has already set final_answer before execution + # (in case of re-entry) + if namespace.get("__final_answer__") is not None: + logger.debug("iteration=<%d> | final_answer already set, stopping", iteration) + return + + # Validate the code before execution + validation_error = self._validate_code(code) + if validation_error: + output = f"Code validation error: {validation_error}" + event.invocation_state["codeact_iteration"] = iteration + 1 + event.resume = f"{_CODEACT_OBSERVATION_PREFIX}{output}{_CODEACT_OBSERVATION_SUFFIX}" + return + + # Execute the code + logger.debug("iteration=<%d>, code_length=<%d> | executing codeact code block", iteration, len(code)) + output = self._execute_code(code, namespace) + + # Check if final_answer was called + if namespace.get("__final_answer__") is not None: + final = namespace["__final_answer__"] + logger.debug("iteration=<%d> | final_answer called", iteration) + # Set resume with final answer so the model can produce a clean response + event.resume = ( + f"{_CODEACT_OBSERVATION_PREFIX}{output}\n\n" + f"final_answer was called with: {final}{_CODEACT_OBSERVATION_SUFFIX}\n\n" + f"Provide the final answer to the user based on the above result. " + f"Do NOT write any more code." + ) + # Clear namespace to prevent re-triggering + namespace["__final_answer_delivered__"] = True + return + + # Check if this is a post-final-answer turn (model should just respond) + if namespace.get("__final_answer_delivered__"): + return + + # Resume with execution output for next iteration + event.invocation_state["codeact_iteration"] = iteration + 1 + event.resume = f"{_CODEACT_OBSERVATION_PREFIX}{output}{_CODEACT_OBSERVATION_SUFFIX}" + + def _build_initial_namespace(self, agent: Agent) -> dict[str, Any]: + """Build the initial execution namespace with tool wrappers and builtins. + + Creates async wrapper functions for each agent tool and injects them + into a namespace dict. Also adds ``final_answer()`` and a restricted + ``__import__`` that only allows safe modules. + + Args: + agent: The agent whose tools to wrap. + + Returns: + Namespace dict ready for code execution. + """ + namespace: dict[str, Any] = {} + + # Add restricted builtins + safe_builtins = { + k: v + for k, v in __builtins__.items() # type: ignore[union-attr] + if k + not in { + "exec", + "eval", + "compile", + "__import__", + "globals", + "locals", + "breakpoint", + "exit", + "quit", + } + } + + # Add controlled import + def _safe_import(name: str, *args: Any, **kwargs: Any) -> Any: + """Import only allowed modules.""" + # Check if the module or its parent is allowed + parts = name.split(".") + for i in range(len(parts), 0, -1): + if ".".join(parts[:i]) in self._allowed_modules: + return __builtins__["__import__"](name, *args, **kwargs) # type: ignore[index] + raise ImportError( + f"Import of '{name}' is not allowed. " + f"Allowed modules: {', '.join(sorted(self._allowed_modules))}" + ) + + safe_builtins["__import__"] = _safe_import + namespace["__builtins__"] = safe_builtins + + # Add asyncio for await support + namespace["asyncio"] = asyncio + + # Add final_answer function + def final_answer(result: Any) -> None: + """Call this when you have the final answer.""" + namespace["__final_answer__"] = result + + namespace["final_answer"] = final_answer + + # Add tool wrappers + for tool_name in agent.tool_registry.registry: + namespace[tool_name.replace("-", "_")] = self._make_tool_wrapper(agent, tool_name) + + return namespace + + def _make_tool_wrapper(self, agent: Agent, tool_name: str) -> Any: + """Create an async wrapper function for an agent tool. + + The wrapper calls ``agent.tool.X()`` when invoked, making the tool + available as a regular async function in the code execution namespace. + + Args: + agent: The agent instance. + tool_name: Name of the tool to wrap. + + Returns: + An async function that calls the tool. + """ + # Normalize for Python identifier + python_name = tool_name.replace("-", "_") + + async def tool_wrapper(**kwargs: Any) -> Any: + """Async wrapper that calls agent.tool.{tool_name}().""" + logger.debug("tool_name=<%s> | codeact calling tool", tool_name) + try: + caller = getattr(agent.tool, python_name) + result = caller(record_direct_tool_call=False, **kwargs) + # Extract text content from ToolResult + if isinstance(result, dict) and "content" in result: + content = result["content"] + if isinstance(content, list): + texts = [] + for block in content: + if isinstance(block, dict) and "text" in block: + texts.append(block["text"]) + return "\n".join(texts) if texts else str(result) + return result + except Exception as e: + logger.warning("tool_name=<%s>, error=<%s> | codeact tool call failed", tool_name, e) + raise + + tool_wrapper.__name__ = python_name + tool_wrapper.__qualname__ = python_name + + # Add docstring from tool spec for model context + tool_config = agent.tool_registry.get_all_tools_config() + spec = tool_config.get(tool_name, {}) + description = spec.get("description", f"Call the {tool_name} tool.") + tool_wrapper.__doc__ = description + + return tool_wrapper + + def _build_tool_signatures(self, agent: Agent) -> str: + """Generate Python function signatures for all available tools. + + Reads tool specs from the agent's tool registry and formats them + as function signatures with type hints and docstrings. + + Args: + agent: The agent whose tools to describe. + + Returns: + Formatted string of tool function signatures. + """ + tool_config = agent.tool_registry.get_all_tools_config() + signatures = [] + + for tool_name, spec in tool_config.items(): + python_name = tool_name.replace("-", "_") + description = spec.get("description", "") + input_schema = spec.get("inputSchema", {}).get("json", {}) + properties = input_schema.get("properties", {}) + required = set(input_schema.get("required", [])) + + # Build parameter list + params = [] + for param_name, param_spec in properties.items(): + param_type = self._json_type_to_python(param_spec.get("type", "str")) + param_desc = param_spec.get("description", "") + + if param_name in required: + params.append(f"{param_name}: {param_type}") + else: + default = param_spec.get("default", None) + params.append(f"{param_name}: {param_type} = {repr(default)}") + + params_str = ", ".join(params) + sig = f"async def {python_name}({params_str}) -> str" + + # Format with docstring + entry = f'{sig}:\n """{description}"""' + signatures.append(entry) + + return "\n\n".join(signatures) + + @staticmethod + def _json_type_to_python(json_type: str) -> str: + """Convert JSON schema type to Python type hint string. + + Args: + json_type: The JSON schema type string. + + Returns: + Python type hint string. + """ + type_map = { + "string": "str", + "integer": "int", + "number": "float", + "boolean": "bool", + "array": "list", + "object": "dict", + } + return type_map.get(json_type, "str") + + @staticmethod + def _extract_response_text(result: Any) -> str: + """Extract text content from an AgentResult. + + Args: + result: The AgentResult from the model. + + Returns: + The text content of the response, or empty string. + """ + try: + message = result.message + if message and "content" in message: + texts = [] + for block in message["content"]: + if isinstance(block, dict) and "text" in block: + texts.append(block["text"]) + return "\n".join(texts) + except (AttributeError, KeyError, TypeError): + pass + + # Fallback: try string conversion + text = str(result) + return text if text and text != "None" else "" + + @staticmethod + def _parse_code_block(text: str) -> str | None: + """Extract Python code from a markdown code block. + + Supports both ```python and ``` (bare) fenced code blocks. + Returns the first code block found, or None if no code block exists. + + Args: + text: The model's response text. + + Returns: + The extracted code, or None if no code block was found. + """ + # Match ```python ... ``` or ```py ... ``` + patterns = [ + r"```python\s*\n(.*?)```", + r"```py\s*\n(.*?)```", + r"```\s*\n(.*?)```", + ] + + for pattern in patterns: + match = re.search(pattern, text, re.DOTALL) + if match: + code = match.group(1).strip() + if code: + return code + + return None + + def _validate_code(self, code: str) -> str | None: + """Validate code before execution. + + Performs AST-level validation to catch syntax errors and + potentially dangerous constructs. + + Args: + code: The code to validate. + + Returns: + Error message string if validation fails, None if valid. + """ + try: + tree = ast.parse(code) + except SyntaxError as e: + return f"SyntaxError: {e}" + + # Check for disallowed constructs + for node in ast.walk(tree): + # Block eval/exec calls + if isinstance(node, ast.Call): + func = node.func + if isinstance(func, ast.Name) and func.id in ("exec", "eval", "compile"): + return f"Use of '{func.id}()' is not allowed" + + # Block __dunder__ attribute access (except __init__, __name__, etc.) + if isinstance(node, ast.Attribute): + if ( + node.attr.startswith("__") + and node.attr.endswith("__") + and node.attr not in ("__init__", "__name__", "__doc__", "__class__", "__len__", "__str__") + ): + return f"Access to '{node.attr}' is not allowed" + + return None + + def _execute_code(self, code: str, namespace: dict[str, Any]) -> str: + """Execute Python code in the given namespace, capturing output. + + Wraps the code in an async function to support ``await`` calls, + then executes it and captures stdout. Local variables from the code + are copied back to the namespace for persistence across turns. + + Args: + code: The Python code to execute. + namespace: The execution namespace (persistent across turns). + + Returns: + Captured stdout output, or error message if execution failed. + """ + # Wrap code in async function to support top-level await. + # The __ns__ parameter receives the namespace dict so we can + # copy local variables back after execution (otherwise they'd + # be lost when the function returns). + indented_code = textwrap.indent(code, " ") + ns_update = ' __ns__.update({k: v for k, v in locals().items() if not k.startswith("_")})' + wrapped = f"async def __codeact_main__(__ns__):\n{indented_code}\n{ns_update}\n" + + stdout_capture = io.StringIO() + + try: + # Compile and exec the wrapper function definition + compiled = compile(wrapped, "", "exec") + exec(compiled, namespace) # noqa: S102 + + # Run the async main function, capturing stdout + with redirect_stdout(stdout_capture): + loop = asyncio.new_event_loop() + try: + loop.run_until_complete(namespace["__codeact_main__"](namespace)) + finally: + loop.close() + + output = stdout_capture.getvalue() + if not output: + output = "(No output)" + + logger.debug("output_length=<%d> | codeact execution complete", len(output)) + return output + + except Exception as e: + captured = stdout_capture.getvalue() + error_output = f"Error: {type(e).__name__}: {e}" + if captured: + error_output = f"{captured}\n{error_output}" + + logger.debug("error=<%s> | codeact execution failed", e) + return error_output + + finally: + # Clean up the async function from namespace + namespace.pop("__codeact_main__", None) + + def _set_state_field(self, agent: Agent, key: str, value: Any) -> None: + """Set a single field in the plugin's agent state dict. + + Args: + agent: The agent whose state to update. + key: The state field key. + value: The value to set. + """ + state_data = agent.state.get("codeact") + if state_data is not None and not isinstance(state_data, dict): + state_data = {} + if state_data is None: + state_data = {} + state_data[key] = value + agent.state.set("codeact", state_data) diff --git a/tests/strands/vended_plugins/codeact/__init__.py b/tests/strands/vended_plugins/codeact/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/strands/vended_plugins/codeact/test_codeact_plugin.py b/tests/strands/vended_plugins/codeact/test_codeact_plugin.py new file mode 100644 index 000000000..8086bb623 --- /dev/null +++ b/tests/strands/vended_plugins/codeact/test_codeact_plugin.py @@ -0,0 +1,671 @@ +"""Tests for CodeAct plugin. + +Tests cover: +- Plugin initialization and lifecycle +- System prompt injection +- Code block parsing +- Code validation (AST checks) +- Code execution with namespace persistence +- Tool wrapper generation and invocation +- final_answer() termination +- Max iterations safety +- Error handling and self-correction loop +- Import restrictions +""" + +import asyncio +from unittest.mock import MagicMock, patch + +import pytest + +from strands.agent.agent_result import AgentResult +from strands.hooks.events import AfterInvocationEvent, BeforeInvocationEvent +from strands.vended_plugins.codeact.codeact_plugin import ( + CodeActPlugin, + _CODEACT_OBSERVATION_PREFIX, + _CODEACT_OBSERVATION_SUFFIX, + _CODEACT_SYSTEM_PROMPT_PREFIX, +) + + +@pytest.fixture +def plugin(): + """Create a CodeAct plugin with default settings.""" + return CodeActPlugin() + + +@pytest.fixture +def plugin_custom(): + """Create a CodeAct plugin with custom settings.""" + return CodeActPlugin(max_iterations=3, allowed_modules={"math", "json"}) + + +@pytest.fixture +def mock_agent(): + """Create a mock agent with tool registry.""" + agent = MagicMock() + agent.system_prompt = "You are a helpful assistant." + agent.state = MagicMock() + agent.state.get.return_value = None + + # Mock tool registry + mock_tool = MagicMock() + mock_tool.tool_spec = { + "name": "calculator", + "description": "Perform calculations", + "inputSchema": { + "json": { + "type": "object", + "properties": { + "expression": { + "type": "string", + "description": "Math expression to evaluate", + } + }, + "required": ["expression"], + } + }, + } + + mock_shell_tool = MagicMock() + mock_shell_tool.tool_spec = { + "name": "shell", + "description": "Execute shell commands", + "inputSchema": { + "json": { + "type": "object", + "properties": { + "command": { + "type": "string", + "description": "Shell command to execute", + }, + "timeout": { + "type": "integer", + "description": "Timeout in seconds", + }, + }, + "required": ["command"], + } + }, + } + + agent.tool_registry = MagicMock() + agent.tool_registry.registry = {"calculator": mock_tool, "shell": mock_shell_tool} + agent.tool_registry.get_all_tools_config.return_value = { + "calculator": mock_tool.tool_spec, + "shell": mock_shell_tool.tool_spec, + } + + # Mock tool caller + agent.tool = MagicMock() + + return agent + + +@pytest.fixture +def mock_result(): + """Create a mock AgentResult with text content.""" + + def _make_result(text): + result = MagicMock() + result.message = {"content": [{"text": text}]} + result.__str__ = lambda self: text + return result + + return _make_result + + +class TestPluginInitialization: + """Test plugin setup and configuration.""" + + def test_plugin_name(self, plugin): + """Plugin should have correct name.""" + assert plugin.name == "codeact" + + def test_default_max_iterations(self, plugin): + """Default max iterations should be 10.""" + assert plugin._max_iterations == 10 + + def test_custom_max_iterations(self, plugin_custom): + """Custom max iterations should be respected.""" + assert plugin_custom._max_iterations == 3 + + def test_default_allowed_modules(self, plugin): + """Default allowed modules should include safe standard library modules.""" + assert "math" in plugin._allowed_modules + assert "json" in plugin._allowed_modules + assert "re" in plugin._allowed_modules + assert "os.path" in plugin._allowed_modules + + def test_custom_allowed_modules(self, plugin_custom): + """Custom allowed modules should override defaults.""" + assert plugin_custom._allowed_modules == frozenset({"math", "json"}) + + def test_hooks_registered(self, plugin): + """Plugin should have hooks registered.""" + assert len(plugin.hooks) == 2 # before_invocation + after_invocation + + +class TestSystemPromptInjection: + """Test BeforeInvocationEvent hook.""" + + def test_injects_codeact_instructions(self, plugin, mock_agent): + """Should inject CodeAct instructions into system prompt.""" + event = BeforeInvocationEvent(agent=mock_agent, invocation_state={}) + plugin._on_before_invocation(event) + + new_prompt = mock_agent.system_prompt + assert "CodeAct agent" in new_prompt + assert "```python" in new_prompt + + def test_injects_tool_signatures(self, plugin, mock_agent): + """Should inject tool function signatures.""" + event = BeforeInvocationEvent(agent=mock_agent, invocation_state={}) + plugin._on_before_invocation(event) + + new_prompt = mock_agent.system_prompt + assert "calculator" in new_prompt + assert "expression: str" in new_prompt + assert "shell" in new_prompt + assert "command: str" in new_prompt + + def test_initializes_invocation_state(self, plugin, mock_agent): + """Should initialize codeact_namespace and codeact_iteration.""" + invocation_state = {} + event = BeforeInvocationEvent(agent=mock_agent, invocation_state=invocation_state) + plugin._on_before_invocation(event) + + assert "codeact_namespace" in invocation_state + assert "codeact_iteration" in invocation_state + assert invocation_state["codeact_iteration"] == 0 + + def test_namespace_has_tools(self, plugin, mock_agent): + """Namespace should contain tool wrapper functions.""" + invocation_state = {} + event = BeforeInvocationEvent(agent=mock_agent, invocation_state=invocation_state) + plugin._on_before_invocation(event) + + namespace = invocation_state["codeact_namespace"] + assert "calculator" in namespace + assert "shell" in namespace + assert callable(namespace["calculator"]) + assert callable(namespace["shell"]) + + def test_namespace_has_final_answer(self, plugin, mock_agent): + """Namespace should contain final_answer function.""" + invocation_state = {} + event = BeforeInvocationEvent(agent=mock_agent, invocation_state=invocation_state) + plugin._on_before_invocation(event) + + namespace = invocation_state["codeact_namespace"] + assert "final_answer" in namespace + assert callable(namespace["final_answer"]) + + def test_preserves_existing_system_prompt(self, plugin, mock_agent): + """Should preserve existing system prompt content.""" + mock_agent.system_prompt = "Original prompt." + event = BeforeInvocationEvent(agent=mock_agent, invocation_state={}) + plugin._on_before_invocation(event) + + new_prompt = mock_agent.system_prompt + assert "Original prompt." in new_prompt + + def test_handles_none_system_prompt(self, plugin, mock_agent): + """Should handle None system prompt gracefully.""" + mock_agent.system_prompt = None + event = BeforeInvocationEvent(agent=mock_agent, invocation_state={}) + plugin._on_before_invocation(event) + + assert mock_agent.system_prompt is not None + assert "CodeAct" in mock_agent.system_prompt + + +class TestCodeBlockParsing: + """Test code block extraction from model responses.""" + + def test_parse_python_code_block(self, plugin): + """Should parse ```python code blocks.""" + text = "Here's the code:\n```python\nprint('hello')\n```\nDone." + code = plugin._parse_code_block(text) + assert code == "print('hello')" + + def test_parse_py_code_block(self, plugin): + """Should parse ```py code blocks.""" + text = "```py\nx = 1 + 2\nprint(x)\n```" + code = plugin._parse_code_block(text) + assert code == "x = 1 + 2\nprint(x)" + + def test_parse_bare_code_block(self, plugin): + """Should parse bare ``` code blocks.""" + text = "```\nresult = 42\nprint(result)\n```" + code = plugin._parse_code_block(text) + assert code == "result = 42\nprint(result)" + + def test_no_code_block(self, plugin): + """Should return None when no code block present.""" + text = "The answer is 42." + code = plugin._parse_code_block(text) + assert code is None + + def test_empty_code_block(self, plugin): + """Should return None for empty code blocks.""" + text = "```python\n\n```" + code = plugin._parse_code_block(text) + assert code is None + + def test_multiline_code(self, plugin): + """Should handle multi-line code blocks.""" + text = """Here's the solution: +```python +total = 0 +for i in range(10): + total += i +print(f"Sum: {total}") +``` +""" + code = plugin._parse_code_block(text) + assert "total = 0" in code + assert "for i in range(10):" in code + assert "print" in code + + def test_first_code_block_wins(self, plugin): + """Should return the first code block when multiple exist.""" + text = "```python\nfirst_block()\n```\n\n```python\nsecond_block()\n```" + code = plugin._parse_code_block(text) + assert code == "first_block()" + + def test_prefers_python_over_bare(self, plugin): + """Should prefer ```python over bare ``` blocks.""" + text = "```\nbare_block()\n```\n\n```python\npython_block()\n```" + code = plugin._parse_code_block(text) + # The python pattern is checked first + assert code == "python_block()" + + +class TestCodeValidation: + """Test AST-level code validation.""" + + def test_valid_code(self, plugin): + """Should accept valid Python code.""" + assert plugin._validate_code("x = 1 + 2\nprint(x)") is None + + def test_syntax_error(self, plugin): + """Should catch syntax errors.""" + result = plugin._validate_code("def foo(:\n pass") + assert result is not None + assert "SyntaxError" in result + + def test_blocks_exec(self, plugin): + """Should block exec() calls.""" + result = plugin._validate_code("exec('print(1)')") + assert result is not None + assert "exec" in result + + def test_blocks_eval(self, plugin): + """Should block eval() calls.""" + result = plugin._validate_code("eval('1+1')") + assert result is not None + assert "eval" in result + + def test_blocks_compile(self, plugin): + """Should block compile() calls.""" + result = plugin._validate_code("compile('x=1', '', 'exec')") + assert result is not None + assert "compile" in result + + def test_blocks_dunder_access(self, plugin): + """Should block dangerous __dunder__ access.""" + result = plugin._validate_code("obj.__subclasses__()") + assert result is not None + assert "__subclasses__" in result + + def test_allows_safe_dunders(self, plugin): + """Should allow safe __dunder__ attributes.""" + assert plugin._validate_code("x.__name__") is None + assert plugin._validate_code("x.__doc__") is None + assert plugin._validate_code("x.__class__") is None + assert plugin._validate_code("x.__len__()") is None + + def test_allows_loops_and_conditionals(self, plugin): + """Should allow normal control flow.""" + code = """ +for i in range(10): + if i % 2 == 0: + print(i) +""" + assert plugin._validate_code(code) is None + + def test_allows_async_await(self, plugin): + """Should allow async/await constructs.""" + code = """ +async def foo(): + result = await bar() + return result +""" + assert plugin._validate_code(code) is None + + +class TestCodeExecution: + """Test code execution with namespace persistence.""" + + def test_simple_execution(self, plugin): + """Should execute simple code and capture output.""" + namespace = {"__builtins__": __builtins__} + output = plugin._execute_code("print('hello world')", namespace) + assert "hello world" in output + + def test_namespace_persistence(self, plugin): + """Variables should persist across execution calls.""" + namespace = {"__builtins__": __builtins__} + plugin._execute_code("x = 42", namespace) + assert namespace.get("x") == 42 + + output = plugin._execute_code("print(x * 2)", namespace) + assert "84" in output + + def test_error_handling(self, plugin): + """Should capture exceptions as output.""" + namespace = {"__builtins__": __builtins__} + output = plugin._execute_code("1/0", namespace) + assert "ZeroDivisionError" in output + + def test_no_output(self, plugin): + """Should return '(No output)' when nothing is printed.""" + namespace = {"__builtins__": __builtins__} + output = plugin._execute_code("x = 1 + 1", namespace) + assert output == "(No output)" + + def test_partial_output_on_error(self, plugin): + """Should include partial output before error.""" + namespace = {"__builtins__": __builtins__} + output = plugin._execute_code("print('before')\n1/0", namespace) + assert "before" in output + assert "ZeroDivisionError" in output + + def test_async_execution(self, plugin): + """Should support await expressions.""" + namespace = {"__builtins__": __builtins__, "asyncio": asyncio} + output = plugin._execute_code("result = await asyncio.sleep(0)\nprint('async done')", namespace) + assert "async done" in output + + def test_final_answer_in_code(self, plugin, mock_agent): + """final_answer() should set namespace flag.""" + invocation_state = {} + event = BeforeInvocationEvent(agent=mock_agent, invocation_state=invocation_state) + plugin._on_before_invocation(event) + + namespace = invocation_state["codeact_namespace"] + plugin._execute_code('final_answer("the answer is 42")', namespace) + assert namespace["__final_answer__"] == "the answer is 42" + + +class TestAfterInvocationHook: + """Test the after-invocation hook (code execution + resume loop).""" + + def test_no_code_block_stops_loop(self, plugin, mock_agent, mock_result): + """Should not resume when model responds without code.""" + event = AfterInvocationEvent( + agent=mock_agent, + invocation_state={"codeact_iteration": 0, "codeact_namespace": {}}, + result=mock_result("The answer is 42."), + ) + plugin._on_after_invocation(event) + assert event.resume is None + + def test_code_execution_sets_resume(self, plugin, mock_agent, mock_result): + """Should resume with execution output when code block found.""" + namespace = {"__builtins__": __builtins__, "asyncio": asyncio} + namespace["final_answer"] = lambda r: namespace.__setitem__("__final_answer__", r) + + event = AfterInvocationEvent( + agent=mock_agent, + invocation_state={"codeact_iteration": 0, "codeact_namespace": namespace}, + result=mock_result("```python\nprint('hello from code')\n```"), + ) + plugin._on_after_invocation(event) + + assert event.resume is not None + assert "hello from code" in event.resume + + def test_final_answer_stops_loop(self, plugin, mock_agent, mock_result): + """Should include final_answer content in resume.""" + namespace = {"__builtins__": __builtins__, "asyncio": asyncio} + namespace["final_answer"] = lambda r: namespace.__setitem__("__final_answer__", r) + + event = AfterInvocationEvent( + agent=mock_agent, + invocation_state={"codeact_iteration": 0, "codeact_namespace": namespace}, + result=mock_result('```python\nfinal_answer("result is 42")\n```'), + ) + plugin._on_after_invocation(event) + + assert event.resume is not None + assert "final_answer was called" in event.resume + assert "result is 42" in event.resume + assert "Do NOT write any more code" in event.resume + + def test_max_iterations_stops_loop(self, plugin, mock_agent, mock_result): + """Should stop when max iterations reached.""" + event = AfterInvocationEvent( + agent=mock_agent, + invocation_state={"codeact_iteration": 10, "codeact_namespace": {}}, + result=mock_result("```python\nprint('more code')\n```"), + ) + plugin._on_after_invocation(event) + assert event.resume is None + + def test_iteration_counter_increments(self, plugin, mock_agent, mock_result): + """Should increment iteration counter on each execution.""" + namespace = {"__builtins__": __builtins__, "asyncio": asyncio} + namespace["final_answer"] = lambda r: namespace.__setitem__("__final_answer__", r) + invocation_state = {"codeact_iteration": 0, "codeact_namespace": namespace} + + event = AfterInvocationEvent( + agent=mock_agent, + invocation_state=invocation_state, + result=mock_result("```python\nprint('iteration 1')\n```"), + ) + plugin._on_after_invocation(event) + assert invocation_state["codeact_iteration"] == 1 + + def test_validation_error_feeds_back(self, plugin, mock_agent, mock_result): + """Should feed validation errors back for self-correction.""" + namespace = {"__builtins__": __builtins__, "asyncio": asyncio} + namespace["final_answer"] = lambda r: namespace.__setitem__("__final_answer__", r) + + event = AfterInvocationEvent( + agent=mock_agent, + invocation_state={"codeact_iteration": 0, "codeact_namespace": namespace}, + result=mock_result("```python\nexec('malicious')\n```"), + ) + plugin._on_after_invocation(event) + + assert event.resume is not None + assert "validation error" in event.resume.lower() + + def test_none_result_is_noop(self, plugin, mock_agent): + """Should do nothing when result is None.""" + event = AfterInvocationEvent( + agent=mock_agent, + invocation_state={"codeact_iteration": 0}, + result=None, + ) + plugin._on_after_invocation(event) + assert event.resume is None + + def test_post_final_answer_is_noop(self, plugin, mock_agent, mock_result): + """Should not execute code after final_answer was delivered.""" + namespace = {"__builtins__": __builtins__, "__final_answer_delivered__": True} + + event = AfterInvocationEvent( + agent=mock_agent, + invocation_state={"codeact_iteration": 1, "codeact_namespace": namespace}, + result=mock_result("```python\nprint('should not run')\n```"), + ) + plugin._on_after_invocation(event) + assert event.resume is None + + +class TestToolSignatureGeneration: + """Test tool signature formatting for system prompt.""" + + def test_generates_signatures(self, plugin, mock_agent): + """Should generate function signatures from tool specs.""" + signatures = plugin._build_tool_signatures(mock_agent) + assert "async def calculator(expression: str) -> str" in signatures + assert "async def shell(command: str" in signatures + + def test_optional_parameters(self, plugin, mock_agent): + """Should format optional params with defaults.""" + signatures = plugin._build_tool_signatures(mock_agent) + # timeout is optional in shell tool + assert "timeout" in signatures + + def test_includes_descriptions(self, plugin, mock_agent): + """Should include tool descriptions as docstrings.""" + signatures = plugin._build_tool_signatures(mock_agent) + assert "Perform calculations" in signatures + assert "Execute shell commands" in signatures + + +class TestImportRestrictions: + """Test import safety in code execution.""" + + def test_allowed_import(self, plugin, mock_agent): + """Should allow imports of safe modules.""" + invocation_state = {} + event = BeforeInvocationEvent(agent=mock_agent, invocation_state=invocation_state) + plugin._on_before_invocation(event) + + namespace = invocation_state["codeact_namespace"] + output = plugin._execute_code("import math\nprint(math.pi)", namespace) + assert "3.14" in output + + def test_blocked_import(self, plugin, mock_agent): + """Should block imports of disallowed modules.""" + invocation_state = {} + event = BeforeInvocationEvent(agent=mock_agent, invocation_state=invocation_state) + plugin._on_before_invocation(event) + + namespace = invocation_state["codeact_namespace"] + output = plugin._execute_code("import subprocess", namespace) + assert "ImportError" in output + assert "not allowed" in output + + def test_blocked_os_import(self, plugin, mock_agent): + """Should block full os module (only os.path is allowed).""" + invocation_state = {} + event = BeforeInvocationEvent(agent=mock_agent, invocation_state=invocation_state) + plugin._on_before_invocation(event) + + namespace = invocation_state["codeact_namespace"] + output = plugin._execute_code("import os\nos.system('echo pwned')", namespace) + assert "ImportError" in output + + def test_allowed_submodule(self, plugin, mock_agent): + """Should allow importing allowed submodules.""" + invocation_state = {} + event = BeforeInvocationEvent(agent=mock_agent, invocation_state=invocation_state) + plugin._on_before_invocation(event) + + namespace = invocation_state["codeact_namespace"] + output = plugin._execute_code("from os.path import join\nprint(join('a', 'b'))", namespace) + # os.path is allowed + assert "a" in output + + +class TestJsonTypeToPython: + """Test JSON schema type to Python type conversion.""" + + def test_string(self): + assert CodeActPlugin._json_type_to_python("string") == "str" + + def test_integer(self): + assert CodeActPlugin._json_type_to_python("integer") == "int" + + def test_number(self): + assert CodeActPlugin._json_type_to_python("number") == "float" + + def test_boolean(self): + assert CodeActPlugin._json_type_to_python("boolean") == "bool" + + def test_array(self): + assert CodeActPlugin._json_type_to_python("array") == "list" + + def test_object(self): + assert CodeActPlugin._json_type_to_python("object") == "dict" + + def test_unknown_defaults_to_str(self): + assert CodeActPlugin._json_type_to_python("unknown") == "str" + + +class TestEdgeCases: + """Test edge cases and error scenarios.""" + + def test_no_tools_registered(self, plugin): + """Should handle agent with no tools.""" + agent = MagicMock() + agent.system_prompt = "" + agent.state = MagicMock() + agent.state.get.return_value = None + agent.tool_registry = MagicMock() + agent.tool_registry.registry = {} + agent.tool_registry.get_all_tools_config.return_value = {} + + event = BeforeInvocationEvent(agent=agent, invocation_state={}) + plugin._on_before_invocation(event) + + # Should still work, just no tool signatures + assert "CodeAct" in agent.system_prompt + + def test_hyphenated_tool_names(self, plugin): + """Should normalize hyphenated tool names to underscores.""" + agent = MagicMock() + agent.state = MagicMock() + agent.state.get.return_value = None + + mock_tool = MagicMock() + mock_tool.tool_spec = { + "name": "my-tool", + "description": "A tool", + "inputSchema": {"json": {"type": "object", "properties": {}, "required": []}}, + } + agent.tool_registry = MagicMock() + agent.tool_registry.registry = {"my-tool": mock_tool} + agent.tool_registry.get_all_tools_config.return_value = {"my-tool": mock_tool.tool_spec} + agent.tool = MagicMock() + + namespace = plugin._build_initial_namespace(agent) + # Should be accessible as my_tool (underscore) + assert "my_tool" in namespace + + def test_tool_wrapper_exception_handling(self, plugin, mock_agent): + """Tool wrapper should propagate exceptions.""" + mock_agent.tool.calculator.side_effect = RuntimeError("Tool failed") + + wrapper = plugin._make_tool_wrapper(mock_agent, "calculator") + + with pytest.raises(RuntimeError, match="Tool failed"): + asyncio.get_event_loop().run_until_complete(wrapper(expression="1+1")) + + def test_extract_response_text_with_multiple_blocks(self, plugin): + """Should concatenate text from multiple content blocks.""" + result = MagicMock() + result.message = { + "content": [ + {"text": "First part."}, + {"toolUse": {"name": "foo"}}, # Non-text block + {"text": "Second part."}, + ] + } + text = plugin._extract_response_text(result) + assert "First part." in text + assert "Second part." in text + + def test_extract_response_text_no_message(self, plugin): + """Should handle result with no message gracefully.""" + result = MagicMock() + result.message = None + result.__str__ = lambda self: "fallback text" + text = plugin._extract_response_text(result) + assert text == "fallback text"