Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 19 additions & 0 deletions src/mlpa/core/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -281,6 +281,25 @@ def valid_service_type_for_model(self, service_type: str, model: str) -> bool:
PG_POOL_MIN_SIZE: int = 1
PG_POOL_MAX_SIZE: int = 10
PG_PREPARED_STMT_CACHE_MAX_SIZE: int = 100
# Postgres query timeouts. Server-enforced (statement_timeout) so a runaway
# query is killed even if the client/event loop hangs, preventing connection
# pile-up. Values are milliseconds; 0 = unlimited (Postgres semantics).
PG_STATEMENT_TIMEOUT_MS: int = 3000
# Reaps transactions left idle between statements (releases held locks).
# Should be >= statement_timeout since a tx legitimately spans round-trips.
PG_IDLE_IN_TX_TIMEOUT_MS: int = 10000
# Raised budget for heavy startup work (capacity reconciliation), applied
# per-transaction via SET LOCAL. 0 = unlimited.
PG_MAINTENANCE_STATEMENT_TIMEOUT_MS: int = 30000
# Raised budget for client-facing admin reads that do unindexable full-table
# scans (user listing, counts-by-service-type). 0 = unlimited.
PG_ADMIN_READ_TIMEOUT_MS: int = 15000
# Optional asyncpg client-side backstop (seconds). None = disabled.
# WARNING: this is a pool-level client-side cancel that is NOT relaxed by the
# per-transaction SET LOCAL statement_timeout. If enabled, set it above the
# largest per-statement budget (PG_MAINTENANCE_STATEMENT_TIMEOUT_MS), or it
# will silently cancel the maintenance/admin-read queries.
PG_COMMAND_TIMEOUT_S: float | None = None

# LLM request default values
TEMPERATURE: float = 0.1
Expand Down
275 changes: 146 additions & 129 deletions src/mlpa/core/pg_services/app_attest_pg_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,14 +102,14 @@ async def delete_key(self, key_id_b64: str):

async def ensure_capacity_state(self) -> None:
"""
Ensure the singleton capacity row and base-identity claim table exist.
Seed the singleton capacity row, then reconcile the claim table.

Reconciles the claim table with current LiteLLM end-user rows for
cap-managed service types on every startup (blocked rows included) so
the counter reflects reality after external writes and config changes.
The seed is critical and fatal on failure: without the row every
admission 500s, so a failure should crash startup rather than serve
broken. Reconciliation is best-effort (see _reconcile_capacity_claims):
if it fails the row still exists with a stale count and admissions work.
"""
managed_service_types = list(env.MLPA_CAPPED_SERVICE_TYPES)

# Seed the singleton row (fatal on failure).
async with self.pool.acquire() as conn:
async with conn.transaction():
await conn.execute(
Expand All @@ -120,12 +120,6 @@ async def ensure_capacity_state(self) -> None:
""",
env.MLPA_MAX_SIGNED_IN_USERS,
)

# Serialize seeding and reconciliation so concurrent app startups
# do not race on the claim table.
await conn.fetchrow(
"SELECT 1 FROM mlpa_user_capacity WHERE id = 1 FOR UPDATE"
)
await conn.execute(
"""
UPDATE mlpa_user_capacity
Expand All @@ -136,35 +130,62 @@ async def ensure_capacity_state(self) -> None:
env.MLPA_MAX_SIGNED_IN_USERS,
)

# Rebuild claims from LiteLLM so the counter matches reality after deletes
# or manual DB edits. Blocked rows still count toward capacity.
await conn.execute("DELETE FROM mlpa_user_capacity_identities")
# Reconcile the claim table (best-effort).
try:
await self._reconcile_capacity_claims()
except Exception as e:
logger.error(
f"Capacity reconciliation failed; serving with last-known "
f"current_identities count: {e}"
)

base_identities = await self.litellm_pg.list_managed_base_identities(
managed_service_types
)
if base_identities:
await conn.executemany(
"""
INSERT INTO mlpa_user_capacity_identities (base_identity)
VALUES ($1)
""",
[(base_identity,) for base_identity in base_identities],
)

