From 8703c21a49ffc5e2fa61582d11be0dfec98aab89 Mon Sep 17 00:00:00 2001 From: Alexander Osipenko <11722602+subpath@users.noreply.github.com> Date: Tue, 16 Jun 2026 10:41:14 +0200 Subject: [PATCH 1/2] feat: add Postgres query timeouts (AIPLAT-921) --- src/mlpa/core/config.py | 13 ++ .../core/pg_services/app_attest_pg_service.py | 104 +++++----- .../core/pg_services/litellm_pg_service.py | 70 ++++--- src/mlpa/core/pg_services/pg_service.py | 43 +++- src/mlpa/run.py | 15 +- src/tests/unit/test_pg_timeouts.py | 190 ++++++++++++++++++ 6 files changed, 354 insertions(+), 81 deletions(-) create mode 100644 src/tests/unit/test_pg_timeouts.py diff --git a/src/mlpa/core/config.py b/src/mlpa/core/config.py index b3de32e..da3a83d 100644 --- a/src/mlpa/core/config.py +++ b/src/mlpa/core/config.py @@ -281,6 +281,19 @@ def valid_service_type_for_model(self, service_type: str, model: str) -> bool: PG_POOL_MIN_SIZE: int = 1 PG_POOL_MAX_SIZE: int = 10 PG_PREPARED_STMT_CACHE_MAX_SIZE: int = 100 + # Postgres query timeouts. Server-enforced (statement_timeout) so a runaway + # query is killed even if the client/event loop hangs, preventing connection + # pile-up. Values are milliseconds; 0 = unlimited (Postgres semantics). + PG_STATEMENT_TIMEOUT_MS: int = 3000 + # Reaps transactions left idle between statements (releases held locks). + # Should be >= statement_timeout since a tx legitimately spans round-trips. + PG_IDLE_IN_TX_TIMEOUT_MS: int = 10000 + # Override for known-heavy startup maintenance (capacity reconciliation), + # applied per-transaction via SET LOCAL. 0 = unlimited. + PG_MAINTENANCE_STATEMENT_TIMEOUT_MS: int = 30000 + # Optional asyncpg client-side backstop (seconds). None = disabled; set + # slightly above PG_STATEMENT_TIMEOUT_MS if you want belt-and-suspenders. + PG_COMMAND_TIMEOUT_S: float | None = None # LLM request default values TEMPERATURE: float = 0.1 diff --git a/src/mlpa/core/pg_services/app_attest_pg_service.py b/src/mlpa/core/pg_services/app_attest_pg_service.py index 1d81ff6..741f6f2 100644 --- a/src/mlpa/core/pg_services/app_attest_pg_service.py +++ b/src/mlpa/core/pg_services/app_attest_pg_service.py @@ -110,61 +110,69 @@ async def ensure_capacity_state(self) -> None: """ managed_service_types = list(env.MLPA_CAPPED_SERVICE_TYPES) - async with self.pool.acquire() as conn: - async with conn.transaction(): - await conn.execute( - """ - INSERT INTO mlpa_user_capacity (id, max_identities, current_identities) - VALUES (1, $1, 0) - ON CONFLICT (id) DO NOTHING - """, - env.MLPA_MAX_SIGNED_IN_USERS, - ) - - # Serialize seeding and reconciliation so concurrent app startups - # do not race on the claim table. - await conn.fetchrow( - "SELECT 1 FROM mlpa_user_capacity WHERE id = 1 FOR UPDATE" - ) - await conn.execute( - """ - UPDATE mlpa_user_capacity - SET max_identities = $1, - updated_at = NOW() - WHERE id = 1 - """, - env.MLPA_MAX_SIGNED_IN_USERS, - ) + # Read cap-managed identities from the *litellm* pool before opening the + # app_attest transaction. Doing this inside the transaction would leave + # the app_attest session idle-in-transaction across a cross-pool await, + # where idle_in_transaction_session_timeout could reap it mid-rebuild. + base_identities = await self.litellm_pg.list_managed_base_identities( + managed_service_types + ) + + # Reconciliation rebuilds the entire claim table (bulk delete + insert of + # all base identities), which can exceed the tight pool-wide + # statement_timeout as the user base grows; maintenance_transaction + # raises both timeouts for this work. + async with self.maintenance_transaction() as conn: + await conn.execute( + """ + INSERT INTO mlpa_user_capacity (id, max_identities, current_identities) + VALUES (1, $1, 0) + ON CONFLICT (id) DO NOTHING + """, + env.MLPA_MAX_SIGNED_IN_USERS, + ) - # Rebuild claims from LiteLLM so the counter matches reality after deletes - # or manual DB edits. Blocked rows still count toward capacity. - await conn.execute("DELETE FROM mlpa_user_capacity_identities") + # Serialize seeding and reconciliation so concurrent app startups + # do not race on the claim table. + await conn.fetchrow( + "SELECT 1 FROM mlpa_user_capacity WHERE id = 1 FOR UPDATE" + ) + await conn.execute( + """ + UPDATE mlpa_user_capacity + SET max_identities = $1, + updated_at = NOW() + WHERE id = 1 + """, + env.MLPA_MAX_SIGNED_IN_USERS, + ) - base_identities = await self.litellm_pg.list_managed_base_identities( - managed_service_types - ) - if base_identities: - await conn.executemany( - """ - INSERT INTO mlpa_user_capacity_identities (base_identity) - VALUES ($1) - """, - [(base_identity,) for base_identity in base_identities], - ) + # Rebuild claims from LiteLLM so the counter matches reality after deletes + # or manual DB edits. Blocked rows still count toward capacity. + await conn.execute("DELETE FROM mlpa_user_capacity_identities") - seeded_claims = await conn.fetchval( - "SELECT COUNT(*) FROM mlpa_user_capacity_identities" - ) - await conn.execute( + if base_identities: + await conn.executemany( """ - UPDATE mlpa_user_capacity - SET current_identities = $1, - updated_at = NOW() - WHERE id = 1 + INSERT INTO mlpa_user_capacity_identities (base_identity) + VALUES ($1) """, - seeded_claims, + [(base_identity,) for base_identity in base_identities], ) + seeded_claims = await conn.fetchval( + "SELECT COUNT(*) FROM mlpa_user_capacity_identities" + ) + await conn.execute( + """ + UPDATE mlpa_user_capacity + SET current_identities = $1, + updated_at = NOW() + WHERE id = 1 + """, + seeded_claims, + ) + async def admit_managed_base_identity( self, base_identity: str ) -> tuple[bool, bool]: diff --git a/src/mlpa/core/pg_services/litellm_pg_service.py b/src/mlpa/core/pg_services/litellm_pg_service.py index 3259154..8055bb0 100644 --- a/src/mlpa/core/pg_services/litellm_pg_service.py +++ b/src/mlpa/core/pg_services/litellm_pg_service.py @@ -141,16 +141,21 @@ async def list_managed_base_identities( ) -> list[str]: """ Return distinct base identities for cap-managed service types. + + Runs under the maintenance timeout: the DISTINCT scan over the full + end-user table is startup reconciliation work that can exceed the tight + pool-wide statement_timeout on a large user base. """ - rows = await self.pool.fetch( - """ - SELECT DISTINCT split_part(user_id, ':', 1) AS base_identity - FROM "LiteLLM_EndUserTable" - WHERE position(':' in user_id) > 0 - AND split_part(user_id, ':', 2) = ANY($1::text[]) - """, - managed_service_types, - ) + async with self.maintenance_transaction() as conn: + rows = await conn.fetch( + """ + SELECT DISTINCT split_part(user_id, ':', 1) AS base_identity + FROM "LiteLLM_EndUserTable" + WHERE position(':' in user_id) > 0 + AND split_part(user_id, ':', 2) = ANY($1::text[]) + """, + managed_service_types, + ) return [row["base_identity"] for row in rows if row.get("base_identity")] async def has_managed_user_rows( @@ -187,27 +192,32 @@ async def create_budget(self): for service_type, budget_config in user_feature_budgets.items(): try: - await self.pool.fetchrow( - """ - INSERT INTO "LiteLLM_BudgetTable" - (budget_id, max_budget, rpm_limit, tpm_limit, budget_duration, created_at, updated_at, created_by, updated_by) - VALUES ($1, $2, $3, $4, $5, NOW(), NOW(), $6, $6) - ON CONFLICT (budget_id) DO UPDATE SET - max_budget = EXCLUDED.max_budget, - rpm_limit = EXCLUDED.rpm_limit, - tpm_limit = EXCLUDED.tpm_limit, - budget_duration = EXCLUDED.budget_duration, - updated_at = NOW(), - updated_by = EXCLUDED.updated_by - RETURNING * - """, - budget_config["budget_id"], - budget_config["max_budget"], - budget_config["rpm_limit"], - budget_config["tpm_limit"], - budget_config["budget_duration"], - "default_user_id", - ) + # Each upsert runs in its own maintenance transaction so a single + # failure rolls back only that budget (and is logged below) while + # the loop continues, and so a slow cold-start upsert is not + # cancelled by the tight pool-wide statement_timeout. + async with self.maintenance_transaction() as conn: + await conn.fetchrow( + """ + INSERT INTO "LiteLLM_BudgetTable" + (budget_id, max_budget, rpm_limit, tpm_limit, budget_duration, created_at, updated_at, created_by, updated_by) + VALUES ($1, $2, $3, $4, $5, NOW(), NOW(), $6, $6) + ON CONFLICT (budget_id) DO UPDATE SET + max_budget = EXCLUDED.max_budget, + rpm_limit = EXCLUDED.rpm_limit, + tpm_limit = EXCLUDED.tpm_limit, + budget_duration = EXCLUDED.budget_duration, + updated_at = NOW(), + updated_by = EXCLUDED.updated_by + RETURNING * + """, + budget_config["budget_id"], + budget_config["max_budget"], + budget_config["rpm_limit"], + budget_config["tpm_limit"], + budget_config["budget_duration"], + "default_user_id", + ) logger.info( f"Budget created/updated: budget_id={budget_config['budget_id']}, " f"service_type={service_type}, max_budget={budget_config['max_budget']}" diff --git a/src/mlpa/core/pg_services/pg_service.py b/src/mlpa/core/pg_services/pg_service.py index b7d3ee2..f813fd7 100644 --- a/src/mlpa/core/pg_services/pg_service.py +++ b/src/mlpa/core/pg_services/pg_service.py @@ -1,4 +1,5 @@ import sys +from contextlib import asynccontextmanager from typing import cast import asyncpg @@ -10,9 +11,16 @@ class PGService: pg: asyncpg.Pool | None - def __init__(self, db_name: str): + def __init__(self, db_name: str, statement_timeout_ms: int | None = None): self.db_name = db_name self.db_url = f"{cast(str, env.PG_DB_URL).rstrip('/')}/{db_name}" + # Pool-wide server-enforced statement timeout. Subclasses (or a future + # third pool) may override per-DB; defaults to the global config value. + self.statement_timeout_ms = ( + statement_timeout_ms + if statement_timeout_ms is not None + else env.PG_STATEMENT_TIMEOUT_MS + ) self.connected = False self.pg = None @@ -22,11 +30,23 @@ def pool(self) -> asyncpg.Pool: async def connect(self): try: + # Applied at connect and automatically re-applied on every reconnect, + # so the timeout is durable across the pool's lifetime without + # touching call sites. Values are passed as bare-integer ms strings. + server_settings = { + "statement_timeout": str(self.statement_timeout_ms), + "idle_in_transaction_session_timeout": str( + env.PG_IDLE_IN_TX_TIMEOUT_MS + ), + "application_name": f"mlpa:{self.db_name}", + } self.pg = await asyncpg.create_pool( self.db_url, min_size=env.PG_POOL_MIN_SIZE, max_size=env.PG_POOL_MAX_SIZE, statement_cache_size=env.PG_PREPARED_STMT_CACHE_MAX_SIZE, + server_settings=server_settings, + command_timeout=env.PG_COMMAND_TIMEOUT_S, ) self.connected = True logger.info(f"Connected to /{self.db_name}") @@ -40,6 +60,27 @@ async def disconnect(self): await self.pg.close() self.connected = False + @asynccontextmanager + async def maintenance_transaction(self): + """ + Yield a connection inside a transaction whose statement and + idle-in-transaction timeouts are raised to the maintenance value. + + For known-heavy startup work (capacity reconciliation, budget upsert) + that legitimately exceeds the tight pool-wide statement_timeout. Both + GUCs are raised via SET LOCAL: statement_timeout for slow statements, + and idle_in_transaction_session_timeout because such work may await + other queries between statements without the session being reaped. + """ + timeout_ms = env.PG_MAINTENANCE_STATEMENT_TIMEOUT_MS + async with self.pool.acquire() as conn: + async with conn.transaction(): + await conn.execute(f"SET LOCAL statement_timeout = '{timeout_ms}'") + await conn.execute( + f"SET LOCAL idle_in_transaction_session_timeout = '{timeout_ms}'" + ) + yield conn + def check_status(self): if self.pg is None or not self.connected: return False diff --git a/src/mlpa/run.py b/src/mlpa/run.py index 6a98a06..b700e13 100644 --- a/src/mlpa/run.py +++ b/src/mlpa/run.py @@ -70,8 +70,19 @@ async def lifespan(app: FastAPI): await app_attest_pg.connect() app_attest_connected = True - await litellm_pg.create_budget() - await app_attest_pg.ensure_capacity_state() + # Startup maintenance (budget upsert, capacity reconciliation) is + # best-effort: it runs under a raised maintenance statement_timeout, but + # if it still fails (e.g. timeout on a very large/slow DB) we log and + # continue rather than crash-loop the whole app. Both paths leave prior + # state intact on failure, so the app serves with last-known config. + try: + await litellm_pg.create_budget() + except Exception as e: + logger.error(f"Startup budget creation failed; continuing: {e}") + try: + await app_attest_pg.ensure_capacity_state() + except Exception as e: + logger.error(f"Startup capacity reconciliation failed; continuing: {e}") yield finally: diff --git a/src/tests/unit/test_pg_timeouts.py b/src/tests/unit/test_pg_timeouts.py new file mode 100644 index 0000000..2cd0c82 --- /dev/null +++ b/src/tests/unit/test_pg_timeouts.py @@ -0,0 +1,190 @@ +import os +from unittest.mock import AsyncMock, MagicMock, patch + +from mlpa.core.config import Env, env +from mlpa.core.pg_services.app_attest_pg_service import AppAttestPGService +from mlpa.core.pg_services.pg_service import PGService + + +def test_pg_timeout_config_from_env(): + """PG timeout settings are overridable via environment variables.""" + env_vars = { + "PG_STATEMENT_TIMEOUT_MS": "5000", + "PG_IDLE_IN_TX_TIMEOUT_MS": "15000", + "PG_MAINTENANCE_STATEMENT_TIMEOUT_MS": "45000", + "PG_COMMAND_TIMEOUT_S": "6.5", + } + + with patch.dict(os.environ, env_vars): + test_env = Env() + + assert test_env.PG_STATEMENT_TIMEOUT_MS == 5000 + assert test_env.PG_IDLE_IN_TX_TIMEOUT_MS == 15000 + assert test_env.PG_MAINTENANCE_STATEMENT_TIMEOUT_MS == 45000 + assert test_env.PG_COMMAND_TIMEOUT_S == 6.5 + + +def test_pg_timeout_defaults(): + """Sane defaults: tight statement timeout, larger idle-in-tx reaper, no client backstop.""" + assert env.PG_STATEMENT_TIMEOUT_MS == 3000 + assert env.PG_IDLE_IN_TX_TIMEOUT_MS == 10000 + assert env.PG_MAINTENANCE_STATEMENT_TIMEOUT_MS == 30000 + assert env.PG_COMMAND_TIMEOUT_S is None + + +async def test_connect_passes_timeout_server_settings(mocker): + """The pool is created with server-enforced statement / idle-in-tx timeouts.""" + create_pool = mocker.patch( + "mlpa.core.pg_services.pg_service.asyncpg.create_pool", + new=AsyncMock(return_value=MagicMock()), + ) + + service = PGService("some_db") + await service.connect() + + _, kwargs = create_pool.call_args + server_settings = kwargs["server_settings"] + assert server_settings["statement_timeout"] == str(env.PG_STATEMENT_TIMEOUT_MS) + assert server_settings["idle_in_transaction_session_timeout"] == str( + env.PG_IDLE_IN_TX_TIMEOUT_MS + ) + assert server_settings["application_name"] == "mlpa:some_db" + assert kwargs["command_timeout"] == env.PG_COMMAND_TIMEOUT_S + + +async def test_connect_respects_per_pool_statement_timeout_override(mocker): + """A subclass/per-pool override flows into server_settings.""" + create_pool = mocker.patch( + "mlpa.core.pg_services.pg_service.asyncpg.create_pool", + new=AsyncMock(return_value=MagicMock()), + ) + + service = PGService("some_db", statement_timeout_ms=1234) + await service.connect() + + _, kwargs = create_pool.call_args + assert kwargs["server_settings"]["statement_timeout"] == "1234" + + +def _mock_maintenance_conn(): + conn = MagicMock() + conn.execute = AsyncMock() + conn.executemany = AsyncMock() + conn.fetch = AsyncMock(return_value=[]) + conn.fetchrow = AsyncMock(return_value={"?column?": 1}) + conn.fetchval = AsyncMock(return_value=0) + conn.transaction.return_value.__aenter__ = AsyncMock(return_value=None) + conn.transaction.return_value.__aexit__ = AsyncMock(return_value=None) + return conn + + +def _mock_pool(conn): + acquire_cm = MagicMock() + acquire_cm.__aenter__ = AsyncMock(return_value=conn) + acquire_cm.__aexit__ = AsyncMock(return_value=None) + pool = MagicMock() + pool.acquire.return_value = acquire_cm + return pool + + +def _set_local_calls(conn, guc): + return [ + call.args[0] + for call in conn.execute.call_args_list + if call.args and f"SET LOCAL {guc}" in call.args[0] + ] + + +async def test_maintenance_transaction_raises_both_timeouts(mocker): + """The helper lifts statement_timeout AND idle_in_transaction_session_timeout.""" + conn = _mock_maintenance_conn() + mocker.patch.object(PGService, "pool", new=_mock_pool(conn)) + + service = PGService("some_db") + async with service.maintenance_transaction() as yielded: + assert yielded is conn + + assert _set_local_calls(conn, "statement_timeout") + assert _set_local_calls(conn, "idle_in_transaction_session_timeout") + + +async def test_list_managed_base_identities_uses_maintenance_timeout(mocker): + """The heavy reconciliation read runs under the maintenance timeout, not the 3s default.""" + from mlpa.core.pg_services.litellm_pg_service import LiteLLMPGService + + conn = _mock_maintenance_conn() + mocker.patch.object(PGService, "pool", new=_mock_pool(conn)) + + service = LiteLLMPGService() + await service.list_managed_base_identities(["ai"]) + + timeout_calls = _set_local_calls(conn, "statement_timeout") + assert timeout_calls + assert str(env.PG_MAINTENANCE_STATEMENT_TIMEOUT_MS) in timeout_calls[0] + conn.fetch.assert_awaited_once() + + +async def test_ensure_capacity_state_reads_identities_before_transaction(mocker): + """The cross-pool read happens before the app_attest transaction is opened.""" + order: list[str] = [] + + litellm_pg = MagicMock() + + async def _list(*_args, **_kwargs): + order.append("read") + return [] + + litellm_pg.list_managed_base_identities = _list + + conn = _mock_maintenance_conn() + + async def _execute(sql, *_args, **_kwargs): + order.append(f"exec:{sql.strip().split()[0]}") + + conn.execute = AsyncMock(side_effect=_execute) + mocker.patch.object(PGService, "pool", new=_mock_pool(conn)) + + service = AppAttestPGService(litellm_pg) + await service.ensure_capacity_state() + + # The litellm read must precede every statement issued inside the app_attest + # transaction, so the session is never idle-in-transaction across that await. + assert order[0] == "read" + assert all(step == "read" or step.startswith("exec:") for step in order) + assert order.count("read") == 1 + + +async def test_ensure_capacity_state_raises_maintenance_timeout(mocker): + """Reconciliation lifts the tight pool timeout via SET LOCAL for its transaction.""" + litellm_pg = MagicMock() + litellm_pg.list_managed_base_identities = AsyncMock(return_value=[]) + + service = AppAttestPGService(litellm_pg) + + conn = MagicMock() + conn.execute = AsyncMock() + conn.executemany = AsyncMock() + conn.fetchrow = AsyncMock(return_value={"?column?": 1}) + conn.fetchval = AsyncMock(return_value=0) + + # async context managers for pool.acquire() and conn.transaction() + conn.transaction.return_value.__aenter__ = AsyncMock(return_value=None) + conn.transaction.return_value.__aexit__ = AsyncMock(return_value=None) + + acquire_cm = MagicMock() + acquire_cm.__aenter__ = AsyncMock(return_value=conn) + acquire_cm.__aexit__ = AsyncMock(return_value=None) + + pool = MagicMock() + pool.acquire.return_value = acquire_cm + mocker.patch.object(PGService, "pool", new=pool) + + await service.ensure_capacity_state() + + set_local_calls = [ + call.args[0] + for call in conn.execute.call_args_list + if call.args and "SET LOCAL statement_timeout" in call.args[0] + ] + assert set_local_calls, "expected a SET LOCAL statement_timeout in reconciliation" + assert str(env.PG_MAINTENANCE_STATEMENT_TIMEOUT_MS) in set_local_calls[0] From 2689b67770f59807ab546363a03461dfa0427d55 Mon Sep 17 00:00:00 2001 From: Alexander Osipenko <11722602+subpath@users.noreply.github.com> Date: Tue, 16 Jun 2026 11:15:04 +0200 Subject: [PATCH 2/2] refactoring --- src/mlpa/core/config.py | 14 +- .../core/pg_services/app_attest_pg_service.py | 265 +++++++++--------- .../core/pg_services/litellm_pg_service.py | 101 +++---- src/mlpa/core/pg_services/pg_service.py | 60 ++-- src/mlpa/run.py | 15 +- src/tests/unit/test_pg_timeouts.py | 95 ++++++- 6 files changed, 328 insertions(+), 222 deletions(-) diff --git a/src/mlpa/core/config.py b/src/mlpa/core/config.py index da3a83d..ae10d9d 100644 --- a/src/mlpa/core/config.py +++ b/src/mlpa/core/config.py @@ -288,11 +288,17 @@ def valid_service_type_for_model(self, service_type: str, model: str) -> bool: # Reaps transactions left idle between statements (releases held locks). # Should be >= statement_timeout since a tx legitimately spans round-trips. PG_IDLE_IN_TX_TIMEOUT_MS: int = 10000 - # Override for known-heavy startup maintenance (capacity reconciliation), - # applied per-transaction via SET LOCAL. 0 = unlimited. + # Raised budget for heavy startup work (capacity reconciliation), applied + # per-transaction via SET LOCAL. 0 = unlimited. PG_MAINTENANCE_STATEMENT_TIMEOUT_MS: int = 30000 - # Optional asyncpg client-side backstop (seconds). None = disabled; set - # slightly above PG_STATEMENT_TIMEOUT_MS if you want belt-and-suspenders. + # Raised budget for client-facing admin reads that do unindexable full-table + # scans (user listing, counts-by-service-type). 0 = unlimited. + PG_ADMIN_READ_TIMEOUT_MS: int = 15000 + # Optional asyncpg client-side backstop (seconds). None = disabled. + # WARNING: this is a pool-level client-side cancel that is NOT relaxed by the + # per-transaction SET LOCAL statement_timeout. If enabled, set it above the + # largest per-statement budget (PG_MAINTENANCE_STATEMENT_TIMEOUT_MS), or it + # will silently cancel the maintenance/admin-read queries. PG_COMMAND_TIMEOUT_S: float | None = None # LLM request default values diff --git a/src/mlpa/core/pg_services/app_attest_pg_service.py b/src/mlpa/core/pg_services/app_attest_pg_service.py index 741f6f2..707cb28 100644 --- a/src/mlpa/core/pg_services/app_attest_pg_service.py +++ b/src/mlpa/core/pg_services/app_attest_pg_service.py @@ -102,53 +102,66 @@ async def delete_key(self, key_id_b64: str): async def ensure_capacity_state(self) -> None: """ - Ensure the singleton capacity row and base-identity claim table exist. + Seed the singleton capacity row, then reconcile the claim table. - Reconciles the claim table with current LiteLLM end-user rows for - cap-managed service types on every startup (blocked rows included) so - the counter reflects reality after external writes and config changes. + The seed is critical and fatal on failure: without the row every + admission 500s, so a failure should crash startup rather than serve + broken. Reconciliation is best-effort (see _reconcile_capacity_claims): + if it fails the row still exists with a stale count and admissions work. """ + # Seed the singleton row (fatal on failure). + async with self.pool.acquire() as conn: + async with conn.transaction(): + await conn.execute( + """ + INSERT INTO mlpa_user_capacity (id, max_identities, current_identities) + VALUES (1, $1, 0) + ON CONFLICT (id) DO NOTHING + """, + env.MLPA_MAX_SIGNED_IN_USERS, + ) + await conn.execute( + """ + UPDATE mlpa_user_capacity + SET max_identities = $1, + updated_at = NOW() + WHERE id = 1 + """, + env.MLPA_MAX_SIGNED_IN_USERS, + ) + + # Reconcile the claim table (best-effort). + try: + await self._reconcile_capacity_claims() + except Exception as e: + logger.error( + f"Capacity reconciliation failed; serving with last-known " + f"current_identities count: {e}" + ) + + async def _reconcile_capacity_claims(self) -> None: + """Rebuild the claim table from LiteLLM and refresh current_identities.""" managed_service_types = list(env.MLPA_CAPPED_SERVICE_TYPES) - # Read cap-managed identities from the *litellm* pool before opening the - # app_attest transaction. Doing this inside the transaction would leave - # the app_attest session idle-in-transaction across a cross-pool await, - # where idle_in_transaction_session_timeout could reap it mid-rebuild. + # Read from the litellm pool before opening the app_attest transaction: + # doing it inside would leave the session idle-in-transaction across a + # cross-pool await, where idle_in_transaction_session_timeout could reap it. base_identities = await self.litellm_pg.list_managed_base_identities( managed_service_types ) - # Reconciliation rebuilds the entire claim table (bulk delete + insert of - # all base identities), which can exceed the tight pool-wide - # statement_timeout as the user base grows; maintenance_transaction - # raises both timeouts for this work. - async with self.maintenance_transaction() as conn: - await conn.execute( - """ - INSERT INTO mlpa_user_capacity (id, max_identities, current_identities) - VALUES (1, $1, 0) - ON CONFLICT (id) DO NOTHING - """, - env.MLPA_MAX_SIGNED_IN_USERS, - ) - - # Serialize seeding and reconciliation so concurrent app startups - # do not race on the claim table. + # Bulk delete + insert scales with the user base and can exceed the tight + # pool-wide statement_timeout. Statements run back-to-back (no inter- + # statement await), so the raised statement_timeout alone suffices. + async with self.statement_timeout( + env.PG_MAINTENANCE_STATEMENT_TIMEOUT_MS + ) as conn: + # Serialize so concurrent app startups do not race on the claim table. await conn.fetchrow( "SELECT 1 FROM mlpa_user_capacity WHERE id = 1 FOR UPDATE" ) - await conn.execute( - """ - UPDATE mlpa_user_capacity - SET max_identities = $1, - updated_at = NOW() - WHERE id = 1 - """, - env.MLPA_MAX_SIGNED_IN_USERS, - ) - # Rebuild claims from LiteLLM so the counter matches reality after deletes - # or manual DB edits. Blocked rows still count toward capacity. + # Blocked rows still count toward capacity. await conn.execute("DELETE FROM mlpa_user_capacity_identities") if base_identities: @@ -185,57 +198,53 @@ async def admit_managed_base_identity( if not env.MLPA_ENFORCE_SIGNIN_CAP: return True, False - async with self.pool.acquire() as conn: - async with conn.transaction(): - await conn.execute( - f"SET LOCAL lock_timeout = '{env.MLPA_ADMISSION_LOCK_TIMEOUT_MS}ms'" - ) - capacity_row = await conn.fetchrow( - """ - SELECT max_identities, current_identities - FROM mlpa_user_capacity - WHERE id = 1 - FOR UPDATE - """ + async with self.admission_transaction() as conn: + capacity_row = await conn.fetchrow( + """ + SELECT max_identities, current_identities + FROM mlpa_user_capacity + WHERE id = 1 + FOR UPDATE + """ + ) + if capacity_row is None: + raise HTTPException( + status_code=500, + detail="Capacity state not initialized", ) - if capacity_row is None: - raise HTTPException( - status_code=500, - detail="Capacity state not initialized", - ) - already_claimed = await conn.fetchval( - """ - SELECT 1 - FROM mlpa_user_capacity_identities - WHERE base_identity = $1 - """, - base_identity, - ) - if already_claimed: - return True, False + already_claimed = await conn.fetchval( + """ + SELECT 1 + FROM mlpa_user_capacity_identities + WHERE base_identity = $1 + """, + base_identity, + ) + if already_claimed: + return True, False - max_identities = int(capacity_row["max_identities"]) - current_identities = int(capacity_row["current_identities"]) - if current_identities >= max_identities: - return False, False + max_identities = int(capacity_row["max_identities"]) + current_identities = int(capacity_row["current_identities"]) + if current_identities >= max_identities: + return False, False - await conn.execute( - """ - INSERT INTO mlpa_user_capacity_identities (base_identity) - VALUES ($1) - """, - base_identity, - ) - await conn.execute( - """ - UPDATE mlpa_user_capacity - SET current_identities = current_identities + 1, - updated_at = NOW() - WHERE id = 1 - """ - ) - return True, True + await conn.execute( + """ + INSERT INTO mlpa_user_capacity_identities (base_identity) + VALUES ($1) + """, + base_identity, + ) + await conn.execute( + """ + UPDATE mlpa_user_capacity + SET current_identities = current_identities + 1, + updated_at = NOW() + WHERE id = 1 + """ + ) + return True, True async def maybe_release_managed_base_identity_if_no_managed_users( self, base_identity: str @@ -249,55 +258,55 @@ async def maybe_release_managed_base_identity_if_no_managed_users( managed_service_types = list(env.MLPA_CAPPED_SERVICE_TYPES) - async with self.pool.acquire() as conn: - async with conn.transaction(): - await conn.execute( - f"SET LOCAL lock_timeout = '{env.MLPA_ADMISSION_LOCK_TIMEOUT_MS}ms'" - ) - capacity_row = await conn.fetchrow( - """ - SELECT max_identities, current_identities - FROM mlpa_user_capacity - WHERE id = 1 - FOR UPDATE - """ - ) - if capacity_row is None: - return + # Read the litellm state before opening the app_attest transaction: doing + # it inside would hold the FOR UPDATE lock idle-in-transaction across a + # cross-pool await, where idle_in_transaction_session_timeout could reap + # it and abort the release, leaking the claim (mirrors ensure_capacity_state). + has_managed_user_rows = await self.litellm_pg.has_managed_user_rows( + base_identity, + managed_service_types, + ) + if has_managed_user_rows: + return - claimed = await conn.fetchval( - """ - SELECT 1 - FROM mlpa_user_capacity_identities - WHERE base_identity = $1 - """, - base_identity, - ) - if not claimed: - return + async with self.admission_transaction() as conn: + capacity_row = await conn.fetchrow( + """ + SELECT max_identities, current_identities + FROM mlpa_user_capacity + WHERE id = 1 + FOR UPDATE + """ + ) + if capacity_row is None: + return - has_managed_user_rows = await self.litellm_pg.has_managed_user_rows( - base_identity, - managed_service_types, - ) - if has_managed_user_rows: - return + claimed = await conn.fetchval( + """ + SELECT 1 + FROM mlpa_user_capacity_identities + WHERE base_identity = $1 + """, + base_identity, + ) + if not claimed: + return - await conn.execute( - """ - DELETE FROM mlpa_user_capacity_identities - WHERE base_identity = $1 - """, - base_identity, - ) - await conn.execute( - """ - UPDATE mlpa_user_capacity - SET current_identities = GREATEST(current_identities - 1, 0), - updated_at = NOW() - WHERE id = 1 - """ - ) + await conn.execute( + """ + DELETE FROM mlpa_user_capacity_identities + WHERE base_identity = $1 + """, + base_identity, + ) + await conn.execute( + """ + UPDATE mlpa_user_capacity + SET current_identities = GREATEST(current_identities - 1, 0), + updated_at = NOW() + WHERE id = 1 + """ + ) async def get_signup_cap_status(self) -> dict: """ diff --git a/src/mlpa/core/pg_services/litellm_pg_service.py b/src/mlpa/core/pg_services/litellm_pg_service.py index 8055bb0..5911664 100644 --- a/src/mlpa/core/pg_services/litellm_pg_service.py +++ b/src/mlpa/core/pg_services/litellm_pg_service.py @@ -77,14 +77,17 @@ async def block_user(self, user_id: str, blocked: bool = True) -> dict: async def list_users(self, limit: int = 50, offset: int = 0) -> dict: try: - total = await self.pool.fetchval( - 'SELECT COUNT(*) FROM "LiteLLM_EndUserTable"' - ) - users = await self.pool.fetch( - 'SELECT * FROM "LiteLLM_EndUserTable" ORDER BY user_id LIMIT $1 OFFSET $2', - limit, - offset, - ) + # COUNT(*) + deep OFFSET scan the full table; admin-read budget + # rather than the tight pool-wide default. + async with self.statement_timeout(env.PG_ADMIN_READ_TIMEOUT_MS) as conn: + total = await conn.fetchval( + 'SELECT COUNT(*) FROM "LiteLLM_EndUserTable"' + ) + users = await conn.fetch( + 'SELECT * FROM "LiteLLM_EndUserTable" ORDER BY user_id LIMIT $1 OFFSET $2', + limit, + offset, + ) return { "users": [dict(user) for user in users], @@ -106,16 +109,19 @@ async def count_users_by_service_type(self) -> dict: `{base_user_id}:{service_type}`. """ try: - rows = await self.pool.fetch( - """ - SELECT - split_part(user_id, ':', 2) AS service_type, - COUNT(*)::int AS total_users - FROM "LiteLLM_EndUserTable" - WHERE position(':' in user_id) > 0 - GROUP BY service_type - """ - ) + # GROUP BY split_part(...) is unindexable, so always a full-table + # scan; admin-read budget rather than the tight pool-wide default. + async with self.statement_timeout(env.PG_ADMIN_READ_TIMEOUT_MS) as conn: + rows = await conn.fetch( + """ + SELECT + split_part(user_id, ':', 2) AS service_type, + COUNT(*)::int AS total_users + FROM "LiteLLM_EndUserTable" + WHERE position(':' in user_id) > 0 + GROUP BY service_type + """ + ) service_type_counts: dict[str, int] = {} for row in rows: @@ -142,11 +148,13 @@ async def list_managed_base_identities( """ Return distinct base identities for cap-managed service types. - Runs under the maintenance timeout: the DISTINCT scan over the full - end-user table is startup reconciliation work that can exceed the tight - pool-wide statement_timeout on a large user base. + The DISTINCT scan over the full end-user table can exceed the tight + pool-wide statement_timeout on a large user base, so it runs under the + maintenance budget (startup reconciliation work). """ - async with self.maintenance_transaction() as conn: + async with self.statement_timeout( + env.PG_MAINTENANCE_STATEMENT_TIMEOUT_MS + ) as conn: rows = await conn.fetch( """ SELECT DISTINCT split_part(user_id, ':', 1) AS base_identity @@ -192,32 +200,29 @@ async def create_budget(self): for service_type, budget_config in user_feature_budgets.items(): try: - # Each upsert runs in its own maintenance transaction so a single - # failure rolls back only that budget (and is logged below) while - # the loop continues, and so a slow cold-start upsert is not - # cancelled by the tight pool-wide statement_timeout. - async with self.maintenance_transaction() as conn: - await conn.fetchrow( - """ - INSERT INTO "LiteLLM_BudgetTable" - (budget_id, max_budget, rpm_limit, tpm_limit, budget_duration, created_at, updated_at, created_by, updated_by) - VALUES ($1, $2, $3, $4, $5, NOW(), NOW(), $6, $6) - ON CONFLICT (budget_id) DO UPDATE SET - max_budget = EXCLUDED.max_budget, - rpm_limit = EXCLUDED.rpm_limit, - tpm_limit = EXCLUDED.tpm_limit, - budget_duration = EXCLUDED.budget_duration, - updated_at = NOW(), - updated_by = EXCLUDED.updated_by - RETURNING * - """, - budget_config["budget_id"], - budget_config["max_budget"], - budget_config["rpm_limit"], - budget_config["tpm_limit"], - budget_config["budget_duration"], - "default_user_id", - ) + # Fast single-row PK upsert: stays a plain autocommit call (it + # cannot realistically hit the pool-wide statement_timeout). + await self.pool.fetchrow( + """ + INSERT INTO "LiteLLM_BudgetTable" + (budget_id, max_budget, rpm_limit, tpm_limit, budget_duration, created_at, updated_at, created_by, updated_by) + VALUES ($1, $2, $3, $4, $5, NOW(), NOW(), $6, $6) + ON CONFLICT (budget_id) DO UPDATE SET + max_budget = EXCLUDED.max_budget, + rpm_limit = EXCLUDED.rpm_limit, + tpm_limit = EXCLUDED.tpm_limit, + budget_duration = EXCLUDED.budget_duration, + updated_at = NOW(), + updated_by = EXCLUDED.updated_by + RETURNING * + """, + budget_config["budget_id"], + budget_config["max_budget"], + budget_config["rpm_limit"], + budget_config["tpm_limit"], + budget_config["budget_duration"], + "default_user_id", + ) logger.info( f"Budget created/updated: budget_id={budget_config['budget_id']}, " f"service_type={service_type}, max_budget={budget_config['max_budget']}" diff --git a/src/mlpa/core/pg_services/pg_service.py b/src/mlpa/core/pg_services/pg_service.py index f813fd7..e8fabc4 100644 --- a/src/mlpa/core/pg_services/pg_service.py +++ b/src/mlpa/core/pg_services/pg_service.py @@ -14,8 +14,6 @@ class PGService: def __init__(self, db_name: str, statement_timeout_ms: int | None = None): self.db_name = db_name self.db_url = f"{cast(str, env.PG_DB_URL).rstrip('/')}/{db_name}" - # Pool-wide server-enforced statement timeout. Subclasses (or a future - # third pool) may override per-DB; defaults to the global config value. self.statement_timeout_ms = ( statement_timeout_ms if statement_timeout_ms is not None @@ -30,9 +28,8 @@ def pool(self) -> asyncpg.Pool: async def connect(self): try: - # Applied at connect and automatically re-applied on every reconnect, - # so the timeout is durable across the pool's lifetime without - # touching call sites. Values are passed as bare-integer ms strings. + # asyncpg re-applies server_settings on every reconnect, so these + # are durable for the pool's lifetime. Values are ms-integer strings. server_settings = { "statement_timeout": str(self.statement_timeout_ms), "idle_in_transaction_session_timeout": str( @@ -61,26 +58,55 @@ async def disconnect(self): self.connected = False @asynccontextmanager - async def maintenance_transaction(self): + async def _timed_transaction( + self, + statement_timeout_ms: int, + idle_in_tx_timeout_ms: int | None = None, + lock_timeout_ms: int | None = None, + ): """ - Yield a connection inside a transaction whose statement and - idle-in-transaction timeouts are raised to the maintenance value. - - For known-heavy startup work (capacity reconciliation, budget upsert) - that legitimately exceeds the tight pool-wide statement_timeout. Both - GUCs are raised via SET LOCAL: statement_timeout for slow statements, - and idle_in_transaction_session_timeout because such work may await - other queries between statements without the session being reaped. + Yield a connection in a transaction with statement_timeout (and + optionally idle_in_transaction_session_timeout / lock_timeout) set via + SET LOCAL, scoped to the transaction so the connection reverts to the + pool-wide defaults on release. Timeout values are config ints, not input. """ - timeout_ms = env.PG_MAINTENANCE_STATEMENT_TIMEOUT_MS async with self.pool.acquire() as conn: async with conn.transaction(): - await conn.execute(f"SET LOCAL statement_timeout = '{timeout_ms}'") await conn.execute( - f"SET LOCAL idle_in_transaction_session_timeout = '{timeout_ms}'" + f"SET LOCAL statement_timeout = '{statement_timeout_ms}'" ) + if idle_in_tx_timeout_ms is not None: + await conn.execute( + f"SET LOCAL idle_in_transaction_session_timeout = '{idle_in_tx_timeout_ms}'" + ) + if lock_timeout_ms is not None: + await conn.execute( + f"SET LOCAL lock_timeout = '{lock_timeout_ms}ms'" + ) yield conn + @asynccontextmanager + async def statement_timeout(self, timeout_ms: int): + """ + Raise statement_timeout for statements that legitimately exceed the + tight pool-wide default (e.g. unindexable full-table scans). + """ + async with self._timed_transaction(timeout_ms) as conn: + yield conn + + @asynccontextmanager + async def admission_transaction(self): + """ + Signup-capacity admission path: a bounded lock_timeout for the FOR UPDATE + on the singleton capacity row, plus a statement_timeout set above it so + the lock wait is governed by lock_timeout rather than silently capped by + the pool-wide statement_timeout (Postgres counts lock-wait toward it). + """ + lock_ms = env.MLPA_ADMISSION_LOCK_TIMEOUT_MS + stmt_ms = lock_ms + env.PG_STATEMENT_TIMEOUT_MS + async with self._timed_transaction(stmt_ms, lock_timeout_ms=lock_ms) as conn: + yield conn + def check_status(self): if self.pg is None or not self.connected: return False diff --git a/src/mlpa/run.py b/src/mlpa/run.py index b700e13..6a98a06 100644 --- a/src/mlpa/run.py +++ b/src/mlpa/run.py @@ -70,19 +70,8 @@ async def lifespan(app: FastAPI): await app_attest_pg.connect() app_attest_connected = True - # Startup maintenance (budget upsert, capacity reconciliation) is - # best-effort: it runs under a raised maintenance statement_timeout, but - # if it still fails (e.g. timeout on a very large/slow DB) we log and - # continue rather than crash-loop the whole app. Both paths leave prior - # state intact on failure, so the app serves with last-known config. - try: - await litellm_pg.create_budget() - except Exception as e: - logger.error(f"Startup budget creation failed; continuing: {e}") - try: - await app_attest_pg.ensure_capacity_state() - except Exception as e: - logger.error(f"Startup capacity reconciliation failed; continuing: {e}") + await litellm_pg.create_budget() + await app_attest_pg.ensure_capacity_state() yield finally: diff --git a/src/tests/unit/test_pg_timeouts.py b/src/tests/unit/test_pg_timeouts.py index 2cd0c82..e8be16d 100644 --- a/src/tests/unit/test_pg_timeouts.py +++ b/src/tests/unit/test_pg_timeouts.py @@ -12,6 +12,7 @@ def test_pg_timeout_config_from_env(): "PG_STATEMENT_TIMEOUT_MS": "5000", "PG_IDLE_IN_TX_TIMEOUT_MS": "15000", "PG_MAINTENANCE_STATEMENT_TIMEOUT_MS": "45000", + "PG_ADMIN_READ_TIMEOUT_MS": "20000", "PG_COMMAND_TIMEOUT_S": "6.5", } @@ -21,6 +22,7 @@ def test_pg_timeout_config_from_env(): assert test_env.PG_STATEMENT_TIMEOUT_MS == 5000 assert test_env.PG_IDLE_IN_TX_TIMEOUT_MS == 15000 assert test_env.PG_MAINTENANCE_STATEMENT_TIMEOUT_MS == 45000 + assert test_env.PG_ADMIN_READ_TIMEOUT_MS == 20000 assert test_env.PG_COMMAND_TIMEOUT_S == 6.5 @@ -29,6 +31,7 @@ def test_pg_timeout_defaults(): assert env.PG_STATEMENT_TIMEOUT_MS == 3000 assert env.PG_IDLE_IN_TX_TIMEOUT_MS == 10000 assert env.PG_MAINTENANCE_STATEMENT_TIMEOUT_MS == 30000 + assert env.PG_ADMIN_READ_TIMEOUT_MS == 15000 assert env.PG_COMMAND_TIMEOUT_S is None @@ -95,17 +98,68 @@ def _set_local_calls(conn, guc): ] -async def test_maintenance_transaction_raises_both_timeouts(mocker): - """The helper lifts statement_timeout AND idle_in_transaction_session_timeout.""" +async def test_admission_transaction_sets_lock_and_statement_timeout(mocker): + """admission_transaction lifts lock_timeout and a statement_timeout above it.""" conn = _mock_maintenance_conn() mocker.patch.object(PGService, "pool", new=_mock_pool(conn)) service = PGService("some_db") - async with service.maintenance_transaction() as yielded: + async with service.admission_transaction() as yielded: assert yielded is conn - assert _set_local_calls(conn, "statement_timeout") - assert _set_local_calls(conn, "idle_in_transaction_session_timeout") + assert _set_local_calls(conn, "lock_timeout") + stmt_calls = _set_local_calls(conn, "statement_timeout") + assert stmt_calls + # statement_timeout must exceed lock_timeout so the lock wait is governed by + # lock_timeout, not silently capped by the tight pool-wide statement_timeout. + expected = env.MLPA_ADMISSION_LOCK_TIMEOUT_MS + env.PG_STATEMENT_TIMEOUT_MS + assert str(expected) in stmt_calls[0] + + +async def test_statement_timeout_sets_only_statement_timeout(mocker): + """statement_timeout() lifts statement_timeout but not idle-in-tx (read-only helper).""" + conn = _mock_maintenance_conn() + mocker.patch.object(PGService, "pool", new=_mock_pool(conn)) + + service = PGService("some_db") + async with service.statement_timeout(15000) as yielded: + assert yielded is conn + + stmt_calls = _set_local_calls(conn, "statement_timeout") + assert stmt_calls + assert "15000" in stmt_calls[0] + assert not _set_local_calls(conn, "idle_in_transaction_session_timeout") + + +async def test_count_users_by_service_type_uses_admin_read_timeout(mocker): + """The unindexable full-table GROUP BY runs under the admin-read timeout, not 3s.""" + from mlpa.core.pg_services.litellm_pg_service import LiteLLMPGService + + conn = _mock_maintenance_conn() + mocker.patch.object(PGService, "pool", new=_mock_pool(conn)) + + service = LiteLLMPGService() + await service.count_users_by_service_type() + + timeout_calls = _set_local_calls(conn, "statement_timeout") + assert timeout_calls + assert str(env.PG_ADMIN_READ_TIMEOUT_MS) in timeout_calls[0] + conn.fetch.assert_awaited_once() + + +async def test_list_users_uses_admin_read_timeout(mocker): + """The full-table COUNT(*) + deep OFFSET page run under the admin-read timeout.""" + from mlpa.core.pg_services.litellm_pg_service import LiteLLMPGService + + conn = _mock_maintenance_conn() + mocker.patch.object(PGService, "pool", new=_mock_pool(conn)) + + service = LiteLLMPGService() + await service.list_users() + + timeout_calls = _set_local_calls(conn, "statement_timeout") + assert timeout_calls + assert str(env.PG_ADMIN_READ_TIMEOUT_MS) in timeout_calls[0] async def test_list_managed_base_identities_uses_maintenance_timeout(mocker): @@ -124,8 +178,15 @@ async def test_list_managed_base_identities_uses_maintenance_timeout(mocker): conn.fetch.assert_awaited_once() -async def test_ensure_capacity_state_reads_identities_before_transaction(mocker): - """The cross-pool read happens before the app_attest transaction is opened.""" +async def test_ensure_capacity_state_reads_identities_before_reconcile_transaction( + mocker, +): + """The cross-pool read must not be issued inside the reconcile transaction. + + The session must never sit idle-in-transaction across the cross-pool await, + so the litellm read has to land AFTER the (self-contained) seed transaction + and BEFORE the destructive claim rebuild (DELETE ...). + """ order: list[str] = [] litellm_pg = MagicMock() @@ -139,7 +200,7 @@ async def _list(*_args, **_kwargs): conn = _mock_maintenance_conn() async def _execute(sql, *_args, **_kwargs): - order.append(f"exec:{sql.strip().split()[0]}") + order.append(f"exec:{sql.strip()}") conn.execute = AsyncMock(side_effect=_execute) mocker.patch.object(PGService, "pool", new=_mock_pool(conn)) @@ -147,11 +208,21 @@ async def _execute(sql, *_args, **_kwargs): service = AppAttestPGService(litellm_pg) await service.ensure_capacity_state() - # The litellm read must precede every statement issued inside the app_attest - # transaction, so the session is never idle-in-transaction across that await. - assert order[0] == "read" - assert all(step == "read" or step.startswith("exec:") for step in order) assert order.count("read") == 1 + read_idx = order.index("read") + + # The claim rebuild (DELETE) — the first statement of the reconcile + # transaction — must come strictly after the cross-pool read. + delete_idx = next( + i for i, step in enumerate(order) if step.startswith("exec:DELETE") + ) + assert read_idx < delete_idx + + # The seed (INSERT/UPDATE on mlpa_user_capacity) runs in its own transaction + # and commits before the read, so no cross-pool await spans an open tx. + assert any( + step.startswith("exec:INSERT INTO mlpa_user_capacity ") for step in order + ) async def test_ensure_capacity_state_raises_maintenance_timeout(mocker):