diff --git a/src/mlpa/core/auth/authorize.py b/src/mlpa/core/auth/authorize.py index 50eae01..9c25c00 100644 --- a/src/mlpa/core/auth/authorize.py +++ b/src/mlpa/core/auth/authorize.py @@ -13,6 +13,8 @@ ServiceType, ) from mlpa.core.config import env +from mlpa.core.metrics import record_chat_availability_for +from mlpa.core.prometheus_metrics import AvailabilityReason from mlpa.core.routers.appattest import app_attest_auth from mlpa.core.utils import extract_user_from_play_integrity_jwt, parse_app_attest_jwt @@ -123,27 +125,55 @@ async def authorize_chat_request( ) if not is_service_type_valid: + record_chat_availability_for( + AvailabilityReason.INVALID_SERVICE_TYPE_FOR_MODEL, + model=chat_request.model, + service_type=service_type.value, + purpose=purpose or "", + ) raise HTTPException( status_code=400, detail=f"Invalid service-type value for model {chat_request.model}. Should be one of {env.forced_model_service_type_pairs.get(chat_request.model)}", ) - return await _authorize_common_request( - request=request, - build_authorized_request=lambda user, purpose_value: AuthorizedChatRequest( - user=user, + try: + return await _authorize_common_request( + request=request, + build_authorized_request=lambda user, purpose_value: AuthorizedChatRequest( + user=user, + service_type=service_type.value, + purpose=purpose_value, + **chat_request.model_dump(exclude_unset=True), + ), + authorization=authorization, + service_type=service_type, + purpose=purpose, + x_dev_authorization=x_dev_authorization, + use_app_attest=use_app_attest, + use_qa_certificates=use_qa_certificates, + use_play_integrity=use_play_integrity, + ) + except HTTPException as exc: + # Only record terminal HTTP failures from the shared auth call: + # - 401: expected or normalized auth rejection (bad creds, expired token, etc.) + # - 400: client error to the auth layer (invalid purpose, or malformed App + # Attest base64 decoded in app_attest_auth before its try block) + # - anything else (e.g. App Attest's explicit 500): re-raised unrecorded; + # auth-system-failure capture is left to a follow-on auth backend change + # Non-HTTPException errors are not caught here and propagate unrecorded. + if exc.status_code == 401: + reason = AvailabilityReason.AUTH_REJECTED + elif exc.status_code == 400: + reason = AvailabilityReason.INVALID_AUTH_REQUEST + else: + raise + record_chat_availability_for( + reason, + model=chat_request.model, service_type=service_type.value, - purpose=purpose_value, - **chat_request.model_dump(exclude_unset=True), - ), - authorization=authorization, - service_type=service_type, - purpose=purpose, - x_dev_authorization=x_dev_authorization, - use_app_attest=use_app_attest, - use_qa_certificates=use_qa_certificates, - use_play_integrity=use_play_integrity, - ) + purpose=purpose or "", + ) + raise async def authorize_search_request( diff --git a/src/mlpa/core/completions.py b/src/mlpa/core/completions.py index 7a5a8ec..971a586 100644 --- a/src/mlpa/core/completions.py +++ b/src/mlpa/core/completions.py @@ -20,6 +20,7 @@ from mlpa.core.logger import logger from mlpa.core.metrics import ( extract_tool_names, + record_chat_availability, record_chat_request_rejection, record_completion_latency, record_completion_success, @@ -27,6 +28,7 @@ record_ttft, ) from mlpa.core.prometheus_metrics import ( + AvailabilityReason, PrometheusRejectionReason, PrometheusResult, ) @@ -52,20 +54,31 @@ def _build_litellm_body(req: AuthorizedChatRequest, *, stream: bool) -> dict: async def get_or_create_user_for_completion( user_id: str, req: AuthorizedChatRequest | AuthorizedSearchRequest ): - """Wraps get_or_create_user and records a signup-cap rejection metric if applicable.""" + """ + Wraps get_or_create_user and records availability for chat requests: + - signup cap (403 + MAX_USERS_REACHED): excluded, alongside the existing rejection metric + - user-resolution server or system failure (status >= 500): failure + - search requests and non-signup-cap, non-5xx failures: not recorded + """ try: return await get_or_create_user(user_id) except HTTPException as exc: - if ( - exc.status_code == 403 - and isinstance(exc.detail, dict) - and exc.detail.get("error") == ERROR_CODE_MAX_USERS_REACHED - and isinstance(req, AuthorizedChatRequest) - ): - record_chat_request_rejection( - req, - PrometheusRejectionReason.SIGNUP_CAP_EXCEEDED, - ) + if isinstance(req, AuthorizedChatRequest): + if ( + exc.status_code == 403 + and isinstance(exc.detail, dict) + and exc.detail.get("error") == ERROR_CODE_MAX_USERS_REACHED + ): + record_chat_request_rejection( + req, + PrometheusRejectionReason.SIGNUP_CAP_EXCEEDED, + ) + record_chat_availability(req, AvailabilityReason.SIGNUP_CAP_EXCEEDED) + elif exc.status_code >= 500: + # User-resolution server or system failure. Non-signup-cap 4xx errors + # are not recorded; a client-side 4xx should get its own classification + # rather than counting as an availability failure. + record_chat_availability(req, AvailabilityReason.PROVISIONING_FAILURE) raise @@ -80,6 +93,7 @@ async def stream_completion( record_request_with_tools(authorized_chat_request) body = _build_litellm_body(authorized_chat_request, stream=True) result = PrometheusResult.ERROR + availability_reason = AvailabilityReason.UPSTREAM_ERROR is_first_token = True prompt_tokens = 0 completion_tokens = 0 @@ -139,6 +153,7 @@ async def _read_next_chunk( if match.log_message: logger.warning(match.log_message) record_chat_request_rejection(authorized_chat_request, match.reason) + availability_reason = match.availability_reason() yield f'data: {{"error": {match.error_code}}}\n\n'.encode() return @@ -237,6 +252,7 @@ async def _read_next_chunk( return if not streaming_started: + availability_reason = AvailabilityReason.EMPTY_RESPONSE yield raise_and_log( RuntimeError("LiteLLM returned an empty response"), True, @@ -256,6 +272,7 @@ async def _read_next_chunk( snapshot=litellm_routing_snapshot, ) result = PrometheusResult.SUCCESS + availability_reason = AvailabilityReason.VALID_RESPONSE except (GeneratorExit, asyncio.CancelledError): # Client went away mid-stream: Starlette tears the generator down by # throwing GeneratorExit (or cancelling the task) at the paused @@ -291,9 +308,12 @@ async def _read_next_chunk( if result == PrometheusResult.ERROR and disconnect_event.is_set(): result = PrometheusResult.ABORT logger.info(_client_disconnected_msg) + if result == PrometheusResult.ABORT: + availability_reason = AvailabilityReason.CLIENT_DISCONNECT record_completion_latency( authorized_chat_request, result, time.perf_counter() - start_time ) + record_chat_availability(authorized_chat_request, availability_reason) async def get_completion(authorized_chat_request: AuthorizedChatRequest): @@ -304,6 +324,7 @@ async def get_completion(authorized_chat_request: AuthorizedChatRequest): record_request_with_tools(authorized_chat_request) body = _build_litellm_body(authorized_chat_request, stream=False) result = PrometheusResult.ERROR + availability_reason = AvailabilityReason.UPSTREAM_ERROR logger.debug( f"Starting a non-stream completion using {authorized_chat_request.model}, for user {authorized_chat_request.user}", ) @@ -326,6 +347,7 @@ async def get_completion(authorized_chat_request: AuthorizedChatRequest): if match.log_message: logger.warning(match.log_message) record_chat_request_rejection(authorized_chat_request, match.reason) + availability_reason = match.availability_reason() headers = ( {"Retry-After": match.retry_after} if match.retry_after else None ) @@ -362,6 +384,7 @@ async def get_completion(authorized_chat_request: AuthorizedChatRequest): snapshot=litellm_routing_snapshot, ) result = PrometheusResult.SUCCESS + availability_reason = AvailabilityReason.VALID_RESPONSE return data except HTTPException: raise @@ -371,3 +394,4 @@ async def get_completion(authorized_chat_request: AuthorizedChatRequest): record_completion_latency( authorized_chat_request, result, time.perf_counter() - start_time ) + record_chat_availability(authorized_chat_request, availability_reason) diff --git a/src/mlpa/core/errors.py b/src/mlpa/core/errors.py index 50b8a68..6ee521d 100644 --- a/src/mlpa/core/errors.py +++ b/src/mlpa/core/errors.py @@ -9,7 +9,7 @@ ERROR_CODE_REQUEST_TOO_LARGE, ERROR_CODE_UPSTREAM_RATE_LIMIT_EXCEEDED, ) -from mlpa.core.prometheus_metrics import PrometheusRejectionReason +from mlpa.core.prometheus_metrics import AvailabilityReason, PrometheusRejectionReason from mlpa.core.utils import ( is_context_window_error, is_invalid_model_name_error, @@ -18,6 +18,15 @@ is_rate_limit_error, ) +_REJECTION_TO_AVAILABILITY_REASON: dict[ + PrometheusRejectionReason, AvailabilityReason +] = { + PrometheusRejectionReason.BUDGET_EXCEEDED: AvailabilityReason.BUDGET_EXCEEDED, + PrometheusRejectionReason.PAYLOAD_TOO_LARGE: AvailabilityReason.PAYLOAD_TOO_LARGE, + PrometheusRejectionReason.INVALID_MODEL_NAME: AvailabilityReason.INVALID_MODEL_NAME, + PrometheusRejectionReason.INVALID_REQUEST: AvailabilityReason.INVALID_REQUEST, +} + @dataclass(frozen=True) class RejectionMatch: @@ -27,6 +36,15 @@ class RejectionMatch: retry_after: str | None = None log_message: str = "" + def availability_reason(self) -> AvailabilityReason: + # SIGNUP_CAP_EXCEEDED is recorded pre-completion, not via classify_upstream_error, + # so it is not in the mapping below. + if self.reason == PrometheusRejectionReason.RATE_LIMITED: + if self.error_code == ERROR_CODE_UPSTREAM_RATE_LIMIT_EXCEEDED: + return AvailabilityReason.RATE_LIMITED_UPSTREAM + return AvailabilityReason.RATE_LIMITED_PLATFORM + return _REJECTION_TO_AVAILABILITY_REASON[self.reason] + _RATE_LIMIT_REJECTION: dict[int, tuple[PrometheusRejectionReason, str, str]] = { ERROR_CODE_BUDGET_LIMIT_EXCEEDED: ( diff --git a/src/mlpa/core/metrics.py b/src/mlpa/core/metrics.py index 5c54846..f81783a 100644 --- a/src/mlpa/core/metrics.py +++ b/src/mlpa/core/metrics.py @@ -6,9 +6,11 @@ LitellmRoutingSnapshot, ) from mlpa.core.prometheus_metrics import ( + AvailabilityReason, PrometheusRejectionReason, PrometheusResult, TokenType, + availability_outcome_for, metrics, ) @@ -41,6 +43,33 @@ def record_search_request_rejection( metrics.search_request_rejections.labels(reason=reason, **_search_labels(req)).inc() +def record_chat_availability_for( + reason: AvailabilityReason, + *, + model: str, + service_type: str, + purpose: str, +) -> None: + metrics.chat_availability.labels( + outcome=availability_outcome_for(reason), + reason=reason, + model=model, + service_type=service_type, + purpose=purpose, + ).inc() + + +def record_chat_availability( + req: AuthorizedChatRequest, reason: AvailabilityReason +) -> None: + record_chat_availability_for( + reason, + model=req.model, + service_type=req.service_type, + purpose=req.purpose, + ) + + def record_completion_latency( req: AuthorizedChatRequest, result: PrometheusResult, diff --git a/src/mlpa/core/middleware/request_size.py b/src/mlpa/core/middleware/request_size.py index b9603e3..8a3d771 100644 --- a/src/mlpa/core/middleware/request_size.py +++ b/src/mlpa/core/middleware/request_size.py @@ -3,6 +3,8 @@ from mlpa.core.config import ERROR_CODE_REQUEST_TOO_LARGE, env from mlpa.core.logger import logger +from mlpa.core.metrics import record_chat_availability_for +from mlpa.core.prometheus_metrics import AvailabilityReason async def check_request_size_middleware(request: Request, call_next): @@ -19,6 +21,14 @@ async def check_request_size_middleware(request: Request, call_next): logger.warning( f"Request size {size} bytes exceeds maximum {env.MAX_REQUEST_SIZE_BYTES} bytes" ) + # `model` is in the request body, which we don't read here. + # We reject on the Content-Length header without parsing it. + record_chat_availability_for( + AvailabilityReason.PAYLOAD_TOO_LARGE, + model="", + service_type=request.headers.get("service-type") or "", + purpose=request.headers.get("purpose") or "", + ) return JSONResponse( status_code=413, content={"error": ERROR_CODE_REQUEST_TOO_LARGE}, diff --git a/src/mlpa/core/prometheus_metrics.py b/src/mlpa/core/prometheus_metrics.py index 745503d..757634d 100644 --- a/src/mlpa/core/prometheus_metrics.py +++ b/src/mlpa/core/prometheus_metrics.py @@ -19,6 +19,69 @@ class PrometheusRejectionReason(StrEnum): INVALID_REQUEST = "invalid_request" +class AvailabilityOutcome(StrEnum): + SUCCESS = "success" + FAILURE = "failure" + EXCLUDED = "excluded" + ABORT = "abort" + + +class AvailabilityReason(StrEnum): + # Strings shared with PrometheusRejectionReason are kept identical so the + # two counters reconcile. Keep them in sync when a rejection reason is added. + + # --- pre-completion reasons (recorded in the auth dependency and route body) --- + AUTH_REJECTED = "auth_rejected" # excluded + INVALID_AUTH_REQUEST = "invalid_auth_request" # excluded + INVALID_SERVICE_TYPE_FOR_MODEL = "invalid_service_type_for_model" # excluded + SIGNUP_CAP_EXCEEDED = "signup_cap_exceeded" # excluded + BLOCKED = "blocked" # excluded + PROVISIONING_FAILURE = "provisioning_failure" # failure + + # Defined but not yet emitted: auth backends normalize system failures to 401, + # making them indistinguishable from expected rejections. Capturing this + # properly requires a follow-on change to the auth backends themselves. + AUTH_SYSTEM_FAILURE = "auth_system_failure" # failure + + # --- completion-stage reasons (recorded inside stream_completion / get_completion) --- + VALID_RESPONSE = "valid_response" # success + UPSTREAM_ERROR = "upstream_error" # failure + EMPTY_RESPONSE = "empty_response" # failure + BUDGET_EXCEEDED = "budget_exceeded" # excluded + RATE_LIMITED_PLATFORM = "rate_limited_platform" # excluded + RATE_LIMITED_UPSTREAM = "rate_limited_upstream" # excluded + PAYLOAD_TOO_LARGE = "payload_too_large" # excluded + INVALID_MODEL_NAME = "invalid_model_name" # excluded + INVALID_REQUEST = "invalid_request" # excluded + CLIENT_DISCONNECT = "client_disconnect" # abort + + +_AVAILABILITY_OUTCOME_BY_REASON: dict[AvailabilityReason, AvailabilityOutcome] = { + AvailabilityReason.AUTH_REJECTED: AvailabilityOutcome.EXCLUDED, + AvailabilityReason.INVALID_AUTH_REQUEST: AvailabilityOutcome.EXCLUDED, + AvailabilityReason.INVALID_SERVICE_TYPE_FOR_MODEL: AvailabilityOutcome.EXCLUDED, + AvailabilityReason.SIGNUP_CAP_EXCEEDED: AvailabilityOutcome.EXCLUDED, + AvailabilityReason.BLOCKED: AvailabilityOutcome.EXCLUDED, + AvailabilityReason.PROVISIONING_FAILURE: AvailabilityOutcome.FAILURE, + AvailabilityReason.AUTH_SYSTEM_FAILURE: AvailabilityOutcome.FAILURE, + AvailabilityReason.VALID_RESPONSE: AvailabilityOutcome.SUCCESS, + AvailabilityReason.UPSTREAM_ERROR: AvailabilityOutcome.FAILURE, + AvailabilityReason.EMPTY_RESPONSE: AvailabilityOutcome.FAILURE, + AvailabilityReason.BUDGET_EXCEEDED: AvailabilityOutcome.EXCLUDED, + AvailabilityReason.RATE_LIMITED_PLATFORM: AvailabilityOutcome.EXCLUDED, + AvailabilityReason.RATE_LIMITED_UPSTREAM: AvailabilityOutcome.EXCLUDED, + AvailabilityReason.PAYLOAD_TOO_LARGE: AvailabilityOutcome.EXCLUDED, + AvailabilityReason.INVALID_MODEL_NAME: AvailabilityOutcome.EXCLUDED, + AvailabilityReason.INVALID_REQUEST: AvailabilityOutcome.EXCLUDED, + AvailabilityReason.CLIENT_DISCONNECT: AvailabilityOutcome.ABORT, +} + + +def availability_outcome_for(reason: AvailabilityReason) -> AvailabilityOutcome: + """Pure classifier: the availability outcome is fully determined by the reason.""" + return _AVAILABILITY_OUTCOME_BY_REASON[reason] + + class TokenType(StrEnum): PROMPT = "prompt" COMPLETION = "completion" @@ -113,6 +176,7 @@ class PrometheusMetrics: chat_tool_calls_per_completion: Histogram chat_requests_with_tools: Counter chat_request_rejections: Counter + chat_availability: Counter # search search_latency: Histogram @@ -274,6 +338,12 @@ def build_metrics(registry: CollectorRegistry = REGISTRY) -> PrometheusMetrics: ["reason", "model", "service_type", "purpose"], registry=registry, ), + chat_availability=Counter( + "mlpa_chat_availability_total", + "Interim availability outcomes for chat completions. outcome is success/failure/excluded/abort; reason is the bounded cause. Availability = success / (success + failure).", + ["outcome", "reason", "model", "service_type", "purpose"], + registry=registry, + ), search_latency=Histogram( "mlpa_search_latency_seconds", "Search latency in seconds.", diff --git a/src/mlpa/run.py b/src/mlpa/run.py index 6a98a06..b7baebf 100644 --- a/src/mlpa/run.py +++ b/src/mlpa/run.py @@ -26,9 +26,11 @@ ) from mlpa.core.http_client import close_http_client, get_http_client from mlpa.core.logger import logger, setup_logger +from mlpa.core.metrics import record_chat_availability from mlpa.core.middleware import register_middleware from mlpa.core.openapi import customize_openapi from mlpa.core.pg_services.services import app_attest_pg, litellm_pg +from mlpa.core.prometheus_metrics import AvailabilityReason from mlpa.core.routers.appattest import appattest_router from mlpa.core.routers.health import health_router from mlpa.core.routers.mock import mock_router @@ -207,6 +209,7 @@ async def chat_completion( ) user, _ = await get_or_create_user_for_completion(user_id, authorized_chat_request) if user.get("blocked"): + record_chat_availability(authorized_chat_request, AvailabilityReason.BLOCKED) raise HTTPException(status_code=403, detail={"error": "User is blocked."}) if authorized_chat_request.stream: diff --git a/src/tests/integration/test_request_size_limit.py b/src/tests/integration/test_request_size_limit.py index 63e5c59..ad4bb25 100644 --- a/src/tests/integration/test_request_size_limit.py +++ b/src/tests/integration/test_request_size_limit.py @@ -23,7 +23,7 @@ def test_request_size_under_limit(mocked_client_integration): assert response.status_code != 413 -def test_request_size_over_limit(mocked_client_integration): +def test_request_size_over_limit(mocked_client_integration, metrics_spy): """Test that requests over the size limit return 413.""" max_size = env.MAX_REQUEST_SIZE_BYTES oversized_size = max_size + 1 @@ -61,6 +61,17 @@ def test_request_size_over_limit(mocked_client_integration): assert response.status_code == 413 assert response.json() == {"error": 3} + assert ( + metrics_spy.value( + "chat_availability", + outcome="excluded", + reason="payload_too_large", + model="", + service_type="ai", + purpose="chat", + ) + == 1 + ) def test_request_size_exactly_at_limit(mocked_client_integration): diff --git a/src/tests/unit/test_availability.py b/src/tests/unit/test_availability.py new file mode 100644 index 0000000..ef794ab --- /dev/null +++ b/src/tests/unit/test_availability.py @@ -0,0 +1,96 @@ +import pytest + +from mlpa.core.config import ( + ERROR_CODE_BUDGET_LIMIT_EXCEEDED, + ERROR_CODE_INVALID_MODEL_NAME, + ERROR_CODE_INVALID_REQUEST, + ERROR_CODE_RATE_LIMIT_EXCEEDED, + ERROR_CODE_REQUEST_TOO_LARGE, + ERROR_CODE_UPSTREAM_RATE_LIMIT_EXCEEDED, +) +from mlpa.core.errors import RejectionMatch +from mlpa.core.prometheus_metrics import ( + AvailabilityOutcome, + AvailabilityReason, + PrometheusRejectionReason, + availability_outcome_for, +) + + +def test_every_availability_reason_maps_to_an_outcome(): + """Guard: a new AvailabilityReason cannot ship without an outcome mapping. + + Future pre-completion reasons added to AvailabilityReason must extend the + map too; this fails loudly (KeyError) if the map is not updated alongside the enum. + """ + for reason in AvailabilityReason: + assert isinstance(availability_outcome_for(reason), AvailabilityOutcome) + + +# SIGNUP_CAP_EXCEEDED is recorded pre-completion, not by classify_upstream_error, +# so it is intentionally outside the completion-stage availability mapping. +_PRE_COMPLETION_REJECTION_REASONS = {PrometheusRejectionReason.SIGNUP_CAP_EXCEEDED} + + +def test_every_completion_stage_rejection_reason_maps_to_excluded(): + """Guard: every rejection reason classify_upstream_error can produce must + resolve through availability_reason() to an excluded outcome. + + Iterating the enum (minus the pre-completion reasons) means a newly added + completion-stage rejection reason fails loudly here until it is mapped, or is + explicitly classified as pre-completion. + """ + for reason in PrometheusRejectionReason: + if reason in _PRE_COMPLETION_REJECTION_REASONS: + continue + match = RejectionMatch(reason=reason, error_code=0, http_status=400) + assert ( + availability_outcome_for(match.availability_reason()) + == AvailabilityOutcome.EXCLUDED + ) + + +# Pins the expected availability reason for each completion-stage rejection, +# including the own-vs-upstream rate-limit split keyed on error_code. This fixes +# the exact mappings; the completeness test above guards that the map covers +# every completion-stage rejection reason. +@pytest.mark.parametrize( + ("reason", "error_code", "expected"), + [ + ( + PrometheusRejectionReason.BUDGET_EXCEEDED, + ERROR_CODE_BUDGET_LIMIT_EXCEEDED, + AvailabilityReason.BUDGET_EXCEEDED, + ), + ( + PrometheusRejectionReason.RATE_LIMITED, + ERROR_CODE_RATE_LIMIT_EXCEEDED, + AvailabilityReason.RATE_LIMITED_PLATFORM, + ), + ( + PrometheusRejectionReason.RATE_LIMITED, + ERROR_CODE_UPSTREAM_RATE_LIMIT_EXCEEDED, + AvailabilityReason.RATE_LIMITED_UPSTREAM, + ), + ( + PrometheusRejectionReason.PAYLOAD_TOO_LARGE, + ERROR_CODE_REQUEST_TOO_LARGE, + AvailabilityReason.PAYLOAD_TOO_LARGE, + ), + ( + PrometheusRejectionReason.INVALID_MODEL_NAME, + ERROR_CODE_INVALID_MODEL_NAME, + AvailabilityReason.INVALID_MODEL_NAME, + ), + ( + PrometheusRejectionReason.INVALID_REQUEST, + ERROR_CODE_INVALID_REQUEST, + AvailabilityReason.INVALID_REQUEST, + ), + ], +) +def test_rejection_match_availability_reason(reason, error_code, expected): + match = RejectionMatch(reason=reason, error_code=error_code, http_status=400) + assert match.availability_reason() == expected + # All policy rejections are excluded from the availability ratio. + assert availability_outcome_for(expected) == AvailabilityOutcome.EXCLUDED diff --git a/src/tests/unit/test_completions.py b/src/tests/unit/test_completions.py index fb87613..c4147fe 100644 --- a/src/tests/unit/test_completions.py +++ b/src/tests/unit/test_completions.py @@ -28,7 +28,12 @@ LITELLM_HEADER_RESPONSE_DURATION_MS, env, ) -from mlpa.core.prometheus_metrics import PrometheusRejectionReason, PrometheusResult +from mlpa.core.prometheus_metrics import ( + AvailabilityOutcome, + AvailabilityReason, + PrometheusRejectionReason, + PrometheusResult, +) from tests.consts import SAMPLE_REQUEST, SUCCESSFUL_CHAT_RESPONSE @@ -54,6 +59,39 @@ def _rejection_count( ) +def _availability_count( + spy, + outcome: AvailabilityOutcome, + reason: AvailabilityReason, + req=SAMPLE_REQUEST, +) -> float: + return spy.value( + "chat_availability", + outcome=outcome, + reason=reason, + model=req.model, + service_type=req.service_type, + purpose=req.purpose, + ) + + +def _availability_total(spy, req=SAMPLE_REQUEST) -> float: + """Sum of all chat_availability samples for the request labels. + + Proves exactly one availability disposition was recorded, regardless of + which (outcome, reason) pair it landed on. Guards the policy-rejection path + against re-introducing a second emission alongside the correct one. + """ + return sum( + s.value + for s in spy.samples("chat_availability") + if s.name.endswith("_total") + and s.labels.get("model") == req.model + and s.labels.get("service_type") == req.service_type + and s.labels.get("purpose") == req.purpose + ) + + def _sample_litellm_response_headers(**overrides: str) -> httpx.Headers: base = { LITELLM_HEADER_MODEL_API_BASE: "https://api.together.xyz/v1", @@ -105,6 +143,7 @@ async def test_get_completion_success(mocker, metrics_spy): metrics_spy.assert_only( { + "chat_availability", "chat_tokens", "chat_tokens_per_request", "chat_completion_latency", @@ -131,6 +170,14 @@ async def test_get_completion_success(mocker, metrics_spy): == SUCCESSFUL_CHAT_RESPONSE["usage"]["completion_tokens"] ) assert _latency_count(metrics_spy, PrometheusResult.SUCCESS) == 1 + assert ( + _availability_count( + metrics_spy, + AvailabilityOutcome.SUCCESS, + AvailabilityReason.VALID_RESPONSE, + ) + == 1 + ) routing = _litellm_routing_label_base() assert ( @@ -256,8 +303,17 @@ async def test_get_completion_http_error(mocker, metrics_spy): assert exc_info.value.status_code == 500 assert exc_info.value.detail["error"] == "Upstream service returned an error" + assert ( + _availability_count( + metrics_spy, + AvailabilityOutcome.FAILURE, + AvailabilityReason.UPSTREAM_ERROR, + ) + == 1 + ) + assert _availability_total(metrics_spy) == 1 - metrics_spy.assert_only({"chat_completion_latency"}) + metrics_spy.assert_only({"chat_completion_latency", "chat_availability"}) assert _latency_count(metrics_spy, PrometheusResult.ERROR) == 1 @@ -273,7 +329,7 @@ async def test_get_completion_network_error(mocker, metrics_spy): assert exc_info.value.status_code == 502 assert exc_info.value.detail["error"] == "Connection timed out" - metrics_spy.assert_only({"chat_completion_latency"}) + metrics_spy.assert_only({"chat_completion_latency", "chat_availability"}) assert _latency_count(metrics_spy, PrometheusResult.ERROR) == 1 @@ -308,6 +364,7 @@ async def test_stream_completion_success( metrics_spy.assert_only( { + "chat_availability", "chat_completion_ttft", "chat_tokens", "chat_tokens_per_request", @@ -329,6 +386,14 @@ async def test_stream_completion_success( assert metrics_spy.value("chat_tokens", type="prompt", **chat_label_base) == 10 assert metrics_spy.value("chat_tokens", type="completion", **chat_label_base) == 25 assert _latency_count(metrics_spy, PrometheusResult.SUCCESS) == 1 + assert ( + _availability_count( + metrics_spy, + AvailabilityOutcome.SUCCESS, + AvailabilityReason.VALID_RESPONSE, + ) + == 1 + ) assert ( metrics_spy.histogram_count("chat_completion_ttft", model=SAMPLE_REQUEST.model) == 1 @@ -419,9 +484,22 @@ async def test_get_completion_budget_limit_exceeded_429(mocker, metrics_spy): assert exc_info.value.detail == {"error": 1} assert exc_info.value.headers == {"Retry-After": "86400"} - metrics_spy.assert_only({"chat_request_rejections", "chat_completion_latency"}) + metrics_spy.assert_only( + {"chat_request_rejections", "chat_completion_latency", "chat_availability"} + ) assert _rejection_count(metrics_spy, PrometheusRejectionReason.BUDGET_EXCEEDED) == 1 assert _latency_count(metrics_spy, PrometheusResult.ERROR) == 1 + # gap 1: same request, recorded as excluded (not failure) on availability + # even though the latency histogram above still counts it as result=error. + assert ( + _availability_count( + metrics_spy, + AvailabilityOutcome.EXCLUDED, + AvailabilityReason.BUDGET_EXCEEDED, + ) + == 1 + ) + assert _availability_total(metrics_spy) == 1 async def test_get_completion_budget_limit_exceeded_400(mocker, metrics_spy): @@ -453,7 +531,9 @@ async def test_get_completion_budget_limit_exceeded_400(mocker, metrics_spy): assert exc_info.value.detail == {"error": 1} assert exc_info.value.headers == {"Retry-After": "86400"} - metrics_spy.assert_only({"chat_request_rejections", "chat_completion_latency"}) + metrics_spy.assert_only( + {"chat_request_rejections", "chat_completion_latency", "chat_availability"} + ) assert _rejection_count(metrics_spy, PrometheusRejectionReason.BUDGET_EXCEEDED) == 1 @@ -485,8 +565,18 @@ async def test_get_completion_rate_limit_exceeded(mocker, metrics_spy): assert exc_info.value.status_code == 429 assert exc_info.value.detail == {"error": 2} assert exc_info.value.headers == {"Retry-After": "60"} + assert ( + _availability_count( + metrics_spy, + AvailabilityOutcome.EXCLUDED, + AvailabilityReason.RATE_LIMITED_PLATFORM, + ) + == 1 + ) - metrics_spy.assert_only({"chat_request_rejections", "chat_completion_latency"}) + metrics_spy.assert_only( + {"chat_request_rejections", "chat_completion_latency", "chat_availability"} + ) assert _rejection_count(metrics_spy, PrometheusRejectionReason.RATE_LIMITED) == 1 @@ -518,7 +608,7 @@ async def test_get_completion_400_non_rate_limit_error(mocker, metrics_spy): assert exc_info.value.status_code == 400 assert exc_info.value.detail == {"error": "Upstream service returned an error"} - metrics_spy.assert_only({"chat_completion_latency"}) + metrics_spy.assert_only({"chat_completion_latency", "chat_availability"}) async def test_get_completion_429_non_rate_limit_error(mocker, metrics_spy): @@ -543,7 +633,7 @@ async def test_get_completion_429_non_rate_limit_error(mocker, metrics_spy): assert exc_info.value.status_code == 429 assert exc_info.value.detail == {"error": "Upstream service returned an error"} - metrics_spy.assert_only({"chat_completion_latency"}) + metrics_spy.assert_only({"chat_completion_latency", "chat_availability"}) async def test_get_completion_upstream_rate_limit_error(mocker, metrics_spy): @@ -568,7 +658,17 @@ async def test_get_completion_upstream_rate_limit_error(mocker, metrics_spy): assert exc_info.value.status_code == 429 assert exc_info.value.detail == {"error": ERROR_CODE_UPSTREAM_RATE_LIMIT_EXCEEDED} - metrics_spy.assert_only({"chat_request_rejections", "chat_completion_latency"}) + assert ( + _availability_count( + metrics_spy, + AvailabilityOutcome.EXCLUDED, + AvailabilityReason.RATE_LIMITED_UPSTREAM, + ) + == 1 + ) + metrics_spy.assert_only( + {"chat_request_rejections", "chat_completion_latency", "chat_availability"} + ) assert _rejection_count(metrics_spy, PrometheusRejectionReason.RATE_LIMITED) == 1 assert _latency_count(metrics_spy, PrometheusResult.ERROR) == 1 @@ -599,7 +699,9 @@ async def test_get_completion_context_window_exceeded(mocker, metrics_spy): assert exc_info.value.detail == {"error": ERROR_CODE_REQUEST_TOO_LARGE} mock_logger.warning.assert_called_once() assert "Context window exceeded" in str(mock_logger.warning.call_args) - metrics_spy.assert_only({"chat_request_rejections", "chat_completion_latency"}) + metrics_spy.assert_only( + {"chat_request_rejections", "chat_completion_latency", "chat_availability"} + ) assert ( _rejection_count(metrics_spy, PrometheusRejectionReason.PAYLOAD_TOO_LARGE) == 1 ) @@ -634,7 +736,9 @@ async def test_get_completion_invalid_model_name(mocker, metrics_spy): assert exc_info.value.detail == {"error": ERROR_CODE_INVALID_MODEL_NAME} mock_logger.warning.assert_called_once() assert "Invalid model name" in str(mock_logger.warning.call_args) - metrics_spy.assert_only({"chat_request_rejections", "chat_completion_latency"}) + metrics_spy.assert_only( + {"chat_request_rejections", "chat_completion_latency", "chat_availability"} + ) assert ( _rejection_count(metrics_spy, PrometheusRejectionReason.INVALID_MODEL_NAME) == 1 ) @@ -668,7 +772,9 @@ async def test_get_completion_invalid_request_vertex(mocker, metrics_spy): assert exc_info.value.detail == {"error": ERROR_CODE_INVALID_REQUEST} mock_logger.warning.assert_called_once() assert "Invalid request" in str(mock_logger.warning.call_args) - metrics_spy.assert_only({"chat_request_rejections", "chat_completion_latency"}) + metrics_spy.assert_only( + {"chat_request_rejections", "chat_completion_latency", "chat_availability"} + ) assert _rejection_count(metrics_spy, PrometheusRejectionReason.INVALID_REQUEST) == 1 assert _latency_count(metrics_spy, PrometheusResult.ERROR) == 1 @@ -693,7 +799,7 @@ async def test_get_completion_429_invalid_json(mocker, metrics_spy): assert exc_info.value.status_code == 429 assert exc_info.value.detail == {"error": "Upstream service returned an error"} - metrics_spy.assert_only({"chat_completion_latency"}) + metrics_spy.assert_only({"chat_completion_latency", "chat_availability"}) async def test_stream_completion_budget_limit_exceeded_429( @@ -727,9 +833,20 @@ async def test_stream_completion_budget_limit_exceeded_429( ) mock_logger.warning.assert_called_once() assert "Budget limit exceeded" in str(mock_logger.warning.call_args) - metrics_spy.assert_only({"chat_request_rejections", "chat_completion_latency"}) + metrics_spy.assert_only( + {"chat_request_rejections", "chat_completion_latency", "chat_availability"} + ) assert _rejection_count(metrics_spy, PrometheusRejectionReason.BUDGET_EXCEEDED) == 1 assert _latency_count(metrics_spy, PrometheusResult.ERROR) == 1 + assert ( + _availability_count( + metrics_spy, + AvailabilityOutcome.EXCLUDED, + AvailabilityReason.BUDGET_EXCEEDED, + ) + == 1 + ) + assert _availability_total(metrics_spy) == 1 async def test_stream_completion_budget_limit_exceeded_400( @@ -764,7 +881,9 @@ async def test_stream_completion_budget_limit_exceeded_400( ) mock_logger.warning.assert_called_once() assert "Budget limit exceeded" in str(mock_logger.warning.call_args) - metrics_spy.assert_only({"chat_request_rejections", "chat_completion_latency"}) + metrics_spy.assert_only( + {"chat_request_rejections", "chat_completion_latency", "chat_availability"} + ) assert _rejection_count(metrics_spy, PrometheusRejectionReason.BUDGET_EXCEEDED) == 1 assert _latency_count(metrics_spy, PrometheusResult.ERROR) == 1 @@ -801,7 +920,9 @@ async def test_stream_completion_rate_limit_exceeded( ) mock_logger.warning.assert_called_once() assert "Rate limit exceeded" in str(mock_logger.warning.call_args) - metrics_spy.assert_only({"chat_request_rejections", "chat_completion_latency"}) + metrics_spy.assert_only( + {"chat_request_rejections", "chat_completion_latency", "chat_availability"} + ) assert _rejection_count(metrics_spy, PrometheusRejectionReason.RATE_LIMITED) == 1 assert _latency_count(metrics_spy, PrometheusResult.ERROR) == 1 @@ -833,7 +954,9 @@ async def test_stream_completion_context_window_exceeded( ) mock_logger.warning.assert_called_once() assert "Context window exceeded" in str(mock_logger.warning.call_args) - metrics_spy.assert_only({"chat_request_rejections", "chat_completion_latency"}) + metrics_spy.assert_only( + {"chat_request_rejections", "chat_completion_latency", "chat_availability"} + ) assert ( _rejection_count(metrics_spy, PrometheusRejectionReason.PAYLOAD_TOO_LARGE) == 1 ) @@ -869,7 +992,9 @@ async def test_stream_completion_invalid_model_name( ) mock_logger.warning.assert_called_once() assert "Invalid model name" in str(mock_logger.warning.call_args) - metrics_spy.assert_only({"chat_request_rejections", "chat_completion_latency"}) + metrics_spy.assert_only( + {"chat_request_rejections", "chat_completion_latency", "chat_availability"} + ) assert ( _rejection_count(metrics_spy, PrometheusRejectionReason.INVALID_MODEL_NAME) == 1 ) @@ -904,7 +1029,9 @@ async def test_stream_completion_invalid_request_vertex( ) mock_logger.warning.assert_called_once() assert "Invalid request" in str(mock_logger.warning.call_args) - metrics_spy.assert_only({"chat_request_rejections", "chat_completion_latency"}) + metrics_spy.assert_only( + {"chat_request_rejections", "chat_completion_latency", "chat_availability"} + ) assert _rejection_count(metrics_spy, PrometheusRejectionReason.INVALID_REQUEST) == 1 assert _latency_count(metrics_spy, PrometheusResult.ERROR) == 1 @@ -941,7 +1068,7 @@ async def test_stream_completion_400_non_rate_limit_error( == b'data: {"code": 400, "error": "Upstream service returned an error"}\n\n' ) mock_logger.error.assert_called_once() - metrics_spy.assert_only({"chat_completion_latency"}) + metrics_spy.assert_only({"chat_completion_latency", "chat_availability"}) assert _latency_count(metrics_spy, PrometheusResult.ERROR) == 1 @@ -971,7 +1098,7 @@ async def test_stream_completion_429_non_rate_limit_error( == b'data: {"code": 429, "error": "Upstream service returned an error"}\n\n' ) mock_logger.error.assert_called_once() - metrics_spy.assert_only({"chat_completion_latency"}) + metrics_spy.assert_only({"chat_completion_latency", "chat_availability"}) assert _latency_count(metrics_spy, PrometheusResult.ERROR) == 1 @@ -995,7 +1122,9 @@ async def test_stream_completion_upstream_rate_limit_error( assert received_chunks == [ f'data: {{"error": {ERROR_CODE_UPSTREAM_RATE_LIMIT_EXCEEDED}}}\n\n'.encode() ] - metrics_spy.assert_only({"chat_request_rejections", "chat_completion_latency"}) + metrics_spy.assert_only( + {"chat_request_rejections", "chat_completion_latency", "chat_availability"} + ) assert _rejection_count(metrics_spy, PrometheusRejectionReason.RATE_LIMITED) == 1 assert _latency_count(metrics_spy, PrometheusResult.ERROR) == 1 @@ -1023,7 +1152,7 @@ async def test_stream_completion_429_invalid_json( == b'data: {"code": 429, "error": "Upstream service returned an error"}\n\n' ) mock_logger.error.assert_called_once() - metrics_spy.assert_only({"chat_completion_latency"}) + metrics_spy.assert_only({"chat_completion_latency", "chat_availability"}) assert _latency_count(metrics_spy, PrometheusResult.ERROR) == 1 @@ -1043,7 +1172,7 @@ async def test_stream_completion_exception_after_streaming_started( assert len(received_chunks) == 1 assert b"error" in received_chunks[0] - metrics_spy.assert_only({"chat_completion_latency"}) + metrics_spy.assert_only({"chat_completion_latency", "chat_availability"}) assert _latency_count(metrics_spy, PrometheusResult.ERROR) == 1 @@ -1248,6 +1377,15 @@ async def test_stream_sends_error_sse_on_empty_200_response( ) assert b'"error"' in received[0], "Chunk must be an error SSE frame" _assert_error_latency(metrics_spy) + assert ( + _availability_count( + metrics_spy, + AvailabilityOutcome.FAILURE, + AvailabilityReason.EMPTY_RESPONSE, + ) + == 1 + ) + assert _availability_total(metrics_spy) == 1 async def test_stream_completion_client_disconnect_records_abort( @@ -1280,6 +1418,15 @@ async def _aiter_bytes(): assert _latency_count(metrics_spy, PrometheusResult.ABORT) == 1 assert _latency_count(metrics_spy, PrometheusResult.ERROR) == 0 + assert ( + _availability_count( + metrics_spy, + AvailabilityOutcome.ABORT, + AvailabilityReason.CLIENT_DISCONNECT, + ) + == 1 + ) + assert _availability_total(metrics_spy) == 1 async def test_stream_uses_httpx_timeout_object_preserving_pool_timeout( diff --git a/src/tests/unit/test_precompletion_availability.py b/src/tests/unit/test_precompletion_availability.py new file mode 100644 index 0000000..454b08d --- /dev/null +++ b/src/tests/unit/test_precompletion_availability.py @@ -0,0 +1,400 @@ +"""Pre-completion availability instrumentation. + +Covers the dispositions the completion-stage counter cannot see: the chat auth +dependency (service-type / model, auth rejection, client-error auth request, and +the statuses that are intentionally not recorded) and the route-body sites +(signup cap, provisioning failure, blocked). The completion tests bypass the auth +dependency, so the wrapper behavior is only exercised here. +""" + +import pytest +from fastapi import HTTPException, Request + +from mlpa import run as run_module +from mlpa.core.auth import authorize as authorize_module +from mlpa.core.classes import ( + AuthorizedChatRequest, + AuthorizedSearchRequest, + ChatRequest, +) +from mlpa.core.completions import get_or_create_user_for_completion +from mlpa.core.config import ERROR_CODE_MAX_USERS_REACHED +from mlpa.core.prometheus_metrics import ( + AvailabilityOutcome, + AvailabilityReason, +) +from tests.consts import SAMPLE_REQUEST + +# A model/service-type pair that is valid together, so the wrapper passes its own +# check and reaches the shared auth call. +_VALID_MODEL = "openai/gpt-4o" +_AI = authorize_module.ServiceType.ai + + +def _make_request() -> Request: + async def receive() -> dict: + return {"type": "http.request", "body": b"", "more_body": False} + + return Request( + {"type": "http", "method": "POST", "path": "/", "headers": []}, receive + ) + + +def _chat_request(model: str = _VALID_MODEL) -> ChatRequest: + return ChatRequest(model=model, messages=[{"role": "user", "content": "hi"}]) + + +def _availability( + spy, + outcome: AvailabilityOutcome, + reason: AvailabilityReason, + *, + model: str, + service_type: str, + purpose: str = "", +) -> float: + return spy.value( + "chat_availability", + outcome=outcome, + reason=reason, + model=model, + service_type=service_type, + purpose=purpose, + ) + + +def _availability_total(spy) -> float: + """Sum of every chat_availability sample. Proves exactly one disposition.""" + return sum( + s.value for s in spy.samples("chat_availability") if s.name.endswith("_total") + ) + + +def _rejection_total(spy) -> float: + return sum( + s.value + for s in spy.samples("chat_request_rejections") + if s.name.endswith("_total") + ) + + +# --- auth dependency (authorize_chat_request) --------------------------------- + + +async def test_wrapper_success_records_no_auth_stage_availability(mocker, metrics_spy): + mocker.patch.object( + authorize_module, + "fxa_auth", + mocker.AsyncMock(return_value={"user": "user-123"}), + ) + + result = await authorize_module.authorize_chat_request( + request=_make_request(), + chat_request=_chat_request(), + authorization="Bearer token", + service_type=_AI, + purpose="chat", + ) + + assert isinstance(result, AuthorizedChatRequest) + # Auth success is finalized later at completion, never at the auth stage. + assert "chat_availability" not in metrics_spy.touched() + + +async def test_wrapper_invalid_service_type_records_excluded(metrics_spy): + with pytest.raises(HTTPException) as exc_info: + await authorize_module.authorize_chat_request( + request=_make_request(), + chat_request=ChatRequest(model="exa", messages=[]), + authorization="Bearer token", + service_type=_AI, + purpose="chat", + ) + + assert exc_info.value.status_code == 400 + assert ( + _availability( + metrics_spy, + AvailabilityOutcome.EXCLUDED, + AvailabilityReason.INVALID_SERVICE_TYPE_FOR_MODEL, + model="exa", + service_type="ai", + purpose="chat", + ) + == 1 + ) + assert _availability_total(metrics_spy) == 1 + + +async def test_wrapper_invalid_purpose_records_invalid_auth_request(metrics_spy): + # A real shared-call 400 from purpose validation maps to the coarse reason. + with pytest.raises(HTTPException) as exc_info: + await authorize_module.authorize_chat_request( + request=_make_request(), + chat_request=_chat_request(), + authorization="Bearer token", + service_type=_AI, + purpose="definitely-not-a-valid-purpose", + ) + + assert exc_info.value.status_code == 400 + assert ( + _availability( + metrics_spy, + AvailabilityOutcome.EXCLUDED, + AvailabilityReason.INVALID_AUTH_REQUEST, + model=_VALID_MODEL, + service_type="ai", + purpose="definitely-not-a-valid-purpose", + ) + == 1 + ) + assert _availability_total(metrics_spy) == 1 + + +async def test_wrapper_shared_call_401_records_auth_rejected(mocker, metrics_spy): + raised = HTTPException(status_code=401, detail="Invalid FxA auth") + mocker.patch.object( + authorize_module, + "_authorize_common_request", + mocker.AsyncMock(side_effect=raised), + ) + + with pytest.raises(HTTPException) as exc_info: + await authorize_module.authorize_chat_request( + request=_make_request(), + chat_request=_chat_request(), + authorization="Bearer token", + service_type=_AI, + purpose="chat", + ) + + assert exc_info.value is raised # re-raised unchanged + assert ( + _availability( + metrics_spy, + AvailabilityOutcome.EXCLUDED, + AvailabilityReason.AUTH_REJECTED, + model=_VALID_MODEL, + service_type="ai", + purpose="chat", + ) + == 1 + ) + assert _availability_total(metrics_spy) == 1 + + +async def test_wrapper_shared_call_400_records_invalid_auth_request( + mocker, metrics_spy +): + # Pins that the wrapper maps any shared-call 400 to invalid_auth_request, + # regardless of source. The non-purpose source (malformed App Attest base64 + # decoded in app_attest_auth, which raises before its try) was confirmed by + # code inspection; this test proves the mapping, not that path's reachability. + raised = HTTPException(status_code=400, detail={"challenge_b64": "Invalid Base64"}) + mocker.patch.object( + authorize_module, + "_authorize_common_request", + mocker.AsyncMock(side_effect=raised), + ) + + with pytest.raises(HTTPException) as exc_info: + await authorize_module.authorize_chat_request( + request=_make_request(), + chat_request=_chat_request(), + authorization="Bearer token", + service_type=_AI, + purpose="chat", + ) + + assert exc_info.value is raised # re-raised unchanged + assert ( + _availability( + metrics_spy, + AvailabilityOutcome.EXCLUDED, + AvailabilityReason.INVALID_AUTH_REQUEST, + model=_VALID_MODEL, + service_type="ai", + purpose="chat", + ) + == 1 + ) + assert _availability_total(metrics_spy) == 1 + + +async def test_wrapper_shared_call_500_records_nothing(mocker, metrics_spy): + # App Attest's explicit 500 is an HTTPException, so it is caught, but the + # wrapper records only 401/400 and re-raises everything else unrecorded. + raised = HTTPException( + status_code=500, detail="Server error during App Attest auth" + ) + mocker.patch.object( + authorize_module, + "_authorize_common_request", + mocker.AsyncMock(side_effect=raised), + ) + + with pytest.raises(HTTPException) as exc_info: + await authorize_module.authorize_chat_request( + request=_make_request(), + chat_request=_chat_request(), + authorization="Bearer token", + service_type=_AI, + purpose="chat", + ) + + assert exc_info.value is raised # re-raised unchanged + assert "chat_availability" not in metrics_spy.touched() + + +async def test_wrapper_shared_call_non_http_exception_records_nothing( + mocker, metrics_spy +): + # Non-HTTPException auth-path errors are not caught and propagate unrecorded. + raised = RuntimeError("bare auth-path failure") + mocker.patch.object( + authorize_module, + "_authorize_common_request", + mocker.AsyncMock(side_effect=raised), + ) + + with pytest.raises(RuntimeError): + await authorize_module.authorize_chat_request( + request=_make_request(), + chat_request=_chat_request(), + authorization="Bearer token", + service_type=_AI, + purpose="chat", + ) + + assert "chat_availability" not in metrics_spy.touched() + + +# --- route body: get_or_create_user_for_completion ---------------------------- + + +async def test_signup_cap_records_excluded_alongside_rejection(mocker, metrics_spy): + mocker.patch( + "mlpa.core.completions.get_or_create_user", + mocker.AsyncMock( + side_effect=HTTPException( + status_code=403, detail={"error": ERROR_CODE_MAX_USERS_REACHED} + ) + ), + ) + + with pytest.raises(HTTPException) as exc_info: + await get_or_create_user_for_completion(SAMPLE_REQUEST.user, SAMPLE_REQUEST) + + assert exc_info.value.status_code == 403 + assert ( + _availability( + metrics_spy, + AvailabilityOutcome.EXCLUDED, + AvailabilityReason.SIGNUP_CAP_EXCEEDED, + model=SAMPLE_REQUEST.model, + service_type=SAMPLE_REQUEST.service_type, + purpose=SAMPLE_REQUEST.purpose, + ) + == 1 + ) + assert _availability_total(metrics_spy) == 1 + # The existing rejection metric is still recorded. + assert _rejection_total(metrics_spy) == 1 + + +async def test_provisioning_failure_records_failure(mocker, metrics_spy): + mocker.patch( + "mlpa.core.completions.get_or_create_user", + mocker.AsyncMock( + side_effect=HTTPException( + status_code=500, detail={"error": "Error fetching user info"} + ) + ), + ) + + with pytest.raises(HTTPException) as exc_info: + await get_or_create_user_for_completion(SAMPLE_REQUEST.user, SAMPLE_REQUEST) + + assert exc_info.value.status_code == 500 + assert ( + _availability( + metrics_spy, + AvailabilityOutcome.FAILURE, + AvailabilityReason.PROVISIONING_FAILURE, + model=SAMPLE_REQUEST.model, + service_type=SAMPLE_REQUEST.service_type, + purpose=SAMPLE_REQUEST.purpose, + ) + == 1 + ) + assert _availability_total(metrics_spy) == 1 + + +async def test_non_signup_non_5xx_records_nothing(mocker, metrics_spy): + # The strict gate leaves a non-signup-cap, non-5xx disposition unrecorded so a + # client-side 4xx is not counted as an availability failure. + mocker.patch( + "mlpa.core.completions.get_or_create_user", + mocker.AsyncMock( + side_effect=HTTPException( + status_code=400, detail={"error": "Invalid user_id format"} + ) + ), + ) + + with pytest.raises(HTTPException) as exc_info: + await get_or_create_user_for_completion(SAMPLE_REQUEST.user, SAMPLE_REQUEST) + + assert exc_info.value.status_code == 400 + assert "chat_availability" not in metrics_spy.touched() + + +async def test_search_request_records_no_chat_availability(mocker, metrics_spy): + search_req = AuthorizedSearchRequest( + user="user-1:search", service_type="search", query="q", max_results=2 + ) + mocker.patch( + "mlpa.core.completions.get_or_create_user", + mocker.AsyncMock( + side_effect=HTTPException( + status_code=500, detail={"error": "Error fetching user info"} + ) + ), + ) + + with pytest.raises(HTTPException): + await get_or_create_user_for_completion(search_req.user, search_req) + + assert "chat_availability" not in metrics_spy.touched() + + +# --- route body: blocked user ------------------------------------------------- + + +async def test_blocked_user_records_blocked(mocker, metrics_spy): + mocker.patch.object( + run_module, + "get_or_create_user_for_completion", + mocker.AsyncMock(return_value=({"blocked": True}, False)), + ) + + with pytest.raises(HTTPException) as exc_info: + await run_module.chat_completion( + request=_make_request(), + authorized_chat_request=SAMPLE_REQUEST, + ) + + assert exc_info.value.status_code == 403 + assert ( + _availability( + metrics_spy, + AvailabilityOutcome.EXCLUDED, + AvailabilityReason.BLOCKED, + model=SAMPLE_REQUEST.model, + service_type=SAMPLE_REQUEST.service_type, + purpose=SAMPLE_REQUEST.purpose, + ) + == 1 + ) + assert _availability_total(metrics_spy) == 1