Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 29 additions & 2 deletions src/mlpa/core/classes.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,34 @@ class PlayIntegrityTokenResponse(BaseModel):
expires_in: int


class AuthorizedChatRequest(ChatRequest):
class AuthorizedRequestLogMixin:
"""Shared structured log fields for authorized requests.

Bound into the loguru contextvar via ``logger.contextualize(**log_fields)``
in the proxy handlers so every log line emitted while serving the request
(including mid-stream errors) carries them as queryable ``record.extra.*``
fields, rather than concatenated into the message string.
"""

user: str
service_type: str
purpose: str

@property
def log_fields(self) -> dict[str, str]:
# `model` only exists on chat requests, not search requests.
fields = {
"user": self.user,
"service_type": self.service_type,
"purpose": self.purpose or "-",
}
model = getattr(self, "model", None)
if model:
fields["model"] = model
return fields


class AuthorizedChatRequest(ChatRequest, AuthorizedRequestLogMixin):
user: str
service_type: str
purpose: str = (
Expand All @@ -101,7 +128,7 @@ class SearchRequest(BaseModel):
max_results: int = Field(ge=1, le=10)


class AuthorizedSearchRequest(SearchRequest):
class AuthorizedSearchRequest(SearchRequest, AuthorizedRequestLogMixin):
user: str
service_type: str
purpose: str = (
Expand Down
26 changes: 26 additions & 0 deletions src/mlpa/core/completions.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,26 @@ async def get_or_create_user_for_completion(

async def stream_completion(
authorized_chat_request: AuthorizedChatRequest, request: Request
):
"""Bind request log fields onto the loguru contextvar, then stream.

The contextvar must be held *inside* the generator (not the route handler)
so it stays active while Starlette iterates the SSE body — otherwise
mid-stream errors would lose the fields (the streaming blind spot).
"""
with logger.contextualize(**authorized_chat_request.log_fields):
gen = _stream_completion(authorized_chat_request, request)
try:
async for chunk in gen:
yield chunk
finally:
# Forward close/GeneratorExit into the inner generator so its
# client-disconnect handling runs while the fields are still bound.
await gen.aclose()


async def _stream_completion(
authorized_chat_request: AuthorizedChatRequest, request: Request
):
"""
Proxies a streaming request to LiteLLM.
Expand Down Expand Up @@ -297,6 +317,12 @@ async def _read_next_chunk(


async def get_completion(authorized_chat_request: AuthorizedChatRequest):
"""Bind request log fields onto the loguru contextvar, then proxy."""
with logger.contextualize(**authorized_chat_request.log_fields):
return await _get_completion(authorized_chat_request)


async def _get_completion(authorized_chat_request: AuthorizedChatRequest):
"""
Proxies a non-streaming request to LiteLLM.
"""
Expand Down
9 changes: 7 additions & 2 deletions src/mlpa/core/logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,8 +129,13 @@ async def _wrapper(self, *args, **kwargs):
)
try:
response = await original(self, *args, **kwargs)
except Exception:
logger.error(f"HTTPX {method_name.upper()=} request failed for {url=}")
except Exception as exc:
# Include the exception type + repr: transport failures often
# have an empty str(), so the bare URL alone was undiagnosable.
logger.error(
f"HTTPX {method_name.upper()=} request failed for {url=}: "
f"{type(exc).__name__}: {exc!r}"
)
raise
logger.debug(
f"HTTPX {method_name.upper()} response <- {url=} {response.status_code=}",
Expand Down
6 changes: 6 additions & 0 deletions src/mlpa/core/search.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,12 @@


async def get_search(authorized_search_request: AuthorizedSearchRequest):
"""Bind request log fields onto the loguru contextvar, then proxy."""
with logger.contextualize(**authorized_search_request.log_fields):
return await _get_search(authorized_search_request)


async def _get_search(authorized_search_request: AuthorizedSearchRequest):
start_time = time.perf_counter()
body = sanitize_request_body(
authorized_search_request.model_dump(exclude_none=True)
Expand Down
16 changes: 15 additions & 1 deletion src/mlpa/core/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -245,6 +245,11 @@ def raise_and_log(
HTTPException with the chosen status code and a sanitized error message.
If the upstream error body contains a nested error message, it is extracted
so clients receive the actual upstream detail in debug mode. (dev environment only)

Request-identifying fields (user / service_type / model / purpose) are not
passed here; they are bound on the loguru contextvar by the proxy handler
(``logger.contextualize(**req.log_fields)``) so they ride along in
``record.extra`` automatically.
"""
response = getattr(e, "response", None)
error_text = response.text if response is not None else ""
Expand All @@ -264,7 +269,16 @@ def raise_and_log(
except (json.JSONDecodeError, AttributeError, TypeError):
pass
status_code = response_code or getattr(response, "status_code", None) or 500
logger.error(f"{response_text_prefix or GENERIC_UPSTREAM_ERROR}: {detail_text}")
# Transport errors (httpx ConnectError / RemoteProtocolError / ReadError /
# timeouts) carry no `.response` and frequently stringify to "", which is
# why these logs used to read "Failed to proxy request: " with no detail.
# Fall back to the exception class name + repr, and attach the traceback
# (logger.opt(exception=...)) so jsonPayload.record.exception is populated.
exc_type = type(e).__name__
logged_detail = detail_text or repr(e)
logger.opt(exception=e).error(
f"{response_text_prefix or GENERIC_UPSTREAM_ERROR}: {exc_type}: {logged_detail}"
)
if stream:
error_msg = detail_text if env.MLPA_DEBUG else GENERIC_UPSTREAM_ERROR
payload = {"code": status_code, "error": error_msg}
Expand Down
102 changes: 99 additions & 3 deletions src/tests/unit/test_completions.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,10 +28,36 @@
LITELLM_HEADER_RESPONSE_DURATION_MS,
env,
)
from mlpa.core.logger import logger as loguru_logger
from mlpa.core.prometheus_metrics import PrometheusRejectionReason, PrometheusResult
from tests.consts import SAMPLE_REQUEST, SUCCESSFUL_CHAT_RESPONSE


@contextlib.contextmanager
def _capture_logs():
"""Capture raw loguru records emitted within the block.

Each captured item is a loguru ``Message`` whose ``.record`` dict exposes
``message`` / ``level`` / ``exception`` / ``extra`` — lets tests assert on
log content, attached tracebacks, and contextvar-bound fields.
"""
records = []
sink_id = loguru_logger.add(records.append, level="DEBUG", format="{message}")
try:
yield records
finally:
loguru_logger.remove(sink_id)


def _proxy_error_records(records):
return [
item.record
for item in records
if item.record["level"].name == "ERROR"
and "Failed to proxy request" in item.record["message"]
]


def _latency_count(spy, result: PrometheusResult, req=SAMPLE_REQUEST) -> float:
return spy.histogram_count(
"chat_completion_latency",
Expand Down Expand Up @@ -940,7 +966,7 @@ async def test_stream_completion_400_non_rate_limit_error(
received_chunks[0]
== b'data: {"code": 400, "error": "Upstream service returned an error"}\n\n'
)
mock_logger.error.assert_called_once()
mock_logger.opt.return_value.error.assert_called_once()
metrics_spy.assert_only({"chat_completion_latency"})
assert _latency_count(metrics_spy, PrometheusResult.ERROR) == 1

Expand Down Expand Up @@ -970,7 +996,7 @@ async def test_stream_completion_429_non_rate_limit_error(
received_chunks[0]
== b'data: {"code": 429, "error": "Upstream service returned an error"}\n\n'
)
mock_logger.error.assert_called_once()
mock_logger.opt.return_value.error.assert_called_once()
metrics_spy.assert_only({"chat_completion_latency"})
assert _latency_count(metrics_spy, PrometheusResult.ERROR) == 1

Expand Down Expand Up @@ -1022,7 +1048,7 @@ async def test_stream_completion_429_invalid_json(
received_chunks[0]
== b'data: {"code": 429, "error": "Upstream service returned an error"}\n\n'
)
mock_logger.error.assert_called_once()
mock_logger.opt.return_value.error.assert_called_once()
metrics_spy.assert_only({"chat_completion_latency"})
assert _latency_count(metrics_spy, PrometheusResult.ERROR) == 1

Expand Down Expand Up @@ -1373,3 +1399,73 @@ async def test_get_completion_sanitizes_response_surrogates(mocker):
assert "\ud83e" not in data["choices"][0]["message"]["content"]
assert data["choices"][0]["message"]["content"].startswith("done ")
_httpx_encode_json(data) # must not raise


async def test_get_completion_empty_message_transport_error_is_diagnosable(mocker):
"""Regression for the prod 502s that logged a bare ``Failed to proxy request:``.

A transport error with no ``.response`` and an empty ``str()`` (e.g.
``RemoteProtocolError("")``) must still produce a diagnosable ERROR line:
the exception type + repr in the message, the traceback attached, and the
request-identifying fields bound via ``contextualize(**log_fields)``.
"""
mock_client = AsyncMock()
mock_client.post.side_effect = httpx.RemoteProtocolError("")
mocker.patch("mlpa.core.completions.get_http_client", return_value=mock_client)
mocker.patch.object(env, "MLPA_DEBUG", False)

with _capture_logs() as records:
with pytest.raises(HTTPException) as exc_info:
await get_completion(SAMPLE_REQUEST)

assert exc_info.value.status_code == 502

proxy_errors = _proxy_error_records(records)
assert len(proxy_errors) == 1
rec = proxy_errors[0]
# Exception type is named, and the message is NOT the old blank form.
assert "RemoteProtocolError" in rec["message"]
assert not rec["message"].rstrip().endswith("Failed to proxy request:")
# Traceback attached via logger.opt(exception=e).
assert rec["exception"] is not None
assert rec["exception"].type is httpx.RemoteProtocolError
# Request fields bound on the record (queryable as record.extra.*).
assert rec["extra"]["user"] == SAMPLE_REQUEST.user
assert rec["extra"]["model"] == SAMPLE_REQUEST.model
assert rec["extra"]["service_type"] == SAMPLE_REQUEST.service_type


async def test_stream_mid_stream_error_binds_request_fields(
mocker, mock_request, metrics_spy
):
"""Streaming blind-spot regression.

An error raised mid-SSE-stream (after MLPA already returned 200) must still
log with the request fields bound — proving the ``contextualize`` scope set
inside ``stream_completion`` survives generator iteration, unlike the
middleware scope which has already exited by the time the body iterates.
"""
role_chunk = (
b'data: {"choices":[{"delta":{"role":"assistant","content":null}}]}\n\n'
)

async def _failing_aiter_bytes():
yield role_chunk
raise httpx.RemoteProtocolError("")

_patch_mock_stream_client(mocker, _failing_aiter_bytes)
mocker.patch.object(env, "MLPA_DEBUG", False)

with _capture_logs() as records:
received = [c async for c in stream_completion(SAMPLE_REQUEST, mock_request)]

assert any(b'"error"' in chunk for chunk in received)

proxy_errors = _proxy_error_records(records)
assert len(proxy_errors) == 1
rec = proxy_errors[0]
assert "RemoteProtocolError" in rec["message"]
assert rec["exception"] is not None
assert rec["extra"]["user"] == SAMPLE_REQUEST.user
assert rec["extra"]["model"] == SAMPLE_REQUEST.model
assert rec["extra"]["service_type"] == SAMPLE_REQUEST.service_type
58 changes: 58 additions & 0 deletions src/tests/unit/test_logger.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
import contextlib

import httpx
import pytest
from loguru import logger as loguru_logger

from mlpa.core.config import env
from mlpa.core.logger import _enable_httpx_logging


@contextlib.contextmanager
def _capture_logs():
records = []
sink_id = loguru_logger.add(records.append, level="DEBUG", format="{message}")
try:
yield records
finally:
loguru_logger.remove(sink_id)


async def test_httpx_wrapper_logs_exc_type_on_transport_failure(mocker):
"""The HTTPX logging wrapper must name the exception type + repr on failure.

Transport errors often stringify to ``""``, so the bare URL alone (the old
log) was undiagnosable. This is the first line of the 502 "triple".
"""
mocker.patch.object(env, "HTTPX_LOGGING", True)

# Save and restore the real httpx methods so the global patch never leaks
# into other tests, regardless of whether logging was already enabled.
before_get = httpx.AsyncClient.get
before_post = httpx.AsyncClient.post
try:
_enable_httpx_logging()

def _raise(request: httpx.Request) -> httpx.Response:
raise httpx.ConnectError("", request=request)

transport = httpx.MockTransport(_raise)
url = "http://litellm:8000/v1/chat/completions"
with _capture_logs() as records:
async with httpx.AsyncClient(transport=transport) as client:
with pytest.raises(httpx.ConnectError):
await client.post(url)
finally:
httpx.AsyncClient.get = before_get
httpx.AsyncClient.post = before_post

failures = [
item.record["message"]
for item in records
if item.record["level"].name == "ERROR"
and "request failed" in item.record["message"]
]
assert len(failures) == 1
msg = failures[0]
assert "ConnectError" in msg
assert url in msg
Loading