Skip to content

Commit c134ba1

Browse files
authored
Merge pull request #9 from modern-python/refactor
refactor and fix
2 parents 33a9666 + 20ac954 commit c134ba1

12 files changed

Lines changed: 147 additions & 181 deletions

AGENTS.md

Lines changed: 0 additions & 133 deletions
This file was deleted.

CLAUDE.md

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,42 @@
1+
# CLAUDE.md
2+
3+
This file provides guidance to Claude Code (claude.ai/code) when working with code in this repository.
4+
5+
## Commands
6+
7+
```bash
8+
just install # uv lock --upgrade && uv sync
9+
just lint # ruff format + eof-fixer (auto-fix)
10+
just lint-ci # ruff check only (no fixes)
11+
just build # build docker image
12+
just test # run pytest inside docker (requires postgres)
13+
just # install lint build test (full pipeline)
14+
```
15+
16+
To run tests locally without Docker, set `DB_DSN` and run:
17+
```bash
18+
uv run pytest
19+
uv run pytest tests/test_retry.py::test_postgres_retry # single test
20+
```
21+
22+
The CI `DB_DSN` format: `postgresql+asyncpg://postgres:postgres@localhost:5432/postgres`
23+
24+
## Architecture
25+
26+
The package (`db_retry/`) exposes five public symbols via `__init__.py`:
27+
28+
- **`postgres_retry`** (`retry.py`) — async tenacity decorator that retries on `asyncpg.SerializationError` (40001) and `asyncpg.PostgresConnectionError` (08000/08003). Walks the exception chain via `DBAPIError.orig.__cause__` to distinguish retriable errors from others like `StatementCompletionUnknownError` (40002). Supports bare `@postgres_retry` (uses default) and `@postgres_retry(retries=N)` for per-callsite override.
29+
30+
- **`build_connection_factory`** (`connections.py`) — returns an async callable suitable for SQLAlchemy's `async_engine_from_config`. Handles multi-host DSNs by randomizing host order (load balancing) and attempting all hosts on timeout before raising `TargetServerAttributeNotMatched`.
31+
32+
- **`build_db_dsn`** / **`is_dsn_multihost`** (`dsn.py`) — parse and construct `sqlalchemy.URL` objects. Multi-host DSNs encode additional hosts in query parameters. Existing `target_session_attrs` in the DSN is preserved (not overwritten).
33+
34+
- **`Transaction`** (`transaction.py`) — frozen dataclass context manager wrapping `AsyncSession`. Supports optional isolation level (e.g., `"SERIALIZABLE"`). Auto-rolls back on `__aexit__` if the session is still in a transaction (i.e. no explicit `.commit()` or `.rollback()` was called). Uses `typing.Self` (no `typing_extensions` dependency).
35+
36+
- **`settings.py`** — exposes `get_retries_number()` which reads `DB_RETRY_RETRIES_NUMBER` env var at call time (default: 3), allowing `monkeypatch.setenv` to work in tests.
37+
38+
## Linting / Type Checking
39+
40+
Ruff is configured with `select = ["ALL"]` plus specific exclusions. Line length is 120. Run `just lint` before committing.
41+
42+
Type checking uses `ty` (not mypy). In code, use `ty: ignore` for suppression comments (not `type: ignore`).

README.md

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@ class User(DeclarativeBase):
4545
email: Mapped[str] = mapped_column(sa.String(), index=True)
4646

4747

