diff --git a/.github/dependabot.yml b/.github/dependabot.yml new file mode 100644 index 0000000..dc785f7 --- /dev/null +++ b/.github/dependabot.yml @@ -0,0 +1,26 @@ +version: 2 +updates: + # Python dependencies (pyproject.toml + uv.lock) + - package-ecosystem: "pip" + directory: "/" + schedule: + interval: "weekly" + open-pull-requests-limit: 5 + labels: + - "dependencies" + - "python" + groups: + python-dev-dependencies: + dependency-type: "development" + patterns: + - "*" + + # GitHub Actions used in CI / release workflows + - package-ecosystem: "github-actions" + directory: "/" + schedule: + interval: "weekly" + open-pull-requests-limit: 5 + labels: + - "dependencies" + - "github-actions" diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 8c9e871..48fcf13 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -50,3 +50,16 @@ repos: args: [-c, pyproject.toml] additional_dependencies: ["bandit[toml]"] exclude: ^(tests|docs|examples)/ + + # Run mypy from the project venv (synced via `uv sync --dev`) so it resolves + # real dependencies and honours the [tool.mypy] config in pyproject.toml. + # Checks the whole package once (pass_filenames: false) rather than per-file. + - repo: local + hooks: + - id: mypy + name: mypy + entry: uv run mypy agentflow/ + language: system + types: [python] + pass_filenames: false + exclude: ^(tests|docs|examples|normal_tests)/ diff --git a/CLAUDE.md b/CLAUDE.md index ed46249..9c47986 100644 --- a/CLAUDE.md +++ b/CLAUDE.md @@ -24,7 +24,7 @@ persistence, tools, memory, evaluation, and event publishing. Inspired by LangGr the README and several docstrings still show pre-refactor paths (see Known Doc Drift). - **Surgical edits.** This is `Development Status :: 5 - Production/Stable`. Don't refactor module boundaries or rename exports without checking every `__init__.py` that re-exports them. -- **Keep coverage green.** `pytest` enforces `--cov-fail-under=70`. New code needs tests. +- **Keep coverage green.** `pytest` enforces `--cov-fail-under=80`. New code needs tests. - **Optional deps are optional.** Provider SDKs, MCP, Postgres, Redis, Qdrant, Mem0, Kafka, RabbitMQ, OTEL, a2a are all extras. Guard imports; never make core import a hard optional dep. @@ -152,7 +152,7 @@ already present. ```bash # from this folder (agentflow/) -.venv/bin/python -m pytest # full suite (enforces coverage >= 70%) +.venv/bin/python -m pytest # full suite (enforces coverage >= 80%) .venv/bin/python -m pytest tests/graph # one area ruff check . && ruff format . # lint + format (line-length 100, py312) # editable install with extras for local dev: diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md new file mode 100644 index 0000000..c8cd2f6 --- /dev/null +++ b/CONTRIBUTING.md @@ -0,0 +1,112 @@ +# Contributing to Agentflow + +Thanks for your interest in improving `10xscale-agentflow`. This guide covers the +core Python framework that lives in this folder. For the API server, TypeScript +client, docs, or playground, see the `CONTRIBUTING`/`CLAUDE.md` in their +respective packages. + +- Package (PyPI): `10xscale-agentflow` +- Requires: Python >= 3.12 +- The importable package is the nested `agentflow/` directory; this folder is the + repo root for the core library. + +## Getting set up + +We use [`uv`](https://docs.astral.sh/uv/) for environment and dependency +management. + +```bash +# from this folder (the core library root) +uv sync --dev # create .venv and install the package + dev tools +uv run pre-commit install # enable the git hooks (optional but recommended) +``` + +If you work on optional subsystems, install the matching extras, e.g.: + +```bash +uv pip install -e ".[google-genai,openai,mcp,pg_checkpoint]" +``` + +## Before you open a pull request + +Run the same checks CI runs. All must pass: + +```bash +uv run pre-commit run --all-files # ruff format + lint, bandit, mypy, hooks +uv run pytest --cov --cov-branch # tests + coverage gate (>= 80%) +``` + +You can also run pieces individually: + +```bash +uv run ruff check . && uv run ruff format . +uv run mypy agentflow/ +uv run pytest tests/graph # one area +``` + +### What the gates enforce + +- **Formatting & linting:** `ruff` (line length 100, target py312). Most issues + are auto-fixed by `ruff format` / `ruff check --fix`. +- **Types:** `mypy` runs in pre-commit. The codebase is on *phased* typing: a set + of modules with pre-existing errors is listed under `[[tool.mypy.overrides]]` + in `pyproject.toml` with `ignore_errors = true`. New code is type-checked. + Improving a listed module's types and removing it from that list is a welcome + contribution; please don't add new modules to it. +- **Security:** `bandit`. +- **Coverage:** `pytest` fails under 80% line coverage. New code needs tests. + +## Tests + +- Tests live in `tests/`, mirroring the package layout (`graph/`, `state/`, + `storage/`, `publisher/`, `prebuilt/`, `evaluation/`, `testing/`, plus + `chaos/`, `benchmarks/`, `integration/`). +- Markers: `asyncio`, `integration` (needs real databases — Redis/Postgres), + `slow`. Integration tests are skipped unless their backends are available. +- Prefer the in-repo test helpers in `agentflow.qa.testing` (`TestAgent`, + `MockMCPClient`, `MockToolRegistry`) to exercise graphs without live LLM calls. + +## Import paths (read this before referencing symbols) + +The package is organised into `core/`, `storage/`, `runtime/`, `qa/`. There are +**no** top-level `agentflow.graph` / `agentflow.state` / `agentflow.checkpointer` +shims — use the canonical paths: + +```python +from agentflow.core.graph import StateGraph, Agent, ToolNode, CompiledGraph +from agentflow.core.state import AgentState, Message +from agentflow.core.llm import call_llm, create_llm_client, detect_provider +from agentflow.storage.checkpointer import InMemoryCheckpointer, PgCheckpointer +``` + +`examples/` uses current import paths and is the most reliable usage reference. + +## Optional dependencies + +Provider SDKs (OpenAI, Google GenAI), MCP, Postgres, Redis, Qdrant, Mem0, Kafka, +RabbitMQ, OTEL, and a2a are all **extras**. Guard their imports inside the +functions that need them so the core package never hard-imports an optional +dependency. See `agentflow/core/llm/client_factory.py` for the pattern. + +## Commit and PR conventions + +- Use clear, conventional-style commit subjects (`feat:`, `fix:`, `docs:`, + `refactor:`, `test:`, `chore:`), matching the existing history. +- Keep changes surgical. This package is `Development Status :: 5 - + Production/Stable`; avoid renaming exports or moving module boundaries without + checking every `__init__.py` that re-exports the symbol. +- Update docs/examples when you change public behaviour. Prefer fixing a stale + doc/example to match the code over the reverse. +- One logical change per PR. Describe the motivation and how you tested it. + +## Reporting bugs and security issues + +- **Bugs / feature requests:** open an issue at + https://github.com/10xHub/agentflow/issues with a minimal reproduction. +- **Security vulnerabilities:** do **not** open a public issue — follow + [`SECURITY.md`](SECURITY.md). + +## License + +By contributing, you agree that your contributions are licensed under the +project's [MIT License](LICENSE). diff --git a/SECURITY.md b/SECURITY.md new file mode 100644 index 0000000..9a94a8f --- /dev/null +++ b/SECURITY.md @@ -0,0 +1,64 @@ +# Security Policy + +## Supported versions + +`10xscale-agentflow` is pre-1.0 and ships from a single release line. Security +fixes are applied to the latest published release only. Pin a known-good version +in production and upgrade promptly when a security release is announced. + +| Version | Supported | +| ------- | ------------------ | +| 0.7.x | :white_check_mark: | +| < 0.7 | :x: | + +## Reporting a vulnerability + +**Please do not open a public GitHub issue for security problems.** + +Report privately through either channel: + +- **GitHub Security Advisories** (preferred): open a private report at + https://github.com/10xHub/agentflow/security/advisories/new +- **Email:** contact@10xscale.ai (you may also CC shudiptotrafder@gmail.com) + +Include as much of the following as you can: + +- A description of the issue and the impact you believe it has. +- The affected version(s) and, if known, the affected module/import path + (e.g. `agentflow.core.llm.client_factory`). +- A minimal reproduction or proof of concept. +- Any suggested remediation. + +### What to expect + +- **Acknowledgement** within 3 business days. +- An initial assessment and severity triage within 7 business days. +- Coordinated disclosure: we will agree on a disclosure timeline with you and + credit you in the advisory unless you prefer to remain anonymous. + +## Scope + +This policy covers the `10xscale-agentflow` core Python package in this +repository. Issues in the API server (`10xscale-agentflow-cli`), the TypeScript +client, or third-party dependencies should be reported against their respective +projects, though we are happy to help route a report. + +### Things that are expected behaviour, not vulnerabilities + +- **Tools execute arbitrary code by design.** Tools you register with a + `ToolNode` run with the privileges of the host process. Only register trusted + tools and treat tool inputs derived from model output as untrusted. +- **Provider API keys are read from the environment** (`OPENAI_API_KEY`, + `GEMINI_API_KEY`, etc.). Protecting that environment is the deployer's + responsibility. +- **Prompt injection** against an LLM is a property of the model/application + design. Reports demonstrating a concrete privilege escalation or data + exfiltration path *through the framework* are in scope; generic "the model can + be jailbroken" reports are not. + +## Good practice for deployers + +- Keep `IS_DEBUG=false` and `MODE=production` in production. +- Never set `ORIGINS=*` in production. +- Use a secrets manager rather than committing `.env` files. +- Constrain which tools and MCP servers an agent can reach. diff --git a/agentflow/core/graph/agent_internal/circuit_breaker.py b/agentflow/core/graph/agent_internal/circuit_breaker.py new file mode 100644 index 0000000..0d5531e --- /dev/null +++ b/agentflow/core/graph/agent_internal/circuit_breaker.py @@ -0,0 +1,95 @@ +"""A small circuit breaker for LLM calls. + +Complements retry + fallback: once a model/provider has failed +``failure_threshold`` times in a row, its circuit *opens* and further calls to +it are short-circuited (skipped, moving straight to the next fallback) for +``reset_timeout`` seconds. After that cooldown a single trial is allowed +(*half-open*); success closes the circuit, another failure re-opens it. + +This stops a dead provider from being retried on every single invocation. +""" + +from __future__ import annotations + +import time +from collections.abc import Callable +from enum import Enum + + +class CircuitState(str, Enum): + """Lifecycle state of a :class:`CircuitBreaker`.""" + + closed = "closed" # normal operation, calls allowed + open = "open" # failing, calls skipped until the cooldown elapses + half_open = "half_open" # cooldown elapsed, one trial call allowed + + +class CircuitBreakerOpenError(RuntimeError): + """Raised/used as the recorded error when a call is skipped by an open circuit.""" + + def __init__(self, key: object, retry_after: float) -> None: + self.key = key + self.retry_after = retry_after + super().__init__(f"Circuit breaker open for {key!r}; retry in {retry_after:.1f}s") + + +class CircuitBreaker: + """Per-target failure tracker with open/half-open/closed states. + + Args: + failure_threshold: Consecutive failures that trip the circuit (>= 1). + reset_timeout: Seconds to stay open before allowing a half-open trial. + time_func: Monotonic clock source; injectable for testing. + """ + + def __init__( + self, + failure_threshold: int = 5, + reset_timeout: float = 30.0, + time_func: Callable[[], float] = time.monotonic, + ) -> None: + if failure_threshold < 1: + raise ValueError("failure_threshold must be >= 1") + if reset_timeout <= 0: + raise ValueError("reset_timeout must be > 0") + self.failure_threshold = failure_threshold + self.reset_timeout = reset_timeout + self._time = time_func + self._failures = 0 + self._state = CircuitState.closed + self._opened_at = 0.0 + + @property + def state(self) -> CircuitState: + return self._state + + @property + def failure_count(self) -> int: + return self._failures + + def allow(self) -> bool: + """Return True if a call may proceed, transitioning open -> half-open if due.""" + if self._state is CircuitState.open: + if self._time() - self._opened_at >= self.reset_timeout: + self._state = CircuitState.half_open + return True + return False + return True + + def record_success(self) -> None: + """Reset the breaker to closed after a successful call.""" + self._failures = 0 + self._state = CircuitState.closed + + def record_failure(self) -> None: + """Register a failure, opening the circuit at/over threshold or from half-open.""" + self._failures += 1 + if self._state is CircuitState.half_open or self._failures >= self.failure_threshold: + self._state = CircuitState.open + self._opened_at = self._time() + + def retry_after(self) -> float: + """Seconds remaining before an open circuit allows a half-open trial.""" + if self._state is not CircuitState.open: + return 0.0 + return max(0.0, self.reset_timeout - (self._time() - self._opened_at)) diff --git a/agentflow/core/graph/agent_internal/constants.py b/agentflow/core/graph/agent_internal/constants.py index 866d127..ce029b8 100644 --- a/agentflow/core/graph/agent_internal/constants.py +++ b/agentflow/core/graph/agent_internal/constants.py @@ -16,6 +16,13 @@ class RetryConfig: max_delay: Upper-bound cap on exponential back-off delay (default ``30.0``). backoff_factor: Multiplier applied after each retry (default ``2.0``). retryable_status_codes: HTTP status codes considered transient/retryable. + circuit_breaker_enabled: When True, track failures per (provider, model) + and skip a target whose circuit is open, moving straight to the next + fallback instead of retrying a known-dead provider (default ``False``). + circuit_breaker_threshold: Consecutive failures that open a circuit + (default ``5``). + circuit_breaker_reset_timeout: Seconds a circuit stays open before a + single half-open trial is allowed (default ``30.0``). """ max_retries: int = 3 @@ -25,6 +32,9 @@ class RetryConfig: retryable_status_codes: frozenset[int] = field( default_factory=lambda: frozenset({429, 500, 502, 503, 529}), ) + circuit_breaker_enabled: bool = False + circuit_breaker_threshold: int = 5 + circuit_breaker_reset_timeout: float = 30.0 DEFAULT_RETRY_CONFIG = RetryConfig() diff --git a/agentflow/core/graph/agent_internal/execution.py b/agentflow/core/graph/agent_internal/execution.py index b48a751..3c9f7c7 100644 --- a/agentflow/core/graph/agent_internal/execution.py +++ b/agentflow/core/graph/agent_internal/execution.py @@ -18,6 +18,7 @@ strip_media_blocks, ) +from .circuit_breaker import CircuitBreaker, CircuitBreakerOpenError from .constants import RetryConfig @@ -255,7 +256,37 @@ def _is_retryable_error(self, exc: Exception, retry_cfg: RetryConfig) -> bool: for keyword in ("timeout", "connection", "unavailable", "serviceunav") ) - async def _call_llm_with_retry( # noqa: PLR0912 + def _get_circuit_breaker( + self, + provider: str, + model: str, + retry_cfg: RetryConfig | None, + ) -> CircuitBreaker | None: + """Return the circuit breaker for ``(provider, model)``, or None if disabled. + + Breakers are created lazily and cached on the instance so their state + persists across calls (the whole point: stop hammering a dead provider on + every invocation). + """ + if retry_cfg is None or not retry_cfg.circuit_breaker_enabled: + return None + registry: dict[tuple[str, str], CircuitBreaker] | None = self.__dict__.get( + "_circuit_breakers" + ) + if registry is None: + registry = {} + self._circuit_breakers = registry + key = (provider, model) + breaker = registry.get(key) + if breaker is None: + breaker = CircuitBreaker( + failure_threshold=retry_cfg.circuit_breaker_threshold, + reset_timeout=retry_cfg.circuit_breaker_reset_timeout, + ) + registry[key] = breaker + return breaker + + async def _call_llm_with_retry( # noqa: PLR0912, PLR0915 self, messages: list[dict[str, Any]], tools: list | None = None, @@ -297,6 +328,18 @@ async def _call_llm_with_retry( # noqa: PLR0912 provider, ) + breaker = self._get_circuit_breaker(provider, model, retry_cfg) + if breaker is not None and not breaker.allow(): + retry_after = breaker.retry_after() + logger.warning( + "Circuit open for %s (provider=%s); skipping (retry in %.1fs)", + model, + provider, + retry_after, + ) + last_exc = CircuitBreakerOpenError((provider, model), retry_after) + continue + for retry in range(max_retries + 1): # 0 .. max_retries try: if is_fallback: @@ -331,6 +374,8 @@ async def _call_llm_with_retry( # noqa: PLR0912 else: result = await self._call_llm(messages, tools, stream, **kwargs) + if breaker is not None: + breaker.record_success() if is_fallback or retry > 0: logger.info( "LLM call succeeded on %s (attempt %d/%d, model_index=%d)", @@ -374,6 +419,11 @@ async def _call_llm_with_retry( # noqa: PLR0912 model, ) + # This model's attempt cycle failed (retries exhausted or + # non-retryable). Record one failure against its circuit. + if breaker is not None: + breaker.record_failure() + # Every model exhausted → re-raise the last exception assert last_exc is not None # noqa: S101 raise last_exc diff --git a/agentflow/core/graph/compiled_graph.py b/agentflow/core/graph/compiled_graph.py index 750dfb4..437146f 100644 --- a/agentflow/core/graph/compiled_graph.py +++ b/agentflow/core/graph/compiled_graph.py @@ -30,6 +30,8 @@ if TYPE_CHECKING: + from types import TracebackType + from .state_graph import StateGraph @@ -135,6 +137,26 @@ def __init__( self._interrupt_after: list[str] = interrupt_after # generate task manager self._task_manager = task_manager + # Guards aclose() against being run more than once (e.g. an explicit + # aclose() inside an ``async with`` block followed by __aexit__). + self._closed = False + + async def __aenter__(self) -> CompiledGraph[StateT]: + """Enter an async context; returns this graph unchanged. + + Enables ``async with compiled_graph as graph: ...``, which guarantees + :meth:`aclose` runs on exit even if the body raises. + """ + return self + + async def __aexit__( + self, + exc_type: type[BaseException] | None, + exc: BaseException | None, + tb: TracebackType | None, + ) -> None: + """Exit the async context, releasing all resources via :meth:`aclose`.""" + await self.aclose() def _prepare_config( self, @@ -505,7 +527,7 @@ def attach_remote_tools( node_name, ) - async def aclose(self) -> dict[str, Any]: + async def aclose(self) -> dict[str, Any]: # noqa: PLR0915 """ Close the graph and release all resources gracefully. @@ -519,9 +541,20 @@ async def aclose(self) -> dict[str, Any]: Returns: Dictionary with detailed shutdown statistics for each component. + Calling this more than once is a no-op; the second call returns + ``{"status": "already_closed"}``. Prefer the async-context-manager form, + which calls this automatically on exit: + Example: ```python async def main(): + async with await build_and_compile_graph() as graph: + await graph.ainvoke(input_data) + # graph.aclose() has run here, even if ainvoke raised + + + # Or manage the lifecycle manually: + async def main_manual(): graph = await build_and_compile_graph() try: await graph.ainvoke(input_data) @@ -532,6 +565,11 @@ async def main(): """ from agentflow.utils.shutdown import shutdown_with_timeout + if self._closed: + logger.debug("CompiledGraph.aclose() called again; already closed") + return {"status": "already_closed"} + self._closed = True + logger.info("Initiating graceful shutdown of CompiledGraph") stats: dict[str, Any] = {} start_time = asyncio.get_event_loop().time() diff --git a/agentflow/core/llm/__init__.py b/agentflow/core/llm/__init__.py index e2b2178..89dc4a4 100644 --- a/agentflow/core/llm/__init__.py +++ b/agentflow/core/llm/__init__.py @@ -1,7 +1,20 @@ """LLM client creation utilities shared across agents and evaluators.""" from .caller import call_llm -from .client_factory import create_llm_client, detect_provider +from .client_factory import ( + DEFAULT_LLM_TIMEOUT_SECONDS, + create_llm_client, + detect_provider, + get_default_llm_timeout, + set_default_llm_timeout, +) -__all__ = ["call_llm", "create_llm_client", "detect_provider"] +__all__ = [ + "DEFAULT_LLM_TIMEOUT_SECONDS", + "call_llm", + "create_llm_client", + "detect_provider", + "get_default_llm_timeout", + "set_default_llm_timeout", +] diff --git a/agentflow/core/llm/client_factory.py b/agentflow/core/llm/client_factory.py index 5fb8329..59afeb2 100644 --- a/agentflow/core/llm/client_factory.py +++ b/agentflow/core/llm/client_factory.py @@ -14,6 +14,69 @@ logger = logging.getLogger("agentflow.llm") +# Default timeout (in seconds) applied to LLM client construction when the +# caller does not pass an explicit ``timeout``. Bounds every request so a stalled +# provider connection cannot hang a graph run indefinitely. Override globally via +# the ``AGENTFLOW_LLM_TIMEOUT`` environment variable (seconds) or programmatically +# via :func:`set_default_llm_timeout`. +DEFAULT_LLM_TIMEOUT_SECONDS = 600.0 + +# Single-element holder so the override can be mutated without a ``global`` +# statement (which ruff's PLW0603 flags). +_default_timeout_override: dict[str, float | None] = {"value": None} + + +def _env_timeout() -> float | None: + """Read ``AGENTFLOW_LLM_TIMEOUT`` (seconds), or None if unset/invalid.""" + raw = os.getenv("AGENTFLOW_LLM_TIMEOUT") + if raw is None or not raw.strip(): + return None + try: + value = float(raw) + except ValueError: + logger.warning( + "Invalid AGENTFLOW_LLM_TIMEOUT=%r; expected a number of seconds. Ignoring.", + raw, + ) + return None + if value <= 0: + logger.warning("AGENTFLOW_LLM_TIMEOUT=%s must be positive. Ignoring.", value) + return None + return value + + +def get_default_llm_timeout() -> float: + """Return the default LLM request timeout in seconds. + + Resolution order (first match wins): + + 1. A programmatic override set via :func:`set_default_llm_timeout`. + 2. The ``AGENTFLOW_LLM_TIMEOUT`` environment variable (seconds). + 3. :data:`DEFAULT_LLM_TIMEOUT_SECONDS`. + """ + override = _default_timeout_override["value"] + if override is not None: + return override + env = _env_timeout() + if env is not None: + return env + return DEFAULT_LLM_TIMEOUT_SECONDS + + +def set_default_llm_timeout(seconds: float | None) -> None: + """Globally override the default LLM request timeout, in seconds. + + Pass ``None`` to clear the override and fall back to the + ``AGENTFLOW_LLM_TIMEOUT`` environment variable / built-in default. + + Raises: + ValueError: If ``seconds`` is not a positive number. + """ + if seconds is not None and seconds <= 0: + raise ValueError("LLM timeout must be a positive number of seconds.") + _default_timeout_override["value"] = seconds + + # Recognised ``provider/`` prefixes mapped to the concrete provider the client # factory can build. Anything not listed here is an unknown prefix and resolves # to ``"openai"`` (the OpenAI SDK is used for OpenAI-compatible endpoints). @@ -141,12 +204,16 @@ def create_llm_client( def _create_google_client(*, use_vertex_ai: bool) -> Any: try: from google import genai + from google.genai.types import HttpOptions except ImportError as exc: raise ImportError( "google-genai SDK is required for the Google provider. " "Install it with: pip install 10xscale-agentflow[google-genai]" ) from exc + # google-genai expresses the request timeout in milliseconds. + http_options = HttpOptions(timeout=int(get_default_llm_timeout() * 1000)) + if use_vertex_ai: project = os.getenv("GOOGLE_CLOUD_PROJECT") location = os.getenv("GOOGLE_CLOUD_LOCATION", "us-central1") @@ -157,7 +224,9 @@ def _create_google_client(*, use_vertex_ai: bool) -> Any: project, location, ) - return genai.Client(vertexai=True, project=project, location=location) + return genai.Client( + vertexai=True, project=project, location=location, http_options=http_options + ) api_key = os.getenv("GEMINI_API_KEY") or os.getenv("GOOGLE_API_KEY") if not api_key: @@ -170,7 +239,7 @@ def _create_google_client(*, use_vertex_ai: bool) -> Any: # GOOGLE_GENAI_USE_VERTEXAI env var and silently switches to Vertex mode, # which rejects API keys (401 UNAUTHENTICATED). The caller asked for the # Developer API (use_vertex_ai=False), so honour that over the env. - return genai.Client(vertexai=False, api_key=api_key) + return genai.Client(vertexai=False, api_key=api_key, http_options=http_options) def _create_openai_client( @@ -195,6 +264,8 @@ def _create_openai_client( ) client_kwargs = {k: v for k, v in extra_kwargs.items() if k in _CLIENT_CONSTRUCTOR_KWARGS} + # Bound the request unless the caller opted into their own timeout. + client_kwargs.setdefault("timeout", get_default_llm_timeout()) if base_url: logger.info("Creating OpenAI client with custom base_url: %s", base_url) return AsyncOpenAI(api_key=resolved_key, base_url=base_url, **client_kwargs) diff --git a/agentflow/py.typed b/agentflow/py.typed new file mode 100644 index 0000000..e69de29 diff --git a/agentflow/runtime/publisher/console_publisher.py b/agentflow/runtime/publisher/console_publisher.py index 5c55936..559f023 100644 --- a/agentflow/runtime/publisher/console_publisher.py +++ b/agentflow/runtime/publisher/console_publisher.py @@ -15,15 +15,25 @@ class ConsolePublisher(BasePublisher): - """Publisher that prints events to the console for debugging and testing. + """Publisher that writes events to the console for debugging and testing. - This publisher is useful for development and debugging purposes, as it outputs event information - to the standard output. + This is a development/debugging publisher. It is opt-in: nothing wires it up + unless you explicitly construct it and pass it to ``compile()``. **For + production, use a real transport** (Redis, Kafka, RabbitMQ, or OTEL) rather + than this one. + + By default events are written to stdout via ``print`` so they are visible in + a quick script without any logging setup. In a server context, where writing + to stdout is undesirable, set ``use_logger=True`` to route events through the + ``agentflow.publisher`` logger at ``INFO`` level instead, so they respect your + logging configuration. Attributes: format: Output format ('json' by default). include_timestamp: Whether to include timestamp (True by default). indent: Indentation for output (2 by default). + use_logger: Emit via the ``agentflow.publisher`` logger instead of stdout + (False by default). """ def __init__(self, config: dict[str, Any] | None = None): @@ -34,15 +44,21 @@ def __init__(self, config: dict[str, Any] | None = None): - format: Output format (default: 'json'). - include_timestamp: Whether to include timestamp (default: True). - indent: Indentation for output (default: 2). + - use_logger: Emit via the logger instead of stdout + (default: False). """ super().__init__(config or {}) self.format = config.get("format", "json") if config else "json" self.include_timestamp = config.get("include_timestamp", True) if config else True self.indent = config.get("indent", 2) if config else 2 + self.use_logger = config.get("use_logger", False) if config else False async def publish(self, event: EventModel) -> Any: """Publish an event to the console. + Writes to stdout by default, or emits via the ``agentflow.publisher`` + logger when ``use_logger=True`` was set in the config. + Args: event: The event to publish. @@ -58,7 +74,10 @@ async def publish(self, event: EventModel) -> Any: msg = f"{event.timestamp} -> Source: {event.node_name}.{event.event_type}:" msg += f"-> Payload: {event.data}" msg += f" -> {event.metadata}" - print(msg) # noqa: T201 + if self.use_logger: + logger.info(msg) + else: + print(msg) # noqa: T201 async def close(self): """Close the publisher and release any resources. diff --git a/agentflow/storage/media/media_resolver.py b/agentflow/storage/media/media_resolver.py index fcecac0..521e1cb 100644 --- a/agentflow/storage/media/media_resolver.py +++ b/agentflow/storage/media/media_resolver.py @@ -84,6 +84,7 @@ async def resolve( transport_order = caps.get_transport_order(media_type) transports_attempted: list[MediaTransportMode] = [] + last_error: Exception | None = None for transport in transport_order: transports_attempted.append(transport) @@ -97,12 +98,16 @@ async def resolve( ) if result is not None: return result - except Exception: - logger.debug( - "Transport %s failed for %s/%s, trying next fallback", + except Exception as exc: + last_error = exc + logger.warning( + "Media transport %s failed for %s/%s (%s: %s); trying next fallback", transport.value, provider, model, + type(exc).__name__, + exc, + exc_info=True, ) continue @@ -118,7 +123,7 @@ async def resolve( f"All transports failed: " f"{', '.join(t.value for t in transports_attempted)}." ), - ) + ) from last_error async def _try_transport( self, @@ -190,7 +195,15 @@ async def _transport_inline_bytes( from google.genai import types return types.Part.from_bytes(data=data, mime_type=mime) - except Exception: + except Exception as exc: + logger.warning( + "inline_bytes transport failed to fetch %s for provider %s (%s: %s)", + ref.url, + provider, + type(exc).__name__, + exc, + exc_info=True, + ) return None return None @@ -231,7 +244,14 @@ async def _transport_provider_file( # noqa: PLR0911 return await upload_to_google_file_api(data, mime) - except Exception: + except Exception as exc: + logger.warning( + "provider_file transport failed (Google File API) for ref kind=%s (%s: %s)", + ref.kind, + type(exc).__name__, + exc, + exc_info=True, + ) return None return None diff --git a/agentflow/storage/media/processor.py b/agentflow/storage/media/processor.py index 443ed4e..088ba73 100644 --- a/agentflow/storage/media/processor.py +++ b/agentflow/storage/media/processor.py @@ -178,7 +178,12 @@ def _load_inline_image(self, block: ImageBlock) -> Any | None: raw = base64.b64decode(block.media.data_base64) try: return Image.open(io.BytesIO(raw)) - except Exception: + except Exception as exc: + logger.debug( + "Could not open inline image for orientation fix (%s: %s); skipping", + type(exc).__name__, + exc, + ) return None @staticmethod diff --git a/agentflow/storage/media/resolver.py b/agentflow/storage/media/resolver.py index 565001f..18d0307 100644 --- a/agentflow/storage/media/resolver.py +++ b/agentflow/storage/media/resolver.py @@ -202,6 +202,7 @@ async def _resolve_with_capabilities( transport_order = caps.get_transport_order(media_type) transports_attempted: list[MediaTransportMode] = [] + last_error: Exception | None = None for transport in transport_order: transports_attempted.append(transport) @@ -216,12 +217,16 @@ async def _resolve_with_capabilities( return result except UnsupportedMediaInputError: raise - except Exception: - logger.debug( - "Transport %s failed for %s/%s, trying next fallback", + except Exception as exc: + last_error = exc + logger.warning( + "Media transport %s failed for %s/%s (%s: %s); trying next fallback", transport.value, provider, model, + type(exc).__name__, + exc, + exc_info=True, ) continue @@ -231,7 +236,7 @@ async def _resolve_with_capabilities( media_type=media_type, source_kind=_source_kind(ref), transports_attempted=transports_attempted, - ) + ) from last_error async def _try_transport( self, @@ -309,7 +314,15 @@ async def _transport_inline_bytes( from google.genai import types return types.Part.from_bytes(data=data, mime_type=mime) - except Exception: + except Exception as exc: + logger.warning( + "inline_bytes transport failed to fetch %s for provider %s (%s: %s)", + ref.url, + provider, + type(exc).__name__, + exc, + exc_info=True, + ) return None return None @@ -343,7 +356,14 @@ async def _transport_provider_file( return await upload_to_google_file_api(data, mime) - except Exception: + except Exception as exc: + logger.warning( + "provider_file transport failed (Google File API) for ref kind=%s (%s: %s)", + ref.kind, + type(exc).__name__, + exc, + exc_info=True, + ) return None return None diff --git a/agentflow/storage/media/storage/cloud_store.py b/agentflow/storage/media/storage/cloud_store.py index 5f3a333..6cb82f4 100644 --- a/agentflow/storage/media/storage/cloud_store.py +++ b/agentflow/storage/media/storage/cloud_store.py @@ -250,7 +250,13 @@ async def _download_meta(self, storage_key: str) -> dict[str, Any] | None: url = await self._storage.get_public_url(meta_path, expiration=60) raw = await self._download_from_url(url) return json.loads(raw) - except Exception: + except Exception as exc: + logger.debug( + "Could not download/parse sidecar metadata at %s (%s: %s)", + meta_path, + type(exc).__name__, + exc, + ) return None @staticmethod diff --git a/agentflow/utils/__init__.py b/agentflow/utils/__init__.py index 9387e09..7af7eea 100644 --- a/agentflow/utils/__init__.py +++ b/agentflow/utils/__init__.py @@ -55,7 +55,12 @@ TimestampIDGenerator, UUIDGenerator, ) -from .logging import logger +from .logging import ( + SecretRedactionFilter, + install_secret_redaction, + logger, + mask_secrets, +) from .shutdown import ( DelayedKeyboardInterrupt, GracefulShutdownManager, @@ -101,6 +106,7 @@ "OnErrorCallback", "PromptInjectionValidator", "ResponseGranularity", + "SecretRedactionFilter", "ShortIDGenerator", "TaskMetadata", "ThreadInfo", @@ -114,7 +120,9 @@ "delayed_keyboard_interrupt", "get_tool_metadata", "has_tool_decorator", + "install_secret_redaction", "logger", + "mask_secrets", "register_default_validators", "replace_messages", "replace_value", diff --git a/agentflow/utils/logging.py b/agentflow/utils/logging.py index ccefec5..6eb1ea1 100644 --- a/agentflow/utils/logging.py +++ b/agentflow/utils/logging.py @@ -40,6 +40,8 @@ """ import logging +import re +from collections.abc import Callable # Create the main agentflow logger @@ -49,6 +51,114 @@ # Users can configure their own handlers as needed logger.addHandler(logging.NullHandler()) + +# ── Secret redaction ───────────────────────────────────────────────────────── +# +# Best-effort masking of credentials that may otherwise surface in debug logs +# (e.g. signed URLs with query-string tokens, Authorization headers, provider +# API keys). This is defence-in-depth, not a guarantee: prefer never logging +# secrets in the first place. + +_REDACTED = "***REDACTED***" + +_Replacement = str | Callable[[re.Match[str]], str] + +# (pattern, replacement) pairs. Replacement is either the placeholder string +# (full match redacted) or a callable that preserves the key name and redacts +# only the value. ``Bearer`` is handled before the generic key=value rule so an +# Authorization header keeps its scheme instead of being double-redacted. +_SECRET_SUBS: list[tuple[re.Pattern[str], _Replacement]] = [ + # OpenAI-style secret keys: sk-... and sk-proj-... + (re.compile(r"sk-(?:proj-)?[A-Za-z0-9_-]{16,}"), _REDACTED), + # Google API keys + (re.compile(r"AIza[0-9A-Za-z_-]{35}"), _REDACTED), + # GitHub tokens (ghp_, gho_, ghu_, ghs_, ghr_) + (re.compile(r"gh[pousr]_[A-Za-z0-9]{20,}"), _REDACTED), + # Slack tokens + (re.compile(r"xox[baprs]-[A-Za-z0-9-]{10,}"), _REDACTED), + # AWS access key id + (re.compile(r"AKIA[0-9A-Z]{16}"), _REDACTED), + # Bearer tokens (e.g. in Authorization headers) — keep the scheme + (re.compile(r"(?i)\bBearer\s+[A-Za-z0-9._\-]+"), "Bearer " + _REDACTED), + # key/secret/token/password = value (JSON or key=value form) + ( + re.compile( + r"(?i)(api[_-]?key|access[_-]?token|secret|password)" + r"""(["']?\s*[:=]\s*["']?)""" + r"""([^\s"',&}]{4,})""" + ), + lambda m: f"{m.group(1)}{m.group(2)}{_REDACTED}", + ), + # Signed-URL credential query params (?token=…, &sig=…, &X-Amz-Signature=…) + ( + re.compile( + r"(?i)([?&](?:token|sig|signature|x-amz-signature|" + r"x-goog-signature|key|password)=)([^&\s]+)" + ), + lambda m: f"{m.group(1)}{_REDACTED}", + ), +] + + +def mask_secrets(text: str) -> str: + """Redact common credential formats from a string. + + Masks OpenAI/Google/GitHub/Slack/AWS keys, ``Bearer`` tokens, + ``key=value`` secrets, and signed-URL credential query parameters. Returns + the input unchanged when it contains nothing that looks like a secret. + + This is a heuristic. It will not catch every possible secret and may + occasionally over-redact; treat it as a safety net, not a guarantee. + """ + if not text: + return text + for pattern, repl in _SECRET_SUBS: + text = pattern.sub(repl, text) + return text + + +class SecretRedactionFilter(logging.Filter): + """Logging filter that redacts secrets from a record's formatted message. + + Add it to a *handler* to cover every logger that propagates to that handler:: + + handler.addFilter(SecretRedactionFilter()) + + Adding it to a logger only redacts records emitted directly on that logger, + not its children (Python applies logger-level filters only at the originating + logger, while handler-level filters run for propagated records too). + """ + + def filter(self, record: logging.LogRecord) -> bool: + try: + message = record.getMessage() + except Exception: # pragma: no cover - never block logging on redaction + return True + redacted = mask_secrets(message) + if redacted != message: + record.msg = redacted + record.args = () + return True + + +def install_secret_redaction(logger_name: str = "agentflow") -> SecretRedactionFilter: + """Attach a :class:`SecretRedactionFilter` to ``logger_name`` and its handlers. + + Call this *after* configuring your logging handlers so the filter covers the + records they emit. For complete coverage of child loggers, prefer adding the + filter to your handler(s) directly. Returns the installed filter. + """ + target = logging.getLogger(logger_name) + redactor = SecretRedactionFilter() + target.addFilter(redactor) + for handler in target.handlers: + handler.addFilter(redactor) + return redactor + + __all__ = [ + "SecretRedactionFilter", + "install_secret_redaction", "logger", + "mask_secrets", ] diff --git a/pyproject.toml b/pyproject.toml index 76a4b0b..5ab7f2c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -61,9 +61,9 @@ requires-python = ">=3.12" dependencies = [ "injectq>=0.4.0", "pillow>=12.2.0", - "pydantic", - "PyYAML", - "python-dotenv", + "pydantic>=2.0,<3", + "PyYAML>=6.0", + "python-dotenv>=1.0", ] authors = [ {name = "10xScale", email = "contact@10xscale.ai"} @@ -123,6 +123,9 @@ all_publishers = [ include = ["agentflow*"] exclude = ["normal_tests*", "tests*", "examples*", "docs*"] +[tool.setuptools.package-data] +agentflow = ["py.typed"] + [tool.ruff] line-length = 100 @@ -222,6 +225,11 @@ convention = "google" [tool.mypy] +python_version = "3.12" +ignore_missing_imports = true +namespace_packages = true +explicit_package_bases = true +warn_unused_configs = true exclude = [ "normal_tests/*", "tests/*", @@ -229,6 +237,56 @@ exclude = [ "docs/*", ] +# Phased adoption: mypy was wired into CI after the codebase already existed, so +# it does not block on the modules that still carry pre-existing type errors. +# These are still parsed (so downstream type info is available) but their own +# errors are silenced. Burn the list down by cleaning a module's types and +# deleting its entry here; everything NOT listed is already gating. +[[tool.mypy.overrides]] +module = [ + "agentflow.core.graph.agent", + "agentflow.core.graph.agent_internal.execution", + "agentflow.core.graph.agent_internal.google", + "agentflow.core.graph.agent_internal.memory", + "agentflow.core.graph.agent_internal.openai", + "agentflow.core.graph.compiled_graph", + "agentflow.core.graph.node", + "agentflow.core.graph.state_graph", + "agentflow.core.graph.tool_node.base", + "agentflow.core.graph.tool_node.mcp_exec", + "agentflow.core.graph.tool_node.schema", + "agentflow.core.graph.utils.handler_utils", + "agentflow.core.graph.utils.invoke_handler", + "agentflow.core.graph.utils.invoke_node_handler", + "agentflow.core.graph.utils.stream_handler", + "agentflow.core.graph.utils.stream_node_handler", + "agentflow.core.graph.utils.utils", + "agentflow.prebuilt.tools.calculator", + "agentflow.prebuilt.tools.fetch", + "agentflow.prebuilt.tools.files", + "agentflow.prebuilt.tools.memory", + "agentflow.prebuilt.tools.search", + "agentflow.qa.evaluation.criteria.llm_utils", + "agentflow.qa.evaluation.evaluator", + "agentflow.qa.evaluation.quick_eval", + "agentflow.qa.evaluation.reporters.console", + "agentflow.qa.evaluation.reporters._html_render", + "agentflow.qa.evaluation.simulators.user_simulator", + "agentflow.qa.testing.quick_test", + "agentflow.qa.testing.test_agent", + "agentflow.runtime.adapters.llm.google_genai_converter", + "agentflow.runtime.adapters.llm.model_response_converter", + "agentflow.runtime.adapters.llm.openai_converter", + "agentflow.runtime.adapters.llm.openai_responses_converter", + "agentflow.runtime.publisher.publish", + "agentflow.storage.checkpointer.pg_checkpointer", + "agentflow.storage.media.processor", + "agentflow.storage.store.embedding.google_embedding", + "agentflow.storage.store.long_term_memory", + "agentflow.storage.store.qdrant_store", +] +ignore_errors = true + [tool.pytest.ini_options] env = [ "ENVIRONMENT=pytest", @@ -261,7 +319,7 @@ addopts = [ "--cov-report=html", "--cov-report=term-missing", "--cov-report=xml", - "--cov-fail-under=70", + "--cov-fail-under=80", "--strict-markers", "-v" ] @@ -305,6 +363,8 @@ dev = [ "pytest-benchmark>=5.1.0", "hypothesis>=6.100.0", "pre-commit>=3.8.0", + "mypy>=1.11.0", + "types-PyYAML", "asyncpg>=0.30.0", "redis>=6.4.0", "fastmcp>=2.12.3", diff --git a/tests/graph/test_agent_retry_fallback.py b/tests/graph/test_agent_retry_fallback.py index f17ce1b..c9039d4 100644 --- a/tests/graph/test_agent_retry_fallback.py +++ b/tests/graph/test_agent_retry_fallback.py @@ -20,6 +20,7 @@ import pytest from agentflow.core.graph.agent import Agent +from agentflow.core.graph.agent_internal.circuit_breaker import CircuitBreakerOpenError from agentflow.core.graph.agent_internal.constants import DEFAULT_RETRY_CONFIG, RetryConfig @@ -780,3 +781,84 @@ async def test_google_string_503_retried(self): ) assert result is response + + +# ═════════════════════════════════════════════════════════════════════════════ +# Circuit breaker (opt-in via RetryConfig) +# ═════════════════════════════════════════════════════════════════════════════ + + +@pytest.mark.asyncio +class TestCircuitBreakerIntegration: + async def test_disabled_by_default_primary_retried_every_call(self): + """Without circuit_breaker_enabled, a dead primary is retried every call.""" + cfg = RetryConfig(max_retries=0, initial_delay=0.01) + agent = _make_agent(retry_config=cfg, fallback_models=["gpt-4o-mini"]) + original_model = agent.model + primary_calls = 0 + + async def mock_call_llm(*args, **kwargs): + nonlocal primary_calls + if agent.model == original_model: + primary_calls += 1 + raise _FakeAPIStatusError(503, "primary_down") + return _chat_response("fallback") + + agent._call_llm = AsyncMock(side_effect=mock_call_llm) + with patch("agentflow.core.graph.agent_internal.execution.asyncio.sleep", new_callable=AsyncMock): + for _ in range(4): + await agent._call_llm_with_retry([{"role": "user", "content": "Hi"}]) + + # Primary is attempted on every one of the 4 invocations. + assert primary_calls == 4 + + async def test_open_circuit_skips_dead_primary(self): + """After threshold failures the primary is skipped, going straight to fallback.""" + cfg = RetryConfig( + max_retries=0, + initial_delay=0.01, + circuit_breaker_enabled=True, + circuit_breaker_threshold=2, + ) + agent = _make_agent(retry_config=cfg, fallback_models=["gpt-4o-mini"]) + original_model = agent.model + primary_calls = 0 + fallback_response = _chat_response("fallback") + + async def mock_call_llm(*args, **kwargs): + nonlocal primary_calls + if agent.model == original_model: + primary_calls += 1 + raise _FakeAPIStatusError(503, "primary_down") + return fallback_response + + agent._call_llm = AsyncMock(side_effect=mock_call_llm) + with patch("agentflow.core.graph.agent_internal.execution.asyncio.sleep", new_callable=AsyncMock): + results = [ + await agent._call_llm_with_retry([{"role": "user", "content": "Hi"}]) + for _ in range(4) + ] + + # Primary fails on invocations 1 and 2 (opening the circuit at 2), then is + # skipped on 3 and 4. Every call still succeeds via the fallback. + assert primary_calls == 2 + assert all(r is fallback_response for r in results) + + async def test_all_circuits_open_raises_circuit_breaker_error(self): + """If every model's circuit is open, the recorded CircuitBreakerOpenError is raised.""" + cfg = RetryConfig( + max_retries=0, + initial_delay=0.01, + circuit_breaker_enabled=True, + circuit_breaker_threshold=1, + ) + agent = _make_agent(retry_config=cfg, fallback_models=None) + agent._call_llm = AsyncMock(side_effect=_FakeAPIStatusError(503, "down")) + + with patch("agentflow.core.graph.agent_internal.execution.asyncio.sleep", new_callable=AsyncMock): + # First call fails normally and opens the circuit (threshold=1). + with pytest.raises(_FakeAPIStatusError): + await agent._call_llm_with_retry([{"role": "user", "content": "Hi"}]) + # Second call is short-circuited. + with pytest.raises(CircuitBreakerOpenError): + await agent._call_llm_with_retry([{"role": "user", "content": "Hi"}]) diff --git a/tests/graph/test_circuit_breaker.py b/tests/graph/test_circuit_breaker.py new file mode 100644 index 0000000..35823ff --- /dev/null +++ b/tests/graph/test_circuit_breaker.py @@ -0,0 +1,110 @@ +"""Unit tests for the CircuitBreaker used by the Agent retry/fallback loop.""" + +from __future__ import annotations + +import pytest + +from agentflow.core.graph.agent_internal.circuit_breaker import ( + CircuitBreaker, + CircuitBreakerOpenError, + CircuitState, +) + + +class _FakeClock: + """Deterministic monotonic clock; advance() to move time forward.""" + + def __init__(self) -> None: + self.now = 1000.0 + + def __call__(self) -> float: + return self.now + + def advance(self, seconds: float) -> None: + self.now += seconds + + +def _breaker(threshold: int = 3, reset: float = 30.0) -> tuple[CircuitBreaker, _FakeClock]: + clock = _FakeClock() + return CircuitBreaker(failure_threshold=threshold, reset_timeout=reset, time_func=clock), clock + + +class TestCircuitBreaker: + def test_starts_closed_and_allows(self): + cb, _ = _breaker() + assert cb.state is CircuitState.closed + assert cb.allow() is True + + def test_failures_below_threshold_stay_closed(self): + cb, _ = _breaker(threshold=3) + cb.record_failure() + cb.record_failure() + assert cb.state is CircuitState.closed + assert cb.allow() is True + + def test_opens_at_threshold(self): + cb, _ = _breaker(threshold=3) + for _ in range(3): + cb.record_failure() + assert cb.state is CircuitState.open + assert cb.allow() is False + + def test_success_resets_failure_count(self): + cb, _ = _breaker(threshold=3) + cb.record_failure() + cb.record_failure() + cb.record_success() + assert cb.failure_count == 0 + assert cb.state is CircuitState.closed + + def test_half_open_after_reset_timeout(self): + cb, clock = _breaker(threshold=2, reset=30.0) + cb.record_failure() + cb.record_failure() + assert cb.allow() is False # still open before timeout + clock.advance(30.0) + assert cb.allow() is True # transitions to half-open + assert cb.state is CircuitState.half_open + + def test_half_open_success_closes(self): + cb, clock = _breaker(threshold=1, reset=10.0) + cb.record_failure() # opens + clock.advance(10.0) + assert cb.allow() is True # half-open + cb.record_success() + assert cb.state is CircuitState.closed + assert cb.allow() is True + + def test_half_open_failure_reopens(self): + cb, clock = _breaker(threshold=1, reset=10.0) + cb.record_failure() # opens + clock.advance(10.0) + cb.allow() # half-open + cb.record_failure() # fails the trial + assert cb.state is CircuitState.open + assert cb.allow() is False # reopened, cooldown restarts + + def test_retry_after_counts_down(self): + cb, clock = _breaker(threshold=1, reset=30.0) + cb.record_failure() + assert cb.retry_after() == pytest.approx(30.0) + clock.advance(10.0) + assert cb.retry_after() == pytest.approx(20.0) + assert _breaker()[0].retry_after() == 0.0 # closed → 0 + + @pytest.mark.parametrize("bad", [0, -1]) + def test_invalid_threshold(self, bad): + with pytest.raises(ValueError, match="failure_threshold"): + CircuitBreaker(failure_threshold=bad) + + @pytest.mark.parametrize("bad", [0, -5.0]) + def test_invalid_reset_timeout(self, bad): + with pytest.raises(ValueError, match="reset_timeout"): + CircuitBreaker(reset_timeout=bad) + + +def test_circuit_breaker_open_error_carries_context(): + err = CircuitBreakerOpenError(("openai", "gpt-4o"), 12.5) + assert err.key == ("openai", "gpt-4o") + assert err.retry_after == 12.5 + assert "gpt-4o" in str(err) diff --git a/tests/graph/test_graph.py b/tests/graph/test_graph.py index 187d6f4..fce9f8c 100644 --- a/tests/graph/test_graph.py +++ b/tests/graph/test_graph.py @@ -267,6 +267,30 @@ async def test_aclose(self): assert isinstance(stats, dict) # noqa: S101 assert "background_tasks" in stats # noqa: S101 + @pytest.mark.asyncio + async def test_aclose_is_idempotent(self): + """Closing twice is a no-op and reports already_closed.""" + first = await self.compiled.aclose() + assert "background_tasks" in first # noqa: S101 + second = await self.compiled.aclose() + assert second == {"status": "already_closed"} # noqa: S101 + + @pytest.mark.asyncio + async def test_async_context_manager(self): + """`async with` returns the graph and closes it on exit.""" + async with self.compiled as ctx: + assert ctx is self.compiled # noqa: S101 + assert self.compiled._closed is False # noqa: S101, SLF001 + assert self.compiled._closed is True # noqa: S101, SLF001 + + @pytest.mark.asyncio + async def test_async_context_manager_closes_on_exception(self): + """aclose() still runs when the context body raises.""" + with pytest.raises(ValueError, match="boom"): + async with self.compiled: + raise ValueError("boom") + assert self.compiled._closed is True # noqa: S101, SLF001 + def test_generate_graph(self): """Test generating graph representation.""" graph_dict = self.compiled.generate_graph() diff --git a/tests/publisher/test_console_publisher.py b/tests/publisher/test_console_publisher.py index 1efa70c..6be4d29 100644 --- a/tests/publisher/test_console_publisher.py +++ b/tests/publisher/test_console_publisher.py @@ -321,6 +321,34 @@ async def test_publish_message_format(self, mock_print): for part in expected_parts: assert part in printed_msg + @pytest.mark.asyncio + async def test_publish_uses_logger_when_configured(self, caplog): + """With use_logger=True, events go through the logger, not stdout.""" + publisher = ConsolePublisher({"use_logger": True}) + assert publisher.use_logger is True + + event = EventModel( + event=Event.GRAPH_EXECUTION, + event_type=EventType.START, + node_name="logger_node", + data={"key": "value"}, + metadata={"user": "test_user"}, + ) + + with patch('builtins.print') as mock_print: + with caplog.at_level(logging.INFO, logger="agentflow.publisher"): + await publisher.publish(event) + + # Routed to the logger, never to stdout + mock_print.assert_not_called() + + assert "logger_node.start" in caplog.text + assert "{'key': 'value'}" in caplog.text + + def test_use_logger_defaults_to_false(self): + """Default sink is stdout (use_logger is False).""" + assert ConsolePublisher().use_logger is False + @pytest.mark.asyncio @patch('builtins.print') async def test_publish_with_unicode_content(self, mock_print): diff --git a/tests/utils/test_secret_redaction.py b/tests/utils/test_secret_redaction.py new file mode 100644 index 0000000..c4e9a01 --- /dev/null +++ b/tests/utils/test_secret_redaction.py @@ -0,0 +1,100 @@ +"""Tests for secret redaction in logging utilities.""" + +import logging + +import pytest + +from agentflow.utils.logging import ( + SecretRedactionFilter, + install_secret_redaction, + mask_secrets, +) + + +class TestMaskSecrets: + """Unit tests for the mask_secrets() pure function.""" + + @pytest.mark.parametrize( + "secret", + [ + "sk-abcdefghijklmnopqrstuvwxyz0123", + "sk-proj-abcdEFGH1234ijklMNOP5678qrst", + "AIzaSyA1234567890abcdefghijklmnopqrstuvw", + "ghp_0123456789abcdefghijklmnopqrstuvwx", + "xoxb-1234567890-abcdEFGHijkl", + "AKIAIOSFODNN7EXAMPLE", + ], + ) + def test_known_key_formats_are_redacted(self, secret): + out = mask_secrets(f"using credential {secret} now") + assert secret not in out + assert "***REDACTED***" in out + + def test_bearer_token_redacted_keeps_scheme(self): + out = mask_secrets("Authorization: Bearer abc.def.ghi-123") + assert "abc.def.ghi-123" not in out + assert "Bearer ***REDACTED***" in out + + def test_key_value_redacts_value_keeps_key(self): + out = mask_secrets('{"api_key": "super-secret-value-123"}') + assert "super-secret-value-123" not in out + assert "api_key" in out + assert "***REDACTED***" in out + + def test_signed_url_query_token_redacted(self): + url = "https://storage.example.com/blob/abc?X-Amz-Signature=deadbeefcafe&expires=99" + out = mask_secrets(url) + assert "deadbeefcafe" not in out + assert "***REDACTED***" in out + # Non-secret query params are preserved + assert "expires=99" in out + + def test_plain_text_is_unchanged(self): + text = "Switching to fallback model gemini-2.5-flash (provider=google)" + assert mask_secrets(text) == text + + def test_empty_string(self): + assert mask_secrets("") == "" + + +class TestSecretRedactionFilter: + """The logging.Filter integration.""" + + def test_filter_redacts_record_message(self): + record = logging.LogRecord( + name="agentflow.test", + level=logging.INFO, + pathname=__file__, + lineno=1, + msg="key sk-abcdefghijklmnopqrstuvwxyz0123 leaked", + args=(), + exc_info=None, + ) + assert SecretRedactionFilter().filter(record) is True + assert "sk-abcdefghijklmnopqrstuvwxyz0123" not in record.getMessage() + assert "***REDACTED***" in record.getMessage() + + def test_filter_redacts_args_interpolated_message(self): + record = logging.LogRecord( + name="agentflow.test", + level=logging.INFO, + pathname=__file__, + lineno=1, + msg="token=%s", + args=("Bearer secret-token-value",), + exc_info=None, + ) + SecretRedactionFilter().filter(record) + assert "secret-token-value" not in record.getMessage() + + def test_install_attaches_filter_to_handlers(self): + log = logging.getLogger("agentflow.test.install") + handler = logging.StreamHandler() + log.addHandler(handler) + try: + redactor = install_secret_redaction("agentflow.test.install") + assert redactor in log.filters + assert redactor in handler.filters + finally: + log.removeHandler(handler) + log.filters.clear()