diff --git a/src/mlpa/core/config.py b/src/mlpa/core/config.py index b3de32e..ae10d9d 100644 --- a/src/mlpa/core/config.py +++ b/src/mlpa/core/config.py @@ -281,6 +281,25 @@ def valid_service_type_for_model(self, service_type: str, model: str) -> bool: PG_POOL_MIN_SIZE: int = 1 PG_POOL_MAX_SIZE: int = 10 PG_PREPARED_STMT_CACHE_MAX_SIZE: int = 100 + # Postgres query timeouts. Server-enforced (statement_timeout) so a runaway + # query is killed even if the client/event loop hangs, preventing connection + # pile-up. Values are milliseconds; 0 = unlimited (Postgres semantics). + PG_STATEMENT_TIMEOUT_MS: int = 3000 + # Reaps transactions left idle between statements (releases held locks). + # Should be >= statement_timeout since a tx legitimately spans round-trips. + PG_IDLE_IN_TX_TIMEOUT_MS: int = 10000 + # Raised budget for heavy startup work (capacity reconciliation), applied + # per-transaction via SET LOCAL. 0 = unlimited. + PG_MAINTENANCE_STATEMENT_TIMEOUT_MS: int = 30000 + # Raised budget for client-facing admin reads that do unindexable full-table + # scans (user listing, counts-by-service-type). 0 = unlimited. + PG_ADMIN_READ_TIMEOUT_MS: int = 15000 + # Optional asyncpg client-side backstop (seconds). None = disabled. + # WARNING: this is a pool-level client-side cancel that is NOT relaxed by the + # per-transaction SET LOCAL statement_timeout. If enabled, set it above the + # largest per-statement budget (PG_MAINTENANCE_STATEMENT_TIMEOUT_MS), or it + # will silently cancel the maintenance/admin-read queries. + PG_COMMAND_TIMEOUT_S: float | None = None # LLM request default values TEMPERATURE: float = 0.1 diff --git a/src/mlpa/core/pg_services/app_attest_pg_service.py b/src/mlpa/core/pg_services/app_attest_pg_service.py index 1d81ff6..707cb28 100644 --- a/src/mlpa/core/pg_services/app_attest_pg_service.py +++ b/src/mlpa/core/pg_services/app_attest_pg_service.py @@ -102,14 +102,14 @@ async def delete_key(self, key_id_b64: str): async def ensure_capacity_state(self) -> None: """ - Ensure the singleton capacity row and base-identity claim table exist. + Seed the singleton capacity row, then reconcile the claim table. - Reconciles the claim table with current LiteLLM end-user rows for - cap-managed service types on every startup (blocked rows included) so - the counter reflects reality after external writes and config changes. + The seed is critical and fatal on failure: without the row every + admission 500s, so a failure should crash startup rather than serve + broken. Reconciliation is best-effort (see _reconcile_capacity_claims): + if it fails the row still exists with a stale count and admissions work. """ - managed_service_types = list(env.MLPA_CAPPED_SERVICE_TYPES) - + # Seed the singleton row (fatal on failure). async with self.pool.acquire() as conn: async with conn.transaction(): await conn.execute( @@ -120,12 +120,6 @@ async def ensure_capacity_state(self) -> None: """, env.MLPA_MAX_SIGNED_IN_USERS, ) - - # Serialize seeding and reconciliation so concurrent app startups - # do not race on the claim table. - await conn.fetchrow( - "SELECT 1 FROM mlpa_user_capacity WHERE id = 1 FOR UPDATE" - ) await conn.execute( """ UPDATE mlpa_user_capacity @@ -136,35 +130,62 @@ async def ensure_capacity_state(self) -> None: env.MLPA_MAX_SIGNED_IN_USERS, ) - # Rebuild claims from LiteLLM so the counter matches reality after deletes - # or manual DB edits. Blocked rows still count toward capacity. - await conn.execute("DELETE FROM mlpa_user_capacity_identities") + # Reconcile the claim table (best-effort). + try: + await self._reconcile_capacity_claims() + except Exception as e: + logger.error( + f"Capacity reconciliation failed; serving with last-known " + f"current_identities count: {e}" + ) - base_identities = await self.litellm_pg.list_managed_base_identities( - managed_service_types - ) - if base_identities: - await conn.executemany( - """ - INSERT INTO mlpa_user_capacity_identities (base_identity) - VALUES ($1) - """, - [(base_identity,) for base_identity in base_identities], - ) - - seeded_claims = await conn.fetchval( - "SELECT COUNT(*) FROM mlpa_user_capacity_identities" - ) - await conn.execute( + async def _reconcile_capacity_claims(self) -> None: + """Rebuild the claim table from LiteLLM and refresh current_identities.""" + managed_service_types = list(env.MLPA_CAPPED_SERVICE_TYPES) + + # Read from the litellm pool before opening the app_attest transaction: + # doing it inside would leave the session idle-in-transaction across a + # cross-pool await, where idle_in_transaction_session_timeout could reap it. + base_identities = await self.litellm_pg.list_managed_base_identities( + managed_service_types + ) + + # Bulk delete + insert scales with the user base and can exceed the tight + # pool-wide statement_timeout. Statements run back-to-back (no inter- + # statement await), so the raised statement_timeout alone suffices. + async with self.statement_timeout( + env.PG_MAINTENANCE_STATEMENT_TIMEOUT_MS + ) as conn: + # Serialize so concurrent app startups do not race on the claim table. + await conn.fetchrow( + "SELECT 1 FROM mlpa_user_capacity WHERE id = 1 FOR UPDATE" + ) + + # Blocked rows still count toward capacity. + await conn.execute("DELETE FROM mlpa_user_capacity_identities") + + if base_identities: + await conn.executemany( """ - UPDATE mlpa_user_capacity - SET current_identities = $1, - updated_at = NOW() - WHERE id = 1 + INSERT INTO mlpa_user_capacity_identities (base_identity) + VALUES ($1) """, - seeded_claims, + [(base_identity,) for base_identity in base_identities], ) + seeded_claims = await conn.fetchval( + "SELECT COUNT(*) FROM mlpa_user_capacity_identities" + ) + await conn.execute( + """ + UPDATE mlpa_user_capacity + SET current_identities = $1, + updated_at = NOW() + WHERE id = 1 + """, + seeded_claims, + ) + async def admit_managed_base_identity( self, base_identity: str ) -> tuple[bool, bool]: @@ -177,57 +198,53 @@ async def admit_managed_base_identity( if not env.MLPA_ENFORCE_SIGNIN_CAP: return True, False - async with self.pool.acquire() as conn: - async with conn.transaction(): - await conn.execute( - f"SET LOCAL lock_timeout = '{env.MLPA_ADMISSION_LOCK_TIMEOUT_MS}ms'" - ) - capacity_row = await conn.fetchrow( - """ - SELECT max_identities, current_identities - FROM mlpa_user_capacity - WHERE id = 1 - FOR UPDATE - """ + async with self.admission_transaction() as conn: + capacity_row = await conn.fetchrow( + """ + SELECT max_identities, current_identities + FROM mlpa_user_capacity + WHERE id = 1 + FOR UPDATE + """ + ) + if capacity_row is None: + raise HTTPException( + status_code=500, + detail="Capacity state not initialized", ) - if capacity_row is None: - raise HTTPException( - status_code=500, - detail="Capacity state not initialized", - ) - already_claimed = await conn.fetchval( - """ - SELECT 1 - FROM mlpa_user_capacity_identities - WHERE base_identity = $1 - """, - base_identity, - ) - if already_claimed: - return True, False + already_claimed = await conn.fetchval( + """ + SELECT 1 + FROM mlpa_user_capacity_identities + WHERE base_identity = $1 + """, + base_identity, + ) + if already_claimed: + return True, False - max_identities = int(capacity_row["max_identities"]) - current_identities = int(capacity_row["current_identities"]) - if current_identities >= max_identities: - return False, False + max_identities = int(capacity_row["max_identities"]) + current_identities = int(capacity_row["current_identities"]) + if current_identities >= max_identities: + return False, False - await conn.execute( - """ - INSERT INTO mlpa_user_capacity_identities (base_identity) - VALUES ($1) - """, - base_identity, - ) - await conn.execute( - """ - UPDATE mlpa_user_capacity - SET current_identities = current_identities + 1, - updated_at = NOW() - WHERE id = 1 - """ - ) - return True, True + await conn.execute( + """ + INSERT INTO mlpa_user_capacity_identities (base_identity) + VALUES ($1) + """, + base_identity, + ) + await conn.execute( + """ + UPDATE mlpa_user_capacity + SET current_identities = current_identities + 1, + updated_at = NOW() + WHERE id = 1 + """ + ) + return True, True async def maybe_release_managed_base_identity_if_no_managed_users( self, base_identity: str @@ -241,55 +258,55 @@ async def maybe_release_managed_base_identity_if_no_managed_users( managed_service_types = list(env.MLPA_CAPPED_SERVICE_TYPES) - async with self.pool.acquire() as conn: - async with conn.transaction(): - await conn.execute( - f"SET LOCAL lock_timeout = '{env.MLPA_ADMISSION_LOCK_TIMEOUT_MS}ms'" - ) - capacity_row = await conn.fetchrow( - """ - SELECT max_identities, current_identities - FROM mlpa_user_capacity - WHERE id = 1 - FOR UPDATE - """ - ) - if capacity_row is None: - return + # Read the litellm state before opening the app_attest transaction: doing + # it inside would hold the FOR UPDATE lock idle-in-transaction across a + # cross-pool await, where idle_in_transaction_session_timeout could reap + # it and abort the release, leaking the claim (mirrors ensure_capacity_state). + has_managed_user_rows = await self.litellm_pg.has_managed_user_rows( + base_identity, + managed_service_types, + ) + if has_managed_user_rows: + return - claimed = await conn.fetchval( - """ - SELECT 1 - FROM mlpa_user_capacity_identities - WHERE base_identity = $1 - """, - base_identity, - ) - if not claimed: - return + async with self.admission_transaction() as conn: + capacity_row = await conn.fetchrow( + """ + SELECT max_identities, current_identities + FROM mlpa_user_capacity + WHERE id = 1 + FOR UPDATE + """ + ) + if capacity_row is None: + return - has_managed_user_rows = await self.litellm_pg.has_managed_user_rows( - base_identity, - managed_service_types, - ) - if has_managed_user_rows: - return + claimed = await conn.fetchval( + """ + SELECT 1 + FROM mlpa_user_capacity_identities + WHERE base_identity = $1 + """, + base_identity, + ) + if not claimed: + return - await conn.execute( - """ - DELETE FROM mlpa_user_capacity_identities - WHERE base_identity = $1 - """, - base_identity, - ) - await conn.execute( - """ - UPDATE mlpa_user_capacity - SET current_identities = GREATEST(current_identities - 1, 0), - updated_at = NOW() - WHERE id = 1 - """ - ) + await conn.execute( + """ + DELETE FROM mlpa_user_capacity_identities + WHERE base_identity = $1 + """, + base_identity, + ) + await conn.execute( + """ + UPDATE mlpa_user_capacity + SET current_identities = GREATEST(current_identities - 1, 0), + updated_at = NOW() + WHERE id = 1 + """ + ) async def get_signup_cap_status(self) -> dict: """ diff --git a/src/mlpa/core/pg_services/litellm_pg_service.py b/src/mlpa/core/pg_services/litellm_pg_service.py index 3259154..5911664 100644 --- a/src/mlpa/core/pg_services/litellm_pg_service.py +++ b/src/mlpa/core/pg_services/litellm_pg_service.py @@ -77,14 +77,17 @@ async def block_user(self, user_id: str, blocked: bool = True) -> dict: async def list_users(self, limit: int = 50, offset: int = 0) -> dict: try: - total = await self.pool.fetchval( - 'SELECT COUNT(*) FROM "LiteLLM_EndUserTable"' - ) - users = await self.pool.fetch( - 'SELECT * FROM "LiteLLM_EndUserTable" ORDER BY user_id LIMIT $1 OFFSET $2', - limit, - offset, - ) + # COUNT(*) + deep OFFSET scan the full table; admin-read budget + # rather than the tight pool-wide default. + async with self.statement_timeout(env.PG_ADMIN_READ_TIMEOUT_MS) as conn: + total = await conn.fetchval( + 'SELECT COUNT(*) FROM "LiteLLM_EndUserTable"' + ) + users = await conn.fetch( + 'SELECT * FROM "LiteLLM_EndUserTable" ORDER BY user_id LIMIT $1 OFFSET $2', + limit, + offset, + ) return { "users": [dict(user) for user in users], @@ -106,16 +109,19 @@ async def count_users_by_service_type(self) -> dict: `{base_user_id}:{service_type}`. """ try: - rows = await self.pool.fetch( - """ - SELECT - split_part(user_id, ':', 2) AS service_type, - COUNT(*)::int AS total_users - FROM "LiteLLM_EndUserTable" - WHERE position(':' in user_id) > 0 - GROUP BY service_type - """ - ) + # GROUP BY split_part(...) is unindexable, so always a full-table + # scan; admin-read budget rather than the tight pool-wide default. + async with self.statement_timeout(env.PG_ADMIN_READ_TIMEOUT_MS) as conn: + rows = await conn.fetch( + """ + SELECT + split_part(user_id, ':', 2) AS service_type, + COUNT(*)::int AS total_users + FROM "LiteLLM_EndUserTable" + WHERE position(':' in user_id) > 0 + GROUP BY service_type + """ + ) service_type_counts: dict[str, int] = {} for row in rows: @@ -141,16 +147,23 @@ async def list_managed_base_identities( ) -> list[str]: """ Return distinct base identities for cap-managed service types. + + The DISTINCT scan over the full end-user table can exceed the tight + pool-wide statement_timeout on a large user base, so it runs under the + maintenance budget (startup reconciliation work). """ - rows = await self.pool.fetch( - """ - SELECT DISTINCT split_part(user_id, ':', 1) AS base_identity - FROM "LiteLLM_EndUserTable" - WHERE position(':' in user_id) > 0 - AND split_part(user_id, ':', 2) = ANY($1::text[]) - """, - managed_service_types, - ) + async with self.statement_timeout( + env.PG_MAINTENANCE_STATEMENT_TIMEOUT_MS + ) as conn: + rows = await conn.fetch( + """ + SELECT DISTINCT split_part(user_id, ':', 1) AS base_identity + FROM "LiteLLM_EndUserTable" + WHERE position(':' in user_id) > 0 + AND split_part(user_id, ':', 2) = ANY($1::text[]) + """, + managed_service_types, + ) return [row["base_identity"] for row in rows if row.get("base_identity")] async def has_managed_user_rows( @@ -187,6 +200,8 @@ async def create_budget(self): for service_type, budget_config in user_feature_budgets.items(): try: + # Fast single-row PK upsert: stays a plain autocommit call (it + # cannot realistically hit the pool-wide statement_timeout). await self.pool.fetchrow( """ INSERT INTO "LiteLLM_BudgetTable" diff --git a/src/mlpa/core/pg_services/pg_service.py b/src/mlpa/core/pg_services/pg_service.py index b7d3ee2..e8fabc4 100644 --- a/src/mlpa/core/pg_services/pg_service.py +++ b/src/mlpa/core/pg_services/pg_service.py @@ -1,4 +1,5 @@ import sys +from contextlib import asynccontextmanager from typing import cast import asyncpg @@ -10,9 +11,14 @@ class PGService: pg: asyncpg.Pool | None - def __init__(self, db_name: str): + def __init__(self, db_name: str, statement_timeout_ms: int | None = None): self.db_name = db_name self.db_url = f"{cast(str, env.PG_DB_URL).rstrip('/')}/{db_name}" + self.statement_timeout_ms = ( + statement_timeout_ms + if statement_timeout_ms is not None + else env.PG_STATEMENT_TIMEOUT_MS + ) self.connected = False self.pg = None @@ -22,11 +28,22 @@ def pool(self) -> asyncpg.Pool: async def connect(self): try: + # asyncpg re-applies server_settings on every reconnect, so these + # are durable for the pool's lifetime. Values are ms-integer strings. + server_settings = { + "statement_timeout": str(self.statement_timeout_ms), + "idle_in_transaction_session_timeout": str( + env.PG_IDLE_IN_TX_TIMEOUT_MS + ), + "application_name": f"mlpa:{self.db_name}", + } self.pg = await asyncpg.create_pool( self.db_url, min_size=env.PG_POOL_MIN_SIZE, max_size=env.PG_POOL_MAX_SIZE, statement_cache_size=env.PG_PREPARED_STMT_CACHE_MAX_SIZE, + server_settings=server_settings, + command_timeout=env.PG_COMMAND_TIMEOUT_S, ) self.connected = True logger.info(f"Connected to /{self.db_name}") @@ -40,6 +57,56 @@ async def disconnect(self): await self.pg.close() self.connected = False + @asynccontextmanager + async def _timed_transaction( + self, + statement_timeout_ms: int, + idle_in_tx_timeout_ms: int | None = None, + lock_timeout_ms: int | None = None, + ): + """ + Yield a connection in a transaction with statement_timeout (and + optionally idle_in_transaction_session_timeout / lock_timeout) set via + SET LOCAL, scoped to the transaction so the connection reverts to the + pool-wide defaults on release. Timeout values are config ints, not input. + """ + async with self.pool.acquire() as conn: + async with conn.transaction(): + await conn.execute( + f"SET LOCAL statement_timeout = '{statement_timeout_ms}'" + ) + if idle_in_tx_timeout_ms is not None: + await conn.execute( + f"SET LOCAL idle_in_transaction_session_timeout = '{idle_in_tx_timeout_ms}'" + ) + if lock_timeout_ms is not None: + await conn.execute( + f"SET LOCAL lock_timeout = '{lock_timeout_ms}ms'" + ) + yield conn + + @asynccontextmanager + async def statement_timeout(self, timeout_ms: int): + """ + Raise statement_timeout for statements that legitimately exceed the + tight pool-wide default (e.g. unindexable full-table scans). + """ + async with self._timed_transaction(timeout_ms) as conn: + yield conn + + @asynccontextmanager + async def admission_transaction(self): + """ + Signup-capacity admission path: a bounded lock_timeout for the FOR UPDATE + on the singleton capacity row, plus a statement_timeout set above it so + the lock wait is governed by lock_timeout rather than silently capped by + the pool-wide statement_timeout (Postgres counts lock-wait toward it). + """ + lock_ms = env.MLPA_ADMISSION_LOCK_TIMEOUT_MS + stmt_ms = lock_ms + env.PG_STATEMENT_TIMEOUT_MS + async with self._timed_transaction(stmt_ms, lock_timeout_ms=lock_ms) as conn: + yield conn + def check_status(self): if self.pg is None or not self.connected: return False diff --git a/src/tests/unit/test_pg_timeouts.py b/src/tests/unit/test_pg_timeouts.py new file mode 100644 index 0000000..e8be16d --- /dev/null +++ b/src/tests/unit/test_pg_timeouts.py @@ -0,0 +1,261 @@ +import os +from unittest.mock import AsyncMock, MagicMock, patch + +from mlpa.core.config import Env, env +from mlpa.core.pg_services.app_attest_pg_service import AppAttestPGService +from mlpa.core.pg_services.pg_service import PGService + + +def test_pg_timeout_config_from_env(): + """PG timeout settings are overridable via environment variables.""" + env_vars = { + "PG_STATEMENT_TIMEOUT_MS": "5000", + "PG_IDLE_IN_TX_TIMEOUT_MS": "15000", + "PG_MAINTENANCE_STATEMENT_TIMEOUT_MS": "45000", + "PG_ADMIN_READ_TIMEOUT_MS": "20000", + "PG_COMMAND_TIMEOUT_S": "6.5", + } + + with patch.dict(os.environ, env_vars): + test_env = Env() + + assert test_env.PG_STATEMENT_TIMEOUT_MS == 5000 + assert test_env.PG_IDLE_IN_TX_TIMEOUT_MS == 15000 + assert test_env.PG_MAINTENANCE_STATEMENT_TIMEOUT_MS == 45000 + assert test_env.PG_ADMIN_READ_TIMEOUT_MS == 20000 + assert test_env.PG_COMMAND_TIMEOUT_S == 6.5 + + +def test_pg_timeout_defaults(): + """Sane defaults: tight statement timeout, larger idle-in-tx reaper, no client backstop.""" + assert env.PG_STATEMENT_TIMEOUT_MS == 3000 + assert env.PG_IDLE_IN_TX_TIMEOUT_MS == 10000 + assert env.PG_MAINTENANCE_STATEMENT_TIMEOUT_MS == 30000 + assert env.PG_ADMIN_READ_TIMEOUT_MS == 15000 + assert env.PG_COMMAND_TIMEOUT_S is None + + +async def test_connect_passes_timeout_server_settings(mocker): + """The pool is created with server-enforced statement / idle-in-tx timeouts.""" + create_pool = mocker.patch( + "mlpa.core.pg_services.pg_service.asyncpg.create_pool", + new=AsyncMock(return_value=MagicMock()), + ) + + service = PGService("some_db") + await service.connect() + + _, kwargs = create_pool.call_args + server_settings = kwargs["server_settings"] + assert server_settings["statement_timeout"] == str(env.PG_STATEMENT_TIMEOUT_MS) + assert server_settings["idle_in_transaction_session_timeout"] == str( + env.PG_IDLE_IN_TX_TIMEOUT_MS + ) + assert server_settings["application_name"] == "mlpa:some_db" + assert kwargs["command_timeout"] == env.PG_COMMAND_TIMEOUT_S + + +async def test_connect_respects_per_pool_statement_timeout_override(mocker): + """A subclass/per-pool override flows into server_settings.""" + create_pool = mocker.patch( + "mlpa.core.pg_services.pg_service.asyncpg.create_pool", + new=AsyncMock(return_value=MagicMock()), + ) + + service = PGService("some_db", statement_timeout_ms=1234) + await service.connect() + + _, kwargs = create_pool.call_args + assert kwargs["server_settings"]["statement_timeout"] == "1234" + + +def _mock_maintenance_conn(): + conn = MagicMock() + conn.execute = AsyncMock() + conn.executemany = AsyncMock() + conn.fetch = AsyncMock(return_value=[]) + conn.fetchrow = AsyncMock(return_value={"?column?": 1}) + conn.fetchval = AsyncMock(return_value=0) + conn.transaction.return_value.__aenter__ = AsyncMock(return_value=None) + conn.transaction.return_value.__aexit__ = AsyncMock(return_value=None) + return conn + + +def _mock_pool(conn): + acquire_cm = MagicMock() + acquire_cm.__aenter__ = AsyncMock(return_value=conn) + acquire_cm.__aexit__ = AsyncMock(return_value=None) + pool = MagicMock() + pool.acquire.return_value = acquire_cm + return pool + + +def _set_local_calls(conn, guc): + return [ + call.args[0] + for call in conn.execute.call_args_list + if call.args and f"SET LOCAL {guc}" in call.args[0] + ] + + +async def test_admission_transaction_sets_lock_and_statement_timeout(mocker): + """admission_transaction lifts lock_timeout and a statement_timeout above it.""" + conn = _mock_maintenance_conn() + mocker.patch.object(PGService, "pool", new=_mock_pool(conn)) + + service = PGService("some_db") + async with service.admission_transaction() as yielded: + assert yielded is conn + + assert _set_local_calls(conn, "lock_timeout") + stmt_calls = _set_local_calls(conn, "statement_timeout") + assert stmt_calls + # statement_timeout must exceed lock_timeout so the lock wait is governed by + # lock_timeout, not silently capped by the tight pool-wide statement_timeout. + expected = env.MLPA_ADMISSION_LOCK_TIMEOUT_MS + env.PG_STATEMENT_TIMEOUT_MS + assert str(expected) in stmt_calls[0] + + +async def test_statement_timeout_sets_only_statement_timeout(mocker): + """statement_timeout() lifts statement_timeout but not idle-in-tx (read-only helper).""" + conn = _mock_maintenance_conn() + mocker.patch.object(PGService, "pool", new=_mock_pool(conn)) + + service = PGService("some_db") + async with service.statement_timeout(15000) as yielded: + assert yielded is conn + + stmt_calls = _set_local_calls(conn, "statement_timeout") + assert stmt_calls + assert "15000" in stmt_calls[0] + assert not _set_local_calls(conn, "idle_in_transaction_session_timeout") + + +async def test_count_users_by_service_type_uses_admin_read_timeout(mocker): + """The unindexable full-table GROUP BY runs under the admin-read timeout, not 3s.""" + from mlpa.core.pg_services.litellm_pg_service import LiteLLMPGService + + conn = _mock_maintenance_conn() + mocker.patch.object(PGService, "pool", new=_mock_pool(conn)) + + service = LiteLLMPGService() + await service.count_users_by_service_type() + + timeout_calls = _set_local_calls(conn, "statement_timeout") + assert timeout_calls + assert str(env.PG_ADMIN_READ_TIMEOUT_MS) in timeout_calls[0] + conn.fetch.assert_awaited_once() + + +async def test_list_users_uses_admin_read_timeout(mocker): + """The full-table COUNT(*) + deep OFFSET page run under the admin-read timeout.""" + from mlpa.core.pg_services.litellm_pg_service import LiteLLMPGService + + conn = _mock_maintenance_conn() + mocker.patch.object(PGService, "pool", new=_mock_pool(conn)) + + service = LiteLLMPGService() + await service.list_users() + + timeout_calls = _set_local_calls(conn, "statement_timeout") + assert timeout_calls + assert str(env.PG_ADMIN_READ_TIMEOUT_MS) in timeout_calls[0] + + +async def test_list_managed_base_identities_uses_maintenance_timeout(mocker): + """The heavy reconciliation read runs under the maintenance timeout, not the 3s default.""" + from mlpa.core.pg_services.litellm_pg_service import LiteLLMPGService + + conn = _mock_maintenance_conn() + mocker.patch.object(PGService, "pool", new=_mock_pool(conn)) + + service = LiteLLMPGService() + await service.list_managed_base_identities(["ai"]) + + timeout_calls = _set_local_calls(conn, "statement_timeout") + assert timeout_calls + assert str(env.PG_MAINTENANCE_STATEMENT_TIMEOUT_MS) in timeout_calls[0] + conn.fetch.assert_awaited_once() + + +async def test_ensure_capacity_state_reads_identities_before_reconcile_transaction( + mocker, +): + """The cross-pool read must not be issued inside the reconcile transaction. + + The session must never sit idle-in-transaction across the cross-pool await, + so the litellm read has to land AFTER the (self-contained) seed transaction + and BEFORE the destructive claim rebuild (DELETE ...). + """ + order: list[str] = [] + + litellm_pg = MagicMock() + + async def _list(*_args, **_kwargs): + order.append("read") + return [] + + litellm_pg.list_managed_base_identities = _list + + conn = _mock_maintenance_conn() + + async def _execute(sql, *_args, **_kwargs): + order.append(f"exec:{sql.strip()}") + + conn.execute = AsyncMock(side_effect=_execute) + mocker.patch.object(PGService, "pool", new=_mock_pool(conn)) + + service = AppAttestPGService(litellm_pg) + await service.ensure_capacity_state() + + assert order.count("read") == 1 + read_idx = order.index("read") + + # The claim rebuild (DELETE) — the first statement of the reconcile + # transaction — must come strictly after the cross-pool read. + delete_idx = next( + i for i, step in enumerate(order) if step.startswith("exec:DELETE") + ) + assert read_idx < delete_idx + + # The seed (INSERT/UPDATE on mlpa_user_capacity) runs in its own transaction + # and commits before the read, so no cross-pool await spans an open tx. + assert any( + step.startswith("exec:INSERT INTO mlpa_user_capacity ") for step in order + ) + + +async def test_ensure_capacity_state_raises_maintenance_timeout(mocker): + """Reconciliation lifts the tight pool timeout via SET LOCAL for its transaction.""" + litellm_pg = MagicMock() + litellm_pg.list_managed_base_identities = AsyncMock(return_value=[]) + + service = AppAttestPGService(litellm_pg) + + conn = MagicMock() + conn.execute = AsyncMock() + conn.executemany = AsyncMock() + conn.fetchrow = AsyncMock(return_value={"?column?": 1}) + conn.fetchval = AsyncMock(return_value=0) + + # async context managers for pool.acquire() and conn.transaction() + conn.transaction.return_value.__aenter__ = AsyncMock(return_value=None) + conn.transaction.return_value.__aexit__ = AsyncMock(return_value=None) + + acquire_cm = MagicMock() + acquire_cm.__aenter__ = AsyncMock(return_value=conn) + acquire_cm.__aexit__ = AsyncMock(return_value=None) + + pool = MagicMock() + pool.acquire.return_value = acquire_cm + mocker.patch.object(PGService, "pool", new=pool) + + await service.ensure_capacity_state() + + set_local_calls = [ + call.args[0] + for call in conn.execute.call_args_list + if call.args and "SET LOCAL statement_timeout" in call.args[0] + ] + assert set_local_calls, "expected a SET LOCAL statement_timeout in reconciliation" + assert str(env.PG_MAINTENANCE_STATEMENT_TIMEOUT_MS) in set_local_calls[0]