Skip to content
61 changes: 46 additions & 15 deletions src/mlpa/core/auth/authorize.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@
ServiceType,
)
from mlpa.core.config import env
from mlpa.core.metrics import record_chat_availability_for
from mlpa.core.prometheus_metrics import AvailabilityReason
from mlpa.core.routers.appattest import app_attest_auth
from mlpa.core.utils import extract_user_from_play_integrity_jwt, parse_app_attest_jwt

Expand Down Expand Up @@ -123,27 +125,56 @@ async def authorize_chat_request(
)

if not is_service_type_valid:
record_chat_availability_for(
AvailabilityReason.INVALID_SERVICE_TYPE_FOR_MODEL,
model=chat_request.model,
service_type=service_type.value,
purpose=(purpose or "").strip(),
)
raise HTTPException(
status_code=400,
detail=f"Invalid service-type value for model {chat_request.model}. Should be one of {env.forced_model_service_type_pairs.get(chat_request.model)}",
)

return await _authorize_common_request(
request=request,
build_authorized_request=lambda user, purpose_value: AuthorizedChatRequest(
user=user,
try:
return await _authorize_common_request(
request=request,
build_authorized_request=lambda user, purpose_value: AuthorizedChatRequest(
user=user,
service_type=service_type.value,
purpose=purpose_value,
**chat_request.model_dump(exclude_unset=True),
),
authorization=authorization,
service_type=service_type,
purpose=purpose,
x_dev_authorization=x_dev_authorization,
use_app_attest=use_app_attest,
use_qa_certificates=use_qa_certificates,
use_play_integrity=use_play_integrity,
)
except HTTPException as exc:
# Only record terminal HTTP failures from the shared auth call:
# - 401: expected or normalized auth rejection (bad creds, expired token, etc.)
# - 400: client error to the auth layer (invalid purpose, or malformed App
# Attest base64 decoded in app_attest_auth before its try block)
# - anything else (e.g. App Attest's explicit 500): re-raised unrecorded;
# auth-system-failure capture is left to a follow-on auth backend change
# Non-HTTPException errors are not caught here and propagate unrecorded.
# Purpose is unknown at this point, so "" is always used as a placeholder.
if exc.status_code == 401:
reason = AvailabilityReason.AUTH_REJECTED
elif exc.status_code == 400:
reason = AvailabilityReason.INVALID_AUTH_REQUEST
else:
raise
record_chat_availability_for(
reason,
model=chat_request.model,
service_type=service_type.value,
purpose=purpose_value,
**chat_request.model_dump(exclude_unset=True),
),
authorization=authorization,
service_type=service_type,
purpose=purpose,
x_dev_authorization=x_dev_authorization,
use_app_attest=use_app_attest,
use_qa_certificates=use_qa_certificates,
use_play_integrity=use_play_integrity,
)
purpose="",
)
raise


