diff --git a/src/mlpa/core/classes.py b/src/mlpa/core/classes.py index b112f6d..d5d1437 100644 --- a/src/mlpa/core/classes.py +++ b/src/mlpa/core/classes.py @@ -88,7 +88,34 @@ class PlayIntegrityTokenResponse(BaseModel): expires_in: int -class AuthorizedChatRequest(ChatRequest): +class AuthorizedRequestLogMixin: + """Shared structured log fields for authorized requests. + + Bound into the loguru contextvar via ``logger.contextualize(**log_fields)`` + in the proxy handlers so every log line emitted while serving the request + (including mid-stream errors) carries them as queryable ``record.extra.*`` + fields, rather than concatenated into the message string. + """ + + user: str + service_type: str + purpose: str + + @property + def log_fields(self) -> dict[str, str]: + # `model` only exists on chat requests, not search requests. + fields = { + "user": self.user, + "service_type": self.service_type, + "purpose": self.purpose or "-", + } + model = getattr(self, "model", None) + if model: + fields["model"] = model + return fields + + +class AuthorizedChatRequest(ChatRequest, AuthorizedRequestLogMixin): user: str service_type: str purpose: str = ( @@ -101,7 +128,7 @@ class SearchRequest(BaseModel): max_results: int = Field(ge=1, le=10) -class AuthorizedSearchRequest(SearchRequest): +class AuthorizedSearchRequest(SearchRequest, AuthorizedRequestLogMixin): user: str service_type: str purpose: str = ( diff --git a/src/mlpa/core/completions.py b/src/mlpa/core/completions.py index 7a5a8ec..fbd01a6 100644 --- a/src/mlpa/core/completions.py +++ b/src/mlpa/core/completions.py @@ -71,6 +71,26 @@ async def get_or_create_user_for_completion( async def stream_completion( authorized_chat_request: AuthorizedChatRequest, request: Request +): + """Bind request log fields onto the loguru contextvar, then stream. + + The contextvar must be held *inside* the generator (not the route handler) + so it stays active while Starlette iterates the SSE body — otherwise + mid-stream errors would lose the fields (the streaming blind spot). + """ + with logger.contextualize(**authorized_chat_request.log_fields): + gen = _stream_completion(authorized_chat_request, request) + try: + async for chunk in gen: + yield chunk + finally: + # Forward close/GeneratorExit into the inner generator so its + # client-disconnect handling runs while the fields are still bound. + await gen.aclose() + + +async def _stream_completion( + authorized_chat_request: AuthorizedChatRequest, request: Request ): """ Proxies a streaming request to LiteLLM. @@ -297,6 +317,12 @@ async def _read_next_chunk( async def get_completion(authorized_chat_request: AuthorizedChatRequest): + """Bind request log fields onto the loguru contextvar, then proxy.""" + with logger.contextualize(**authorized_chat_request.log_fields): + return await _get_completion(authorized_chat_request) + + +async def _get_completion(authorized_chat_request: AuthorizedChatRequest): """ Proxies a non-streaming request to LiteLLM. """ diff --git a/src/mlpa/core/logger.py b/src/mlpa/core/logger.py index ce708bc..d6fc58f 100644 --- a/src/mlpa/core/logger.py +++ b/src/mlpa/core/logger.py @@ -129,8 +129,13 @@ async def _wrapper(self, *args, **kwargs): ) try: response = await original(self, *args, **kwargs) - except Exception: - logger.error(f"HTTPX {method_name.upper()=} request failed for {url=}") + except Exception as exc: + # Include the exception type + repr: transport failures often + # have an empty str(), so the bare URL alone was undiagnosable. + logger.error( + f"HTTPX {method_name.upper()=} request failed for {url=}: " + f"{type(exc).__name__}: {exc!r}" + ) raise logger.debug( f"HTTPX {method_name.upper()} response <- {url=} {response.status_code=}", diff --git a/src/mlpa/core/search.py b/src/mlpa/core/search.py index b2f8abf..f76f3ec 100644 --- a/src/mlpa/core/search.py +++ b/src/mlpa/core/search.py @@ -18,6 +18,12 @@ async def get_search(authorized_search_request: AuthorizedSearchRequest): + """Bind request log fields onto the loguru contextvar, then proxy.""" + with logger.contextualize(**authorized_search_request.log_fields): + return await _get_search(authorized_search_request) + + +async def _get_search(authorized_search_request: AuthorizedSearchRequest): start_time = time.perf_counter() body = sanitize_request_body( authorized_search_request.model_dump(exclude_none=True) diff --git a/src/mlpa/core/utils.py b/src/mlpa/core/utils.py index 477d6e2..df4376f 100644 --- a/src/mlpa/core/utils.py +++ b/src/mlpa/core/utils.py @@ -245,6 +245,11 @@ def raise_and_log( HTTPException with the chosen status code and a sanitized error message. If the upstream error body contains a nested error message, it is extracted so clients receive the actual upstream detail in debug mode. (dev environment only) + + Request-identifying fields (user / service_type / model / purpose) are not + passed here; they are bound on the loguru contextvar by the proxy handler + (``logger.contextualize(**req.log_fields)``) so they ride along in + ``record.extra`` automatically. """ response = getattr(e, "response", None) error_text = response.text if response is not None else "" @@ -264,7 +269,16 @@ def raise_and_log( except (json.JSONDecodeError, AttributeError, TypeError): pass status_code = response_code or getattr(response, "status_code", None) or 500 - logger.error(f"{response_text_prefix or GENERIC_UPSTREAM_ERROR}: {detail_text}") + # Transport errors (httpx ConnectError / RemoteProtocolError / ReadError / + # timeouts) carry no `.response` and frequently stringify to "", which is + # why these logs used to read "Failed to proxy request: " with no detail. + # Fall back to the exception class name + repr, and attach the traceback + # (logger.opt(exception=...)) so jsonPayload.record.exception is populated. + exc_type = type(e).__name__ + logged_detail = detail_text or repr(e) + logger.opt(exception=e).error( + f"{response_text_prefix or GENERIC_UPSTREAM_ERROR}: {exc_type}: {logged_detail}" + ) if stream: error_msg = detail_text if env.MLPA_DEBUG else GENERIC_UPSTREAM_ERROR payload = {"code": status_code, "error": error_msg} diff --git a/src/tests/unit/test_completions.py b/src/tests/unit/test_completions.py index fb87613..c46d57b 100644 --- a/src/tests/unit/test_completions.py +++ b/src/tests/unit/test_completions.py @@ -28,10 +28,36 @@ LITELLM_HEADER_RESPONSE_DURATION_MS, env, ) +from mlpa.core.logger import logger as loguru_logger from mlpa.core.prometheus_metrics import PrometheusRejectionReason, PrometheusResult from tests.consts import SAMPLE_REQUEST, SUCCESSFUL_CHAT_RESPONSE +@contextlib.contextmanager +def _capture_logs(): + """Capture raw loguru records emitted within the block. + + Each captured item is a loguru ``Message`` whose ``.record`` dict exposes + ``message`` / ``level`` / ``exception`` / ``extra`` — lets tests assert on + log content, attached tracebacks, and contextvar-bound fields. + """ + records = [] + sink_id = loguru_logger.add(records.append, level="DEBUG", format="{message}") + try: + yield records + finally: + loguru_logger.remove(sink_id) + + +def _proxy_error_records(records): + return [ + item.record + for item in records + if item.record["level"].name == "ERROR" + and "Failed to proxy request" in item.record["message"] + ] + + def _latency_count(spy, result: PrometheusResult, req=SAMPLE_REQUEST) -> float: return spy.histogram_count( "chat_completion_latency", @@ -940,7 +966,7 @@ async def test_stream_completion_400_non_rate_limit_error( received_chunks[0] == b'data: {"code": 400, "error": "Upstream service returned an error"}\n\n' ) - mock_logger.error.assert_called_once() + mock_logger.opt.return_value.error.assert_called_once() metrics_spy.assert_only({"chat_completion_latency"}) assert _latency_count(metrics_spy, PrometheusResult.ERROR) == 1 @@ -970,7 +996,7 @@ async def test_stream_completion_429_non_rate_limit_error( received_chunks[0] == b'data: {"code": 429, "error": "Upstream service returned an error"}\n\n' ) - mock_logger.error.assert_called_once() + mock_logger.opt.return_value.error.assert_called_once() metrics_spy.assert_only({"chat_completion_latency"}) assert _latency_count(metrics_spy, PrometheusResult.ERROR) == 1 @@ -1022,7 +1048,7 @@ async def test_stream_completion_429_invalid_json( received_chunks[0] == b'data: {"code": 429, "error": "Upstream service returned an error"}\n\n' ) - mock_logger.error.assert_called_once() + mock_logger.opt.return_value.error.assert_called_once() metrics_spy.assert_only({"chat_completion_latency"}) assert _latency_count(metrics_spy, PrometheusResult.ERROR) == 1 @@ -1373,3 +1399,73 @@ async def test_get_completion_sanitizes_response_surrogates(mocker): assert "\ud83e" not in data["choices"][0]["message"]["content"] assert data["choices"][0]["message"]["content"].startswith("done ") _httpx_encode_json(data) # must not raise + + +async def test_get_completion_empty_message_transport_error_is_diagnosable(mocker): + """Regression for the prod 502s that logged a bare ``Failed to proxy request:``. + + A transport error with no ``.response`` and an empty ``str()`` (e.g. + ``RemoteProtocolError("")``) must still produce a diagnosable ERROR line: + the exception type + repr in the message, the traceback attached, and the + request-identifying fields bound via ``contextualize(**log_fields)``. + """ + mock_client = AsyncMock() + mock_client.post.side_effect = httpx.RemoteProtocolError("") + mocker.patch("mlpa.core.completions.get_http_client", return_value=mock_client) + mocker.patch.object(env, "MLPA_DEBUG", False) + + with _capture_logs() as records: + with pytest.raises(HTTPException) as exc_info: + await get_completion(SAMPLE_REQUEST) + + assert exc_info.value.status_code == 502 + + proxy_errors = _proxy_error_records(records) + assert len(proxy_errors) == 1 + rec = proxy_errors[0] + # Exception type is named, and the message is NOT the old blank form. + assert "RemoteProtocolError" in rec["message"] + assert not rec["message"].rstrip().endswith("Failed to proxy request:") + # Traceback attached via logger.opt(exception=e). + assert rec["exception"] is not None + assert rec["exception"].type is httpx.RemoteProtocolError + # Request fields bound on the record (queryable as record.extra.*). + assert rec["extra"]["user"] == SAMPLE_REQUEST.user + assert rec["extra"]["model"] == SAMPLE_REQUEST.model + assert rec["extra"]["service_type"] == SAMPLE_REQUEST.service_type + + +async def test_stream_mid_stream_error_binds_request_fields( + mocker, mock_request, metrics_spy +): + """Streaming blind-spot regression. + + An error raised mid-SSE-stream (after MLPA already returned 200) must still + log with the request fields bound — proving the ``contextualize`` scope set + inside ``stream_completion`` survives generator iteration, unlike the + middleware scope which has already exited by the time the body iterates. + """ + role_chunk = ( + b'data: {"choices":[{"delta":{"role":"assistant","content":null}}]}\n\n' + ) + + async def _failing_aiter_bytes(): + yield role_chunk + raise httpx.RemoteProtocolError("") + + _patch_mock_stream_client(mocker, _failing_aiter_bytes) + mocker.patch.object(env, "MLPA_DEBUG", False) + + with _capture_logs() as records: + received = [c async for c in stream_completion(SAMPLE_REQUEST, mock_request)] + + assert any(b'"error"' in chunk for chunk in received) + + proxy_errors = _proxy_error_records(records) + assert len(proxy_errors) == 1 + rec = proxy_errors[0] + assert "RemoteProtocolError" in rec["message"] + assert rec["exception"] is not None + assert rec["extra"]["user"] == SAMPLE_REQUEST.user + assert rec["extra"]["model"] == SAMPLE_REQUEST.model + assert rec["extra"]["service_type"] == SAMPLE_REQUEST.service_type diff --git a/src/tests/unit/test_logger.py b/src/tests/unit/test_logger.py new file mode 100644 index 0000000..b54be47 --- /dev/null +++ b/src/tests/unit/test_logger.py @@ -0,0 +1,58 @@ +import contextlib + +import httpx +import pytest +from loguru import logger as loguru_logger + +from mlpa.core.config import env +from mlpa.core.logger import _enable_httpx_logging + + +@contextlib.contextmanager +def _capture_logs(): + records = [] + sink_id = loguru_logger.add(records.append, level="DEBUG", format="{message}") + try: + yield records + finally: + loguru_logger.remove(sink_id) + + +async def test_httpx_wrapper_logs_exc_type_on_transport_failure(mocker): + """The HTTPX logging wrapper must name the exception type + repr on failure. + + Transport errors often stringify to ``""``, so the bare URL alone (the old + log) was undiagnosable. This is the first line of the 502 "triple". + """ + mocker.patch.object(env, "HTTPX_LOGGING", True) + + # Save and restore the real httpx methods so the global patch never leaks + # into other tests, regardless of whether logging was already enabled. + before_get = httpx.AsyncClient.get + before_post = httpx.AsyncClient.post + try: + _enable_httpx_logging() + + def _raise(request: httpx.Request) -> httpx.Response: + raise httpx.ConnectError("", request=request) + + transport = httpx.MockTransport(_raise) + url = "http://litellm:8000/v1/chat/completions" + with _capture_logs() as records: + async with httpx.AsyncClient(transport=transport) as client: + with pytest.raises(httpx.ConnectError): + await client.post(url) + finally: + httpx.AsyncClient.get = before_get + httpx.AsyncClient.post = before_post + + failures = [ + item.record["message"] + for item in records + if item.record["level"].name == "ERROR" + and "request failed" in item.record["message"] + ] + assert len(failures) == 1 + msg = failures[0] + assert "ConnectError" in msg + assert url in msg