From 20ac9547a869f7bb867fb041f1d026ba3726c912 Mon Sep 17 00:00:00 2001 From: Artur Shiriev Date: Sat, 11 Apr 2026 13:46:56 +0300 Subject: [PATCH] refactor and fix --- AGENTS.md | 133 ------------------------------- CLAUDE.md | 42 ++++++++++ README.md | 17 +++- db_retry/connections.py | 2 +- db_retry/dsn.py | 5 +- db_retry/retry.py | 50 ++++++++---- db_retry/settings.py | 4 +- db_retry/transaction.py | 8 +- pyproject.toml | 6 +- tests/conftest.py | 2 +- tests/test_connection_factory.py | 14 ++-- tests/test_retry.py | 45 +++++++++-- 12 files changed, 147 insertions(+), 181 deletions(-) delete mode 100644 AGENTS.md create mode 100644 CLAUDE.md diff --git a/AGENTS.md b/AGENTS.md deleted file mode 100644 index c85a70e..0000000 --- a/AGENTS.md +++ /dev/null @@ -1,133 +0,0 @@ -# Project Context for Agents - -## Project Overview - -This is a Python library called `db-retry` that provides PostgreSQL and SQLAlchemy utilities, specifically focusing on: - -1. **Retry decorators** for handling database connection issues and serialization errors -2. **Connection factory builders** for managing PostgreSQL connections with multiple hosts -3. **DSN (Data Source Name) utilities** for parsing and manipulating database connection strings -4. **Transaction helpers** for managing SQLAlchemy async sessions - -The library is built with modern Python practices (3.13+) and uses type hints extensively. It's designed to work with PostgreSQL databases using the asyncpg driver and SQLAlchemy's asyncio extension. - -## Key Technologies - -- **Python 3.13+** -- **SQLAlchemy** with asyncio extension -- **asyncpg** PostgreSQL driver -- **tenacity** for retry logic -- **uv** for package management and building -- **Docker** for development and testing environments -- **pytest** for testing -- **ruff** and **ty** for linting and type checking - -## Project Structure - -``` -db_retry/ -├── __init__.py # Exports all public APIs -├── connections.py # Connection factory builders -├── dsn.py # DSN parsing and manipulation utilities -├── retry.py # Retry decorators for database operations -├── settings.py # Configuration settings -├── transaction.py # Transaction helper classes -└── py.typed # Marker file for type checking -tests/ -├── test_connection_factory.py -├── test_dsn.py -├── test_retry.py -├── test_transaction.py -├── conftest.py # pytest configuration -└── __init__.py -``` - -## Main Components - -### Retry Decorators (`retry.py`) -Provides `@postgres_retry` decorator that automatically retries database operations when encountering: -- PostgreSQL connection errors -- Serialization errors - -The retry logic uses exponential backoff with jitter and is configurable via environment variables. - -### Connection Factory (`connections.py`) -Provides `build_connection_factory()` function that creates connection factories for PostgreSQL databases with support for: -- Multiple fallback hosts -- Randomized host selection -- Target session attributes (read-write vs standby) - -### DSN Utilities (`dsn.py`) -Provides functions for: -- `build_db_dsn()`: Parse and modify DSN strings, replacing database names and setting target session attributes -- `is_dsn_multihost()`: Check if a DSN contains multiple hosts - -### Transaction Helpers (`transaction.py`) -Provides `Transaction` class that wraps SQLAlchemy AsyncSession with automatic transaction management. - -## Building and Running - -### Development Environment Setup -```bash -# Install dependencies -just install - -# Run tests -just test - -# Run linting and type checking -just lint - -# Run all checks (default) -just -``` - -### Docker-based Development -```bash -# Run tests in Docker -just test - -# Run shell in Docker container -just sh -``` - -### Testing -Tests are written using pytest and can be run with: -```bash -# Run all tests -just test - -# Run specific test file -just test tests/test_retry.py - -# Run tests with coverage -just test --cov=. -``` - -## Configuration - -The library can be configured using environment variables: - -- `DB_UTILS_RETRIES_NUMBER`: Number of retry attempts (default: 3) - -## Development Conventions - -1. **Type Safety**: Strict ty checking is enforced -2. **Code Style**: Ruff is used for linting with specific rules configured -3. **Testing**: All functionality should have corresponding tests -4. **Async/Await**: All database operations are asynchronous -5. **Documentation**: Public APIs should be documented with docstrings - -## Common Tasks - -### Adding New Features -1. Implement the feature in the appropriate module -2. Add tests in the corresponding test file -3. Update exports in `__init__.py` if adding public APIs -4. Run `just` to ensure all checks pass - -### Modifying Retry Logic -The retry behavior is defined in `retry.py` and uses the tenacity library. Modify the `_retry_handler` function to change which exceptions trigger retries. - -### Working with Connections -Connection handling is in `connections.py`. The `build_connection_factory` function handles connecting to PostgreSQL with support for multiple hosts and fallback mechanisms. diff --git a/CLAUDE.md b/CLAUDE.md new file mode 100644 index 0000000..12599b8 --- /dev/null +++ b/CLAUDE.md @@ -0,0 +1,42 @@ +# CLAUDE.md + +This file provides guidance to Claude Code (claude.ai/code) when working with code in this repository. + +## Commands + +```bash +just install # uv lock --upgrade && uv sync +just lint # ruff format + eof-fixer (auto-fix) +just lint-ci # ruff check only (no fixes) +just build # build docker image +just test # run pytest inside docker (requires postgres) +just # install lint build test (full pipeline) +``` + +To run tests locally without Docker, set `DB_DSN` and run: +```bash +uv run pytest +uv run pytest tests/test_retry.py::test_postgres_retry # single test +``` + +The CI `DB_DSN` format: `postgresql+asyncpg://postgres:postgres@localhost:5432/postgres` + +## Architecture + +The package (`db_retry/`) exposes five public symbols via `__init__.py`: + +- **`postgres_retry`** (`retry.py`) — async tenacity decorator that retries on `asyncpg.SerializationError` (40001) and `asyncpg.PostgresConnectionError` (08000/08003). Walks the exception chain via `DBAPIError.orig.__cause__` to distinguish retriable errors from others like `StatementCompletionUnknownError` (40002). Supports bare `@postgres_retry` (uses default) and `@postgres_retry(retries=N)` for per-callsite override. + +- **`build_connection_factory`** (`connections.py`) — returns an async callable suitable for SQLAlchemy's `async_engine_from_config`. Handles multi-host DSNs by randomizing host order (load balancing) and attempting all hosts on timeout before raising `TargetServerAttributeNotMatched`. + +- **`build_db_dsn`** / **`is_dsn_multihost`** (`dsn.py`) — parse and construct `sqlalchemy.URL` objects. Multi-host DSNs encode additional hosts in query parameters. Existing `target_session_attrs` in the DSN is preserved (not overwritten). + +- **`Transaction`** (`transaction.py`) — frozen dataclass context manager wrapping `AsyncSession`. Supports optional isolation level (e.g., `"SERIALIZABLE"`). Auto-rolls back on `__aexit__` if the session is still in a transaction (i.e. no explicit `.commit()` or `.rollback()` was called). Uses `typing.Self` (no `typing_extensions` dependency). + +- **`settings.py`** — exposes `get_retries_number()` which reads `DB_RETRY_RETRIES_NUMBER` env var at call time (default: 3), allowing `monkeypatch.setenv` to work in tests. + +## Linting / Type Checking + +Ruff is configured with `select = ["ALL"]` plus specific exclusions. Line length is 120. Run `just lint` before committing. + +Type checking uses `ty` (not mypy). In code, use `ty: ignore` for suppression comments (not `type: ignore`). diff --git a/README.md b/README.md index cc0300b..192d9b4 100644 --- a/README.md +++ b/README.md @@ -45,7 +45,7 @@ class User(DeclarativeBase): email: Mapped[str] = mapped_column(sa.String(), index=True) -# Apply retry logic to ORM operations +# Apply retry logic to ORM operations (uses DB_RETRY_RETRIES_NUMBER, default 3) @postgres_retry async def get_user_by_email(session: AsyncSession, email: str) -> User: return await session.scalar( @@ -65,6 +65,14 @@ async def main(): asyncio.run(main()) ``` +Per-callsite retry count override: + +```python +@postgres_retry(retries=5) +async def create_order(session: AsyncSession, order: Order) -> Order: + ... +``` + ### 2. High Availability Database Connections Set up resilient database connections with multiple fallback hosts: @@ -173,13 +181,14 @@ The library can be configured using environment variables: Example: ```bash -export DB_UTILS_RETRIES_NUMBER=5 +export DB_RETRY_RETRIES_NUMBER=5 ``` ## API Reference ### Retry Decorator -- `@postgres_retry` - Decorator for async functions that should retry on database errors +- `@postgres_retry` - Decorator for async functions that should retry on database errors (uses `DB_RETRY_RETRIES_NUMBER`) +- `@postgres_retry(retries=N)` - Override retry count per callsite ### Connection Utilities - `build_connection_factory(url, timeout)` - Creates a connection factory for multi-host setups @@ -187,7 +196,7 @@ export DB_UTILS_RETRIES_NUMBER=5 - `is_dsn_multihost(db_dsn)` - Checks if a DSN contains multiple hosts ### Transaction Helper -- `Transaction(session, isolation_level=None)` - Context manager for simplified transaction handling +- `Transaction(session, isolation_level=None)` - Context manager for transaction handling; auto-rolls back on exit if no explicit `.commit()` or `.rollback()` was called ## Requirements diff --git a/db_retry/connections.py b/db_retry/connections.py index 4198b43..953dd1e 100644 --- a/db_retry/connections.py +++ b/db_retry/connections.py @@ -32,7 +32,7 @@ def build_connection_factory( hosts: str | list[str] ports: int | list[int] | None if isinstance(raw_hosts, list) and isinstance(raw_ports, list): - hosts_and_ports = list(zip(raw_hosts, raw_ports, strict=True)) + hosts_and_ports = typing.cast("list[tuple[str, int]]", list(zip(raw_hosts, raw_ports, strict=True))) random.shuffle(hosts_and_ports) hosts = list(map(itemgetter(0), hosts_and_ports)) ports = list(map(itemgetter(1), hosts_and_ports)) diff --git a/db_retry/dsn.py b/db_retry/dsn.py index 273b612..ecdd45f 100644 --- a/db_retry/dsn.py +++ b/db_retry/dsn.py @@ -21,10 +21,7 @@ def build_db_dsn( return parsed_db_dsn.set( database=database_name, drivername=drivername, - query=db_dsn_query - | { - "target_session_attrs": "prefer-standby" if use_replica else "read-write", - }, + query={"target_session_attrs": "prefer-standby" if use_replica else "read-write"} | db_dsn_query, ) diff --git a/db_retry/retry.py b/db_retry/retry.py index 0e46db7..eb9cee6 100644 --- a/db_retry/retry.py +++ b/db_retry/retry.py @@ -16,7 +16,8 @@ def _retry_handler(exception: BaseException) -> bool: if ( isinstance(exception, DBAPIError) and hasattr(exception, "orig") - and isinstance(exception.orig.__cause__, (asyncpg.SerializationError, asyncpg.PostgresConnectionError)) # type: ignore[union-attr] + and exception.orig is not None + and isinstance(exception.orig.__cause__, (asyncpg.SerializationError, asyncpg.PostgresConnectionError)) ): logger.debug("postgres_retry, retrying") return True @@ -25,18 +26,37 @@ def _retry_handler(exception: BaseException) -> bool: return False +type _Func[**P, T] = typing.Callable[P, typing.Coroutine[None, None, T]] +type _Decorator[**P, T] = typing.Callable[[_Func[P, T]], _Func[P, T]] + + +@typing.overload +def postgres_retry[**P, T](func: _Func[P, T], *, retries: int | None = ...) -> _Func[P, T]: ... + + +@typing.overload +def postgres_retry[**P, T](func: None = ..., *, retries: int | None = ...) -> _Decorator[P, T]: ... + + def postgres_retry[**P, T]( - func: typing.Callable[P, typing.Coroutine[None, None, T]], -) -> typing.Callable[P, typing.Coroutine[None, None, T]]: - @tenacity.retry( - stop=tenacity.stop_after_attempt(settings.DB_RETRY_RETRIES_NUMBER), - wait=tenacity.wait_exponential_jitter(), - retry=tenacity.retry_if_exception(_retry_handler), - reraise=True, - before=tenacity.before_log(logger, logging.DEBUG), # ty: ignore[invalid-argument-type] - ) - @functools.wraps(func) - async def wrapped_method(*args: P.args, **kwargs: P.kwargs) -> T: - return await func(*args, **kwargs) - - return wrapped_method + func: _Func[P, T] | None = None, + *, + retries: int | None = None, +) -> _Func[P, T] | _Decorator[P, T]: + def decorator(f: _Func[P, T]) -> _Func[P, T]: + @functools.wraps(f) + async def wrapped_method(*args: P.args, **kwargs: P.kwargs) -> T: + retryer = tenacity.AsyncRetrying( + stop=tenacity.stop_after_attempt(retries if retries is not None else settings.get_retries_number()), + wait=tenacity.wait_exponential_jitter(), + retry=tenacity.retry_if_exception(_retry_handler), + reraise=True, + before=tenacity.before_log(logger, logging.DEBUG), + ) + return await retryer(f, *args, **kwargs) + + return wrapped_method + + if func is not None: + return decorator(func) + return decorator diff --git a/db_retry/settings.py b/db_retry/settings.py index c04d63e..d654c37 100644 --- a/db_retry/settings.py +++ b/db_retry/settings.py @@ -1,5 +1,5 @@ import os -import typing -DB_RETRY_RETRIES_NUMBER: typing.Final = int(os.getenv("DB_RETRY_RETRIES_NUMBER", "3")) +def get_retries_number() -> int: + return int(os.getenv("DB_RETRY_RETRIES_NUMBER", "3")) diff --git a/db_retry/transaction.py b/db_retry/transaction.py index c6f422d..efcaa89 100644 --- a/db_retry/transaction.py +++ b/db_retry/transaction.py @@ -1,6 +1,6 @@ import dataclasses +import typing -import typing_extensions from sqlalchemy.engine.interfaces import IsolationLevel from sqlalchemy.ext import asyncio as sa_async @@ -10,7 +10,7 @@ class Transaction: session: sa_async.AsyncSession isolation_level: IsolationLevel | None = None - async def __aenter__(self) -> typing_extensions.Self: + async def __aenter__(self) -> typing.Self: if self.isolation_level: await self.session.connection(execution_options={"isolation_level": self.isolation_level}) @@ -18,7 +18,9 @@ async def __aenter__(self) -> typing_extensions.Self: await self.session.begin() return self - async def __aexit__(self, *args: object, **kwargs: object) -> None: + async def __aexit__(self, exc_type: object, *args: object, **kwargs: object) -> None: + if self.session.in_transaction(): + await self.session.rollback() await self.session.close() async def commit(self) -> None: diff --git a/pyproject.toml b/pyproject.toml index 9f32d2b..12989f3 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -2,7 +2,7 @@ name = "db-retry" description = "PostgreSQL and SQLAlchemy Tools" authors = [ - { name = "community-of-python" }, + { name = "Artur Shiriev", email = "me@shiriev.ru" }, ] readme = "README.md" requires-python = ">=3.13,<4" @@ -18,7 +18,7 @@ classifiers = [ "Typing :: Typed", "Topic :: Software Development :: Libraries", ] -version = "0" +version = "0.0.0" dependencies = [ "tenacity", "sqlalchemy[asyncio]", @@ -26,7 +26,7 @@ dependencies = [ ] [project.urls] -repository = "https://github.com/modern-python/sa-utils" +repository = "https://github.com/modern-python/db-retry" [dependency-groups] dev = [ diff --git a/tests/conftest.py b/tests/conftest.py index 534a99d..fbf341d 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -7,7 +7,7 @@ @pytest.fixture async def async_engine() -> typing.AsyncIterator[sa_async.AsyncEngine]: - engine: typing.Final = sa_async.create_async_engine(url=os.getenv("DB_DSN", ""), echo=True, echo_pool=True) + engine: typing.Final = sa_async.create_async_engine(url=os.environ["DB_DSN"], echo=True, echo_pool=True) try: yield engine finally: diff --git a/tests/test_connection_factory.py b/tests/test_connection_factory.py index 0886f7a..507f062 100644 --- a/tests/test_connection_factory.py +++ b/tests/test_connection_factory.py @@ -55,15 +55,11 @@ async def test_connection_factory_failure_several_hosts( async def test_connection_factory_failure_and_success(monkeypatch: pytest.MonkeyPatch) -> None: - monkeypatch.setattr("asyncpg.connect", mock.AsyncMock(side_effect=(TimeoutError, ""))) + mock_connection: typing.Final = mock.AsyncMock(spec=asyncpg.Connection) + monkeypatch.setattr("asyncpg.connect", mock.AsyncMock(side_effect=(TimeoutError, mock_connection))) url: typing.Final = sqlalchemy.make_url( "postgresql+asyncpg://user:password@/database?host=host1:5432&host=host2:5432" ) - engine: typing.Final = sa_async.create_async_engine( - url=url, echo=True, echo_pool=True, async_creator=build_connection_factory(url=url, timeout=1.0) - ) - try: - with pytest.raises(AttributeError): - await engine.connect().__aenter__() - finally: - await engine.dispose() + factory: typing.Final = build_connection_factory(url=url, timeout=1.0) + result = await factory() + assert result is mock_connection diff --git a/tests/test_retry.py b/tests/test_retry.py index 44500ee..7abe751 100644 --- a/tests/test_retry.py +++ b/tests/test_retry.py @@ -7,15 +7,15 @@ @pytest.mark.parametrize( - "error_code", + ("error_code", "expected_calls"), [ - "08000", # PostgresConnectionError - backoff triggered - "08003", # subclass of PostgresConnectionError - backoff triggered - "40001", # SerializationError - backoff triggered - "40002", # StatementCompletionUnknownError - backoff not triggered + ("08000", 2), # PostgresConnectionError - backoff triggered, 1 retry + ("08003", 2), # subclass of PostgresConnectionError - backoff triggered, 1 retry + ("40001", 2), # SerializationError - backoff triggered, 1 retry + ("40002", 1), # StatementCompletionUnknownError - backoff not triggered ], ) -async def test_postgres_retry(async_engine: sa_async.AsyncEngine, error_code: str) -> None: +async def test_postgres_retry(async_engine: sa_async.AsyncEngine, error_code: str, expected_calls: int) -> None: async with async_engine.connect() as connection: await connection.execute( sqlalchemy.text( @@ -30,9 +30,42 @@ async def test_postgres_retry(async_engine: sa_async.AsyncEngine, error_code: st ), ) + call_count = 0 + @postgres_retry async def raise_error() -> None: + nonlocal call_count + call_count += 1 + await connection.execute(sqlalchemy.text("SELECT raise_error()")) + + with pytest.raises(DBAPIError): + await raise_error() + + assert call_count == expected_calls + + +async def test_postgres_retry_with_retries(async_engine: sa_async.AsyncEngine) -> None: + async with async_engine.connect() as connection: + await connection.execute( + sqlalchemy.text(""" + CREATE OR REPLACE FUNCTION raise_error() + RETURNS VOID AS $$ + BEGIN + RAISE SQLSTATE '40001'; + END; + $$ LANGUAGE plpgsql; + """), + ) + + call_count = 0 + + @postgres_retry(retries=1) + async def raise_error() -> None: + nonlocal call_count + call_count += 1 await connection.execute(sqlalchemy.text("SELECT raise_error()")) with pytest.raises(DBAPIError): await raise_error() + + assert call_count == 1