async def authorize_search_request(
Comment thread
noahpodgurski marked this conversation as resolved.
Expand Down
46 changes: 35 additions & 11 deletions src/mlpa/core/completions.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,13 +20,15 @@
from mlpa.core.logger import logger
from mlpa.core.metrics import (
extract_tool_names,
record_chat_availability,
record_chat_request_rejection,
record_completion_latency,
record_completion_success,
record_request_with_tools,
record_ttft,
)
from mlpa.core.prometheus_metrics import (
AvailabilityReason,
PrometheusRejectionReason,
PrometheusResult,
)
Expand All @@ -52,20 +54,31 @@ def _build_litellm_body(req: AuthorizedChatRequest, *, stream: bool) -> dict:
async def get_or_create_user_for_completion(
user_id: str, req: AuthorizedChatRequest | AuthorizedSearchRequest
):
"""Wraps get_or_create_user and records a signup-cap rejection metric if applicable."""
"""
Wraps get_or_create_user and records availability for chat requests:
- signup cap (403 + MAX_USERS_REACHED): excluded, alongside the existing rejection metric
- user-resolution server or system failure (status >= 500): failure
- search requests and non-signup-cap, non-5xx failures: not recorded
"""
try:
return await get_or_create_user(user_id)
except HTTPException as exc:
if (
exc.status_code == 403
and isinstance(exc.detail, dict)
and exc.detail.get("error") == ERROR_CODE_MAX_USERS_REACHED
and isinstance(req, AuthorizedChatRequest)
Comment thread
noahpodgurski marked this conversation as resolved.
):
record_chat_request_rejection(
req,
PrometheusRejectionReason.SIGNUP_CAP_EXCEEDED,
)
if isinstance(req, AuthorizedChatRequest):
if (
exc.status_code == 403
and isinstance(exc.detail, dict)
and exc.detail.get("error") == ERROR_CODE_MAX_USERS_REACHED
):
record_chat_request_rejection(
req,
PrometheusRejectionReason.SIGNUP_CAP_EXCEEDED,
)
record_chat_availability(req, AvailabilityReason.SIGNUP_CAP_EXCEEDED)
elif exc.status_code >= 500:
# User-resolution server or system failure. Non-signup-cap 4xx errors
# are not recorded; a client-side 4xx should get its own classification
# rather than counting as an availability failure.
record_chat_availability(req, AvailabilityReason.PROVISIONING_FAILURE)
raise


Expand All @@ -80,6 +93,7 @@ async def stream_completion(
record_request_with_tools(authorized_chat_request)
body = _build_litellm_body(authorized_chat_request, stream=True)
result = PrometheusResult.ERROR
availability_reason = AvailabilityReason.UPSTREAM_ERROR
is_first_token = True
prompt_tokens = 0
completion_tokens = 0
Expand Down Expand Up @@ -139,6 +153,7 @@ async def _read_next_chunk(
if match.log_message:
logger.warning(match.log_message)
record_chat_request_rejection(authorized_chat_request, match.reason)
availability_reason = match.availability_reason()
yield f'data: {{"error": {match.error_code}}}\n\n'.encode()
return

Expand Down Expand Up @@ -237,6 +252,7 @@ async def _read_next_chunk(
return

if not streaming_started:
availability_reason = AvailabilityReason.EMPTY_RESPONSE
yield raise_and_log(
RuntimeError("LiteLLM returned an empty response"),
True,
Expand All @@ -256,6 +272,7 @@ async def _read_next_chunk(
snapshot=litellm_routing_snapshot,
)
result = PrometheusResult.SUCCESS
availability_reason = AvailabilityReason.VALID_RESPONSE
except (GeneratorExit, asyncio.CancelledError):
# Client went away mid-stream: Starlette tears the generator down by
# throwing GeneratorExit (or cancelling the task) at the paused
Expand Down Expand Up @@ -291,9 +308,12 @@ async def _read_next_chunk(
if result == PrometheusResult.ERROR and disconnect_event.is_set():
result = PrometheusResult.ABORT
logger.info(_client_disconnected_msg)
if result == PrometheusResult.ABORT:
availability_reason = AvailabilityReason.CLIENT_DISCONNECT
record_completion_latency(
authorized_chat_request, result, time.perf_counter() - start_time
)
record_chat_availability(authorized_chat_request, availability_reason)


async def get_completion(authorized_chat_request: AuthorizedChatRequest):
Expand All @@ -304,6 +324,7 @@ async def get_completion(authorized_chat_request: AuthorizedChatRequest):
record_request_with_tools(authorized_chat_request)
body = _build_litellm_body(authorized_chat_request, stream=False)
result = PrometheusResult.ERROR
availability_reason = AvailabilityReason.UPSTREAM_ERROR
logger.debug(
f"Starting a non-stream completion using {authorized_chat_request.model}, for user {authorized_chat_request.user}",
)
Expand All @@ -326,6 +347,7 @@ async def get_completion(authorized_chat_request: AuthorizedChatRequest):
if match.log_message:
logger.warning(match.log_message)
record_chat_request_rejection(authorized_chat_request, match.reason)
availability_reason = match.availability_reason()
headers = (
{"Retry-After": match.retry_after} if match.retry_after else None
)
Expand Down Expand Up @@ -362,6 +384,7 @@ async def get_completion(authorized_chat_request: AuthorizedChatRequest):
snapshot=litellm_routing_snapshot,
)
result = PrometheusResult.SUCCESS
availability_reason = AvailabilityReason.VALID_RESPONSE
return data
except HTTPException:
raise
Expand All @@ -371,3 +394,4 @@ async def get_completion(authorized_chat_request: AuthorizedChatRequest):
record_completion_latency(
authorized_chat_request, result, time.perf_counter() - start_time
)
record_chat_availability(authorized_chat_request, availability_reason)
20 changes: 19 additions & 1 deletion src/mlpa/core/errors.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
ERROR_CODE_REQUEST_TOO_LARGE,
ERROR_CODE_UPSTREAM_RATE_LIMIT_EXCEEDED,
)
from mlpa.core.prometheus_metrics import PrometheusRejectionReason
from mlpa.core.prometheus_metrics import AvailabilityReason, PrometheusRejectionReason
from mlpa.core.utils import (
is_context_window_error,
is_invalid_model_name_error,
Expand All @@ -18,6 +18,15 @@
is_rate_limit_error,
)

_REJECTION_TO_AVAILABILITY_REASON: dict[
PrometheusRejectionReason, AvailabilityReason
] = {
PrometheusRejectionReason.BUDGET_EXCEEDED: AvailabilityReason.BUDGET_EXCEEDED,
PrometheusRejectionReason.PAYLOAD_TOO_LARGE: AvailabilityReason.PAYLOAD_TOO_LARGE,
Comment thread
noahpodgurski marked this conversation as resolved.
PrometheusRejectionReason.INVALID_MODEL_NAME: AvailabilityReason.INVALID_MODEL_NAME,
PrometheusRejectionReason.INVALID_REQUEST: AvailabilityReason.INVALID_REQUEST,
}


@dataclass(frozen=True)
class RejectionMatch:
Expand All @@ -27,6 +36,15 @@ class RejectionMatch:
retry_after: str | None = None
log_message: str = ""

def availability_reason(self) -> AvailabilityReason:
# SIGNUP_CAP_EXCEEDED is recorded pre-completion, not via classify_upstream_error,
# so it is not in the mapping below.
if self.reason == PrometheusRejectionReason.RATE_LIMITED:
if self.error_code == ERROR_CODE_UPSTREAM_RATE_LIMIT_EXCEEDED:
return AvailabilityReason.RATE_LIMITED_UPSTREAM
return AvailabilityReason.RATE_LIMITED_PLATFORM
return _REJECTION_TO_AVAILABILITY_REASON[self.reason]


_RATE_LIMIT_REJECTION: dict[int, tuple[PrometheusRejectionReason, str, str]] = {
ERROR_CODE_BUDGET_LIMIT_EXCEEDED: (
Expand Down
29 changes: 29 additions & 0 deletions src/mlpa/core/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,11 @@
LitellmRoutingSnapshot,
)
from mlpa.core.prometheus_metrics import (
AvailabilityReason,
PrometheusRejectionReason,
PrometheusResult,
TokenType,
availability_outcome_for,
metrics,
)

Expand Down Expand Up @@ -41,6 +43,33 @@ def record_search_request_rejection(
metrics.search_request_rejections.labels(reason=reason, **_search_labels(req)).inc()


def record_chat_availability_for(
reason: AvailabilityReason,
*,
model: str,
service_type: str,
purpose: str,
) -> None:
metrics.chat_availability.labels(
outcome=availability_outcome_for(reason),
reason=reason,
model=model,
service_type=service_type,
purpose=purpose,
).inc()


def record_chat_availability(
req: AuthorizedChatRequest, reason: AvailabilityReason
) -> None:
record_chat_availability_for(
reason,
model=req.model,
service_type=req.service_type,
purpose=req.purpose,
)


def record_completion_latency(
req: AuthorizedChatRequest,
result: PrometheusResult,
Expand Down
12 changes: 12 additions & 0 deletions src/mlpa/core/middleware/request_size.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@

from mlpa.core.config import ERROR_CODE_REQUEST_TOO_LARGE, env
from mlpa.core.logger import logger
from mlpa.core.metrics import record_chat_availability_for
from mlpa.core.prometheus_metrics import AvailabilityReason


async def check_request_size_middleware(request: Request, call_next):
Expand All @@ -19,6 +21,16 @@ async def check_request_size_middleware(request: Request, call_next):
logger.warning(
f"Request size {size} bytes exceeds maximum {env.MAX_REQUEST_SIZE_BYTES} bytes"
)
# `model` is in the request body, which we don't read here.
# We reject on the Content-Length header without parsing it.
record_chat_availability_for(
AvailabilityReason.PAYLOAD_TOO_LARGE,
model="",
service_type=(
request.headers.get("service-type") or ""
).strip(),
purpose=(request.headers.get("purpose") or "").strip(),
)
return JSONResponse(
status_code=413,
content={"error": ERROR_CODE_REQUEST_TOO_LARGE},
Expand Down
Loading
Loading