Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
133 changes: 0 additions & 133 deletions AGENTS.md

This file was deleted.

42 changes: 42 additions & 0 deletions CLAUDE.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
# CLAUDE.md

This file provides guidance to Claude Code (claude.ai/code) when working with code in this repository.

## Commands

```bash
just install # uv lock --upgrade && uv sync
just lint # ruff format + eof-fixer (auto-fix)
just lint-ci # ruff check only (no fixes)
just build # build docker image
just test # run pytest inside docker (requires postgres)
just # install lint build test (full pipeline)
```

To run tests locally without Docker, set `DB_DSN` and run:
```bash
uv run pytest
uv run pytest tests/test_retry.py::test_postgres_retry # single test
```

The CI `DB_DSN` format: `postgresql+asyncpg://postgres:postgres@localhost:5432/postgres`

## Architecture

The package (`db_retry/`) exposes five public symbols via `__init__.py`:

- **`postgres_retry`** (`retry.py`) — async tenacity decorator that retries on `asyncpg.SerializationError` (40001) and `asyncpg.PostgresConnectionError` (08000/08003). Walks the exception chain via `DBAPIError.orig.__cause__` to distinguish retriable errors from others like `StatementCompletionUnknownError` (40002). Supports bare `@postgres_retry` (uses default) and `@postgres_retry(retries=N)` for per-callsite override.

- **`build_connection_factory`** (`connections.py`) — returns an async callable suitable for SQLAlchemy's `async_engine_from_config`. Handles multi-host DSNs by randomizing host order (load balancing) and attempting all hosts on timeout before raising `TargetServerAttributeNotMatched`.

- **`build_db_dsn`** / **`is_dsn_multihost`** (`dsn.py`) — parse and construct `sqlalchemy.URL` objects. Multi-host DSNs encode additional hosts in query parameters. Existing `target_session_attrs` in the DSN is preserved (not overwritten).

- **`Transaction`** (`transaction.py`) — frozen dataclass context manager wrapping `AsyncSession`. Supports optional isolation level (e.g., `"SERIALIZABLE"`). Auto-rolls back on `__aexit__` if the session is still in a transaction (i.e. no explicit `.commit()` or `.rollback()` was called). Uses `typing.Self` (no `typing_extensions` dependency).

- **`settings.py`** — exposes `get_retries_number()` which reads `DB_RETRY_RETRIES_NUMBER` env var at call time (default: 3), allowing `monkeypatch.setenv` to work in tests.

## Linting / Type Checking

Ruff is configured with `select = ["ALL"]` plus specific exclusions. Line length is 120. Run `just lint` before committing.

Type checking uses `ty` (not mypy). In code, use `ty: ignore` for suppression comments (not `type: ignore`).
17 changes: 13 additions & 4 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ class User(DeclarativeBase):
email: Mapped[str] = mapped_column(sa.String(), index=True)