seeded_claims = await conn.fetchval(
"SELECT COUNT(*) FROM mlpa_user_capacity_identities"
)
await conn.execute(
async def _reconcile_capacity_claims(self) -> None:
"""Rebuild the claim table from LiteLLM and refresh current_identities."""
managed_service_types = list(env.MLPA_CAPPED_SERVICE_TYPES)

# Read from the litellm pool before opening the app_attest transaction:
# doing it inside would leave the session idle-in-transaction across a
# cross-pool await, where idle_in_transaction_session_timeout could reap it.
base_identities = await self.litellm_pg.list_managed_base_identities(
managed_service_types
)

# Bulk delete + insert scales with the user base and can exceed the tight
# pool-wide statement_timeout. Statements run back-to-back (no inter-
# statement await), so the raised statement_timeout alone suffices.
async with self.statement_timeout(
env.PG_MAINTENANCE_STATEMENT_TIMEOUT_MS
) as conn:
# Serialize so concurrent app startups do not race on the claim table.
await conn.fetchrow(
"SELECT 1 FROM mlpa_user_capacity WHERE id = 1 FOR UPDATE"
)

# Blocked rows still count toward capacity.
await conn.execute("DELETE FROM mlpa_user_capacity_identities")

if base_identities:
await conn.executemany(
"""
UPDATE mlpa_user_capacity
SET current_identities = $1,
updated_at = NOW()
WHERE id = 1
INSERT INTO mlpa_user_capacity_identities (base_identity)
VALUES ($1)
""",
seeded_claims,
[(base_identity,) for base_identity in base_identities],
)

seeded_claims = await conn.fetchval(
"SELECT COUNT(*) FROM mlpa_user_capacity_identities"
)
await conn.execute(
"""
UPDATE mlpa_user_capacity
SET current_identities = $1,
updated_at = NOW()
WHERE id = 1
""",
seeded_claims,
)

async def admit_managed_base_identity(
self, base_identity: str
) -> tuple[bool, bool]:
Expand All @@ -177,57 +198,53 @@ async def admit_managed_base_identity(
if not env.MLPA_ENFORCE_SIGNIN_CAP:
return True, False

async with self.pool.acquire() as conn:
async with conn.transaction():
await conn.execute(
f"SET LOCAL lock_timeout = '{env.MLPA_ADMISSION_LOCK_TIMEOUT_MS}ms'"
)
capacity_row = await conn.fetchrow(
"""
SELECT max_identities, current_identities
FROM mlpa_user_capacity
WHERE id = 1
FOR UPDATE
"""
async with self.admission_transaction() as conn:
capacity_row = await conn.fetchrow(
"""
SELECT max_identities, current_identities
FROM mlpa_user_capacity
WHERE id = 1
FOR UPDATE
"""
)
if capacity_row is None:
raise HTTPException(
status_code=500,
detail="Capacity state not initialized",
)
if capacity_row is None:
raise HTTPException(
status_code=500,
detail="Capacity state not initialized",
)

already_claimed = await conn.fetchval(
"""
SELECT 1
FROM mlpa_user_capacity_identities
WHERE base_identity = $1
""",
base_identity,
)
if already_claimed:
return True, False
already_claimed = await conn.fetchval(
"""
SELECT 1
FROM mlpa_user_capacity_identities
WHERE base_identity = $1
""",
base_identity,
)
if already_claimed:
return True, False

max_identities = int(capacity_row["max_identities"])
current_identities = int(capacity_row["current_identities"])
if current_identities >= max_identities:
return False, False
max_identities = int(capacity_row["max_identities"])
current_identities = int(capacity_row["current_identities"])
if current_identities >= max_identities:
return False, False

await conn.execute(
"""
INSERT INTO mlpa_user_capacity_identities (base_identity)
VALUES ($1)
""",
base_identity,
)
await conn.execute(
"""
UPDATE mlpa_user_capacity
SET current_identities = current_identities + 1,
updated_at = NOW()
WHERE id = 1
"""
)
return True, True
await conn.execute(
"""
INSERT INTO mlpa_user_capacity_identities (base_identity)
VALUES ($1)
""",
base_identity,
)
await conn.execute(
"""
UPDATE mlpa_user_capacity
SET current_identities = current_identities + 1,
updated_at = NOW()
WHERE id = 1
"""
)
return True, True

async def maybe_release_managed_base_identity_if_no_managed_users(
self, base_identity: str
Expand All @@ -241,55 +258,55 @@ async def maybe_release_managed_base_identity_if_no_managed_users(

managed_service_types = list(env.MLPA_CAPPED_SERVICE_TYPES)

async with self.pool.acquire() as conn:
async with conn.transaction():
await conn.execute(
f"SET LOCAL lock_timeout = '{env.MLPA_ADMISSION_LOCK_TIMEOUT_MS}ms'"
)
capacity_row = await conn.fetchrow(
"""
SELECT max_identities, current_identities
FROM mlpa_user_capacity
WHERE id = 1
FOR UPDATE
"""
)
if capacity_row is None:
return
# Read the litellm state before opening the app_attest transaction: doing
# it inside would hold the FOR UPDATE lock idle-in-transaction across a
# cross-pool await, where idle_in_transaction_session_timeout could reap
# it and abort the release, leaking the claim (mirrors ensure_capacity_state).
has_managed_user_rows = await self.litellm_pg.has_managed_user_rows(
base_identity,
managed_service_types,
)
if has_managed_user_rows:
return

claimed = await conn.fetchval(
"""
SELECT 1
FROM mlpa_user_capacity_identities
WHERE base_identity = $1
""",
base_identity,
)
if not claimed:
return
async with self.admission_transaction() as conn:
capacity_row = await conn.fetchrow(
"""
SELECT max_identities, current_identities
FROM mlpa_user_capacity
WHERE id = 1
FOR UPDATE
"""
)
if capacity_row is None:
return

has_managed_user_rows = await self.litellm_pg.has_managed_user_rows(
base_identity,
managed_service_types,
)
if has_managed_user_rows:
return
claimed = await conn.fetchval(
"""
SELECT 1
FROM mlpa_user_capacity_identities
WHERE base_identity = $1
""",
base_identity,
)
if not claimed:
return

await conn.execute(
"""
DELETE FROM mlpa_user_capacity_identities
WHERE base_identity = $1
""",
base_identity,
)
await conn.execute(
"""
UPDATE mlpa_user_capacity
SET current_identities = GREATEST(current_identities - 1, 0),
updated_at = NOW()
WHERE id = 1
"""
)
await conn.execute(
"""
DELETE FROM mlpa_user_capacity_identities
WHERE base_identity = $1
""",
base_identity,
)
await conn.execute(
"""
UPDATE mlpa_user_capacity
SET current_identities = GREATEST(current_identities - 1, 0),
updated_at = NOW()
WHERE id = 1
"""
)

async def get_signup_cap_status(self) -> dict:
"""
Expand Down
Loading