48-
# Apply retry logic to ORM operations
48+
# Apply retry logic to ORM operations (uses DB_RETRY_RETRIES_NUMBER, default 3)
4949
@postgres_retry
5050
async def get_user_by_email(session: AsyncSession, email: str) -> User:
5151
return await session.scalar(
@@ -65,6 +65,14 @@ async def main():
6565
asyncio.run(main())
6666
```
6767

68+
Per-callsite retry count override:
69+
70+
```python
71+
@postgres_retry(retries=5)
72+
async def create_order(session: AsyncSession, order: Order) -> Order:
73+
...
74+
```
75+
6876
### 2. High Availability Database Connections
6977

7078
Set up resilient database connections with multiple fallback hosts:
@@ -173,21 +181,22 @@ The library can be configured using environment variables:
173181

174182
Example:
175183
```bash
176-
export DB_UTILS_RETRIES_NUMBER=5
184+
export DB_RETRY_RETRIES_NUMBER=5
177185
```
178186

179187
## API Reference
180188

181189
### Retry Decorator
182-
- `@postgres_retry` - Decorator for async functions that should retry on database errors
190+
- `@postgres_retry` - Decorator for async functions that should retry on database errors (uses `DB_RETRY_RETRIES_NUMBER`)
191+
- `@postgres_retry(retries=N)` - Override retry count per callsite
183192

184193
### Connection Utilities
185194
- `build_connection_factory(url, timeout)` - Creates a connection factory for multi-host setups
186195
- `build_db_dsn(db_dsn, database_name, use_replica=False, drivername="postgresql")` - Builds a DSN with specified parameters
187196
- `is_dsn_multihost(db_dsn)` - Checks if a DSN contains multiple hosts
188197

189198
### Transaction Helper
190-
- `Transaction(session, isolation_level=None)` - Context manager for simplified transaction handling
199+
- `Transaction(session, isolation_level=None)` - Context manager for transaction handling; auto-rolls back on exit if no explicit `.commit()` or `.rollback()` was called
191200

192201
## Requirements
193202

db_retry/connections.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ def build_connection_factory(
3232
hosts: str | list[str]
3333
ports: int | list[int] | None
3434
if isinstance(raw_hosts, list) and isinstance(raw_ports, list):
35-
hosts_and_ports = list(zip(raw_hosts, raw_ports, strict=True))
35+
hosts_and_ports = typing.cast("list[tuple[str, int]]", list(zip(raw_hosts, raw_ports, strict=True)))
3636
random.shuffle(hosts_and_ports)
3737
hosts = list(map(itemgetter(0), hosts_and_ports))
3838
ports = list(map(itemgetter(1), hosts_and_ports))

db_retry/dsn.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -21,10 +21,7 @@ def build_db_dsn(
2121
return parsed_db_dsn.set(
2222
database=database_name,
2323
drivername=drivername,
24-
query=db_dsn_query
25-
| {
26-
"target_session_attrs": "prefer-standby" if use_replica else "read-write",
27-
},
24+
query={"target_session_attrs": "prefer-standby" if use_replica else "read-write"} | db_dsn_query,
2825
)
2926

3027

db_retry/retry.py

Lines changed: 35 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,8 @@ def _retry_handler(exception: BaseException) -> bool:
1616
if (
1717
isinstance(exception, DBAPIError)
1818
and hasattr(exception, "orig")
19-
and isinstance(exception.orig.__cause__, (asyncpg.SerializationError, asyncpg.PostgresConnectionError)) # type: ignore[union-attr]
19+
and exception.orig is not None
20+
and isinstance(exception.orig.__cause__, (asyncpg.SerializationError, asyncpg.PostgresConnectionError))
2021
):
2122
logger.debug("postgres_retry, retrying")
2223
return True
@@ -25,18 +26,37 @@ def _retry_handler(exception: BaseException) -> bool:
2526
return False
2627

2728

29+
type _Func[**P, T] = typing.Callable[P, typing.Coroutine[None, None, T]]
30+
type _Decorator[**P, T] = typing.Callable[[_Func[P, T]], _Func[P, T]]
31+
32+
33+
@typing.overload
34+
def postgres_retry[**P, T](func: _Func[P, T], *, retries: int | None = ...) -> _Func[P, T]: ...
35+
36+
37+
@typing.overload
38+
def postgres_retry[**P, T](func: None = ..., *, retries: int | None = ...) -> _Decorator[P, T]: ...
39+
40+
2841
def postgres_retry[**P, T](
29-
func: typing.Callable[P, typing.Coroutine[None, None, T]],
30-
) -> typing.Callable[P, typing.Coroutine[None, None, T]]:
31-
@tenacity.retry(
32-
stop=tenacity.stop_after_attempt(settings.DB_RETRY_RETRIES_NUMBER),
33-
wait=tenacity.wait_exponential_jitter(),
34-
retry=tenacity.retry_if_exception(_retry_handler),
35-
reraise=True,
36-
before=tenacity.before_log(logger, logging.DEBUG), # ty: ignore[invalid-argument-type]
37-
)
38-
@functools.wraps(func)
39-
async def wrapped_method(*args: P.args, **kwargs: P.kwargs) -> T:
40-
return await func(*args, **kwargs)
41-
42-
return wrapped_method
42+
func: _Func[P, T] | None = None,
43+
*,
44+
retries: int | None = None,
45+
) -> _Func[P, T] | _Decorator[P, T]:
46+
def decorator(f: _Func[P, T]) -> _Func[P, T]:
47+
@functools.wraps(f)
48+
async def wrapped_method(*args: P.args, **kwargs: P.kwargs) -> T:
49+
retryer = tenacity.AsyncRetrying(
50+
stop=tenacity.stop_after_attempt(retries if retries is not None else settings.get_retries_number()),
51+
wait=tenacity.wait_exponential_jitter(),
52+
retry=tenacity.retry_if_exception(_retry_handler),
53+
reraise=True,
54+
before=tenacity.before_log(logger, logging.DEBUG),
55+
)
56+
return await retryer(f, *args, **kwargs)
57+
58+
return wrapped_method
59+
60+
if func is not None:
61+
return decorator(func)
62+
return decorator

db_retry/settings.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
import os
2-
import typing
32

43

5-
DB_RETRY_RETRIES_NUMBER: typing.Final = int(os.getenv("DB_RETRY_RETRIES_NUMBER", "3"))
4+
def get_retries_number() -> int:
5+
return int(os.getenv("DB_RETRY_RETRIES_NUMBER", "3"))

db_retry/transaction.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import dataclasses
2+
import typing
23

3-
import typing_extensions
44
from sqlalchemy.engine.interfaces import IsolationLevel
55
from sqlalchemy.ext import asyncio as sa_async
66

@@ -10,15 +10,17 @@ class Transaction:
1010
session: sa_async.AsyncSession
1111
isolation_level: IsolationLevel | None = None
1212

13-
async def __aenter__(self) -> typing_extensions.Self:
13+
async def __aenter__(self) -> typing.Self:
1414
if self.isolation_level:
1515
await self.session.connection(execution_options={"isolation_level": self.isolation_level})
1616

1717
if not self.session.in_transaction():
1818
await self.session.begin()
1919
return self
2020

21-
async def __aexit__(self, *args: object, **kwargs: object) -> None:
21+
async def __aexit__(self, exc_type: object, *args: object, **kwargs: object) -> None:
22+
if self.session.in_transaction():
23+
await self.session.rollback()
2224
await self.session.close()
2325

2426
async def commit(self) -> None:

pyproject.toml

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
name = "db-retry"
33
description = "PostgreSQL and SQLAlchemy Tools"
44
authors = [
5-
{ name = "community-of-python" },
5+
{ name = "Artur Shiriev", email = "me@shiriev.ru" },
66
]
77
readme = "README.md"
88
requires-python = ">=3.13,<4"
@@ -18,15 +18,15 @@ classifiers = [
1818
"Typing :: Typed",
1919
"Topic :: Software Development :: Libraries",
2020
]
21-
version = "0"
21+
version = "0.0.0"
2222
dependencies = [
2323
"tenacity",
2424
"sqlalchemy[asyncio]",
2525
"asyncpg",
2626
]
2727

2828
[project.urls]
29-
repository = "https://github.com/modern-python/sa-utils"
29+
repository = "https://github.com/modern-python/db-retry"
3030

3131
[dependency-groups]
3232
dev = [

tests/conftest.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77

88
@pytest.fixture
99
async def async_engine() -> typing.AsyncIterator[sa_async.AsyncEngine]:
10-
engine: typing.Final = sa_async.create_async_engine(url=os.getenv("DB_DSN", ""), echo=True, echo_pool=True)
10+
engine: typing.Final = sa_async.create_async_engine(url=os.environ["DB_DSN"], echo=True, echo_pool=True)
1111
try:
1212
yield engine
1313
finally:

0 commit comments

Comments
 (0)