# Apply retry logic to ORM operations
# Apply retry logic to ORM operations (uses DB_RETRY_RETRIES_NUMBER, default 3)
@postgres_retry
async def get_user_by_email(session: AsyncSession, email: str) -> User:
return await session.scalar(
Expand All @@ -65,6 +65,14 @@ async def main():
asyncio.run(main())
```

Per-callsite retry count override:

```python
@postgres_retry(retries=5)
async def create_order(session: AsyncSession, order: Order) -> Order:
...
```

### 2. High Availability Database Connections

Set up resilient database connections with multiple fallback hosts:
Expand Down Expand Up @@ -173,21 +181,22 @@ The library can be configured using environment variables:

Example:
```bash
export DB_UTILS_RETRIES_NUMBER=5
export DB_RETRY_RETRIES_NUMBER=5
```

## API Reference

### Retry Decorator
- `@postgres_retry` - Decorator for async functions that should retry on database errors
- `@postgres_retry` - Decorator for async functions that should retry on database errors (uses `DB_RETRY_RETRIES_NUMBER`)
- `@postgres_retry(retries=N)` - Override retry count per callsite

### Connection Utilities
- `build_connection_factory(url, timeout)` - Creates a connection factory for multi-host setups
- `build_db_dsn(db_dsn, database_name, use_replica=False, drivername="postgresql")` - Builds a DSN with specified parameters
- `is_dsn_multihost(db_dsn)` - Checks if a DSN contains multiple hosts

### Transaction Helper
- `Transaction(session, isolation_level=None)` - Context manager for simplified transaction handling
- `Transaction(session, isolation_level=None)` - Context manager for transaction handling; auto-rolls back on exit if no explicit `.commit()` or `.rollback()` was called

## Requirements

Expand Down
2 changes: 1 addition & 1 deletion db_retry/connections.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ def build_connection_factory(
hosts: str | list[str]
ports: int | list[int] | None
if isinstance(raw_hosts, list) and isinstance(raw_ports, list):
hosts_and_ports = list(zip(raw_hosts, raw_ports, strict=True))
hosts_and_ports = typing.cast("list[tuple[str, int]]", list(zip(raw_hosts, raw_ports, strict=True)))
random.shuffle(hosts_and_ports)
hosts = list(map(itemgetter(0), hosts_and_ports))
ports = list(map(itemgetter(1), hosts_and_ports))
Expand Down
5 changes: 1 addition & 4 deletions db_retry/dsn.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,10 +21,7 @@ def build_db_dsn(
return parsed_db_dsn.set(
database=database_name,
drivername=drivername,
query=db_dsn_query
| {
"target_session_attrs": "prefer-standby" if use_replica else "read-write",
},
query={"target_session_attrs": "prefer-standby" if use_replica else "read-write"} | db_dsn_query,
)


Expand Down
50 changes: 35 additions & 15 deletions db_retry/retry.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,8 @@ def _retry_handler(exception: BaseException) -> bool:
if (
isinstance(exception, DBAPIError)
and hasattr(exception, "orig")
and isinstance(exception.orig.__cause__, (asyncpg.SerializationError, asyncpg.PostgresConnectionError)) # type: ignore[union-attr]
and exception.orig is not None
and isinstance(exception.orig.__cause__, (asyncpg.SerializationError, asyncpg.PostgresConnectionError))
):
logger.debug("postgres_retry, retrying")
return True
Expand All @@ -25,18 +26,37 @@ def _retry_handler(exception: BaseException) -> bool:
return False


type _Func[**P, T] = typing.Callable[P, typing.Coroutine[None, None, T]]
type _Decorator[**P, T] = typing.Callable[[_Func[P, T]], _Func[P, T]]


@typing.overload
def postgres_retry[**P, T](func: _Func[P, T], *, retries: int | None = ...) -> _Func[P, T]: ...


@typing.overload
def postgres_retry[**P, T](func: None = ..., *, retries: int | None = ...) -> _Decorator[P, T]: ...


def postgres_retry[**P, T](
func: typing.Callable[P, typing.Coroutine[None, None, T]],
) -> typing.Callable[P, typing.Coroutine[None, None, T]]:
@tenacity.retry(
stop=tenacity.stop_after_attempt(settings.DB_RETRY_RETRIES_NUMBER),
wait=tenacity.wait_exponential_jitter(),
retry=tenacity.retry_if_exception(_retry_handler),
reraise=True,
before=tenacity.before_log(logger, logging.DEBUG), # ty: ignore[invalid-argument-type]
)
@functools.wraps(func)
async def wrapped_method(*args: P.args, **kwargs: P.kwargs) -> T:
return await func(*args, **kwargs)

return wrapped_method
func: _Func[P, T] | None = None,
*,
retries: int | None = None,
) -> _Func[P, T] | _Decorator[P, T]:
def decorator(f: _Func[P, T]) -> _Func[P, T]:
@functools.wraps(f)
async def wrapped_method(*args: P.args, **kwargs: P.kwargs) -> T:
retryer = tenacity.AsyncRetrying(
stop=tenacity.stop_after_attempt(retries if retries is not None else settings.get_retries_number()),
wait=tenacity.wait_exponential_jitter(),
retry=tenacity.retry_if_exception(_retry_handler),
reraise=True,
before=tenacity.before_log(logger, logging.DEBUG),
)
return await retryer(f, *args, **kwargs)

return wrapped_method

if func is not None:
return decorator(func)
return decorator
4 changes: 2 additions & 2 deletions db_retry/settings.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import os
import typing


DB_RETRY_RETRIES_NUMBER: typing.Final = int(os.getenv("DB_RETRY_RETRIES_NUMBER", "3"))
def get_retries_number() -> int:
return int(os.getenv("DB_RETRY_RETRIES_NUMBER", "3"))
8 changes: 5 additions & 3 deletions db_retry/transaction.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import dataclasses
import typing

import typing_extensions
from sqlalchemy.engine.interfaces import IsolationLevel
from sqlalchemy.ext import asyncio as sa_async

Expand All @@ -10,15 +10,17 @@ class Transaction:
session: sa_async.AsyncSession
isolation_level: IsolationLevel | None = None

async def __aenter__(self) -> typing_extensions.Self:
async def __aenter__(self) -> typing.Self:
if self.isolation_level:
await self.session.connection(execution_options={"isolation_level": self.isolation_level})

if not self.session.in_transaction():
await self.session.begin()
return self

async def __aexit__(self, *args: object, **kwargs: object) -> None:
async def __aexit__(self, exc_type: object, *args: object, **kwargs: object) -> None:
if self.session.in_transaction():
await self.session.rollback()
await self.session.close()

async def commit(self) -> None:
Expand Down
6 changes: 3 additions & 3 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
name = "db-retry"
description = "PostgreSQL and SQLAlchemy Tools"
authors = [
{ name = "community-of-python" },
{ name = "Artur Shiriev", email = "me@shiriev.ru" },
]
readme = "README.md"
requires-python = ">=3.13,<4"
Expand All @@ -18,15 +18,15 @@ classifiers = [
"Typing :: Typed",
"Topic :: Software Development :: Libraries",
]
version = "0"
version = "0.0.0"
dependencies = [
"tenacity",
"sqlalchemy[asyncio]",
"asyncpg",
]

[project.urls]
repository = "https://github.com/modern-python/sa-utils"
repository = "https://github.com/modern-python/db-retry"

[dependency-groups]
dev = [
Expand Down
2 changes: 1 addition & 1 deletion tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

@pytest.fixture
async def async_engine() -> typing.AsyncIterator[sa_async.AsyncEngine]:
engine: typing.Final = sa_async.create_async_engine(url=os.getenv("DB_DSN", ""), echo=True, echo_pool=True)
engine: typing.Final = sa_async.create_async_engine(url=os.environ["DB_DSN"], echo=True, echo_pool=True)
try:
yield engine
finally:
Expand Down
Loading
Loading