diff --git a/src/shade/errors.py b/src/shade/errors.py index 710bf34..2670540 100644 --- a/src/shade/errors.py +++ b/src/shade/errors.py @@ -1,6 +1,9 @@ from __future__ import annotations -from typing import Optional +import json +from typing import Any, Optional + +INVALID_REQUEST_STATUS_CODES = (400, 422) class ShadeError(Exception): @@ -28,7 +31,47 @@ class AuthenticationError(ShadeError): class InvalidRequestError(ShadeError): - """Raised when a request is malformed or rejected by validation.""" + """Raised on HTTP 400/422 responses for malformed or invalid parameters.""" + + def __init__( + self, + message: str, + status_code: Optional[int] = None, + response_body: Optional[str] = None, + param: Optional[str] = None, + field_errors: Optional[dict[str, Any]] = None, + ) -> None: + super().__init__(message, status_code, response_body) + parsed = _parse_error_response(response_body) + self.param: Optional[str] = param if param is not None else parsed.get("param") + self.field_errors: dict[str, Any] = ( + field_errors if field_errors is not None else parsed.get("field_errors", {}) + ) + + def __str__(self) -> str: + message = self.message + if self.param: + message = f"{message} (param: {self.param})" + if self.status_code is None: + return message + return f"{message} (status code: {self.status_code})" + + @classmethod + def from_response( + cls, + status_code: int, + response_body: Optional[str] = None, + ) -> "InvalidRequestError": + """Construct from a raw 400/422 API response body.""" + parsed = _parse_error_response(response_body) + message = parsed.get("message") or "Invalid request" + return cls( + message, + status_code=status_code, + response_body=response_body, + param=parsed.get("param"), + field_errors=parsed.get("field_errors", {}), + ) class NotFoundError(ShadeError): @@ -37,3 +80,45 @@ class NotFoundError(ShadeError): class NetworkError(ShadeError): """Raised when the SDK cannot complete a network request.""" + + +def raise_for_invalid_request( + status_code: int, + response_body: Optional[str] = None, +) -> None: + """Raise InvalidRequestError when the API returns 400 or 422.""" + if status_code in INVALID_REQUEST_STATUS_CODES: + raise InvalidRequestError.from_response(status_code, response_body) + + +def _parse_body(response_body: Optional[str]) -> dict[str, Any]: + if not response_body: + return {} + try: + data = json.loads(response_body) + return data if isinstance(data, dict) else {} + except (json.JSONDecodeError, ValueError): + return {} + + +def _parse_error_response(response_body: Optional[str]) -> dict[str, Any]: + data = _parse_body(response_body) + error = data.get("error", {}) + if not isinstance(error, dict): + error = {} + + field_errors = error.get("field_errors") + if not isinstance(field_errors, dict): + field_errors = data.get("field_errors") + if not isinstance(field_errors, dict): + field_errors = {} + + message = error.get("message") + if not message: + message = data.get("message") + + return { + "message": message, + "param": error.get("param"), + "field_errors": field_errors, + } diff --git a/tests/test_errors.py b/tests/test_errors.py index 4339d2b..7539489 100644 --- a/tests/test_errors.py +++ b/tests/test_errors.py @@ -1,3 +1,4 @@ +import pytest import shade from shade import ( AuthenticationError, @@ -6,6 +7,7 @@ NotFoundError, ShadeError, ) +from shade.errors import raise_for_invalid_request def test_shade_error_can_be_raised_standalone(): @@ -49,3 +51,87 @@ def test_package_root_exports_error_classes(): assert shade.InvalidRequestError is InvalidRequestError assert shade.NetworkError is NetworkError assert shade.NotFoundError is NotFoundError + + +def test_invalid_request_error_parses_param_from_body(): + body = ( + '{"error": {"code": "invalid_param", "param": "amount", ' + '"message": "Amount must be greater than zero"}}' + ) + error = InvalidRequestError("invalid request", status_code=400, response_body=body) + + assert error.param == "amount" + assert error.message == "invalid request" + + +def test_invalid_request_error_parses_field_errors_from_body(): + body = ( + '{"error": {"code": "invalid_param", "param": "amount", "message": "Validation failed", ' + '"field_errors": {"amount": "must be positive", "currency": "is required"}}}' + ) + error = InvalidRequestError.from_response(400, body) + + assert error.param == "amount" + assert error.field_errors == { + "amount": "must be positive", + "currency": "is required", + } + + +def test_invalid_request_error_str_includes_param(): + body = ( + '{"error": {"code": "invalid_param", "param": "amount", ' + '"message": "Amount must be greater than zero"}}' + ) + error = InvalidRequestError.from_response(400, body) + + assert "amount" in str(error) + assert "Amount must be greater than zero" in str(error) + assert str(error) == "Amount must be greater than zero (param: amount) (status code: 400)" + + +def test_invalid_request_error_explicit_attrs_override_body(): + body = ( + '{"error": {"param": "currency", "field_errors": {"currency": "invalid"}}, ' + '"field_errors": {"amount": "too small"}}' + ) + error = InvalidRequestError( + "invalid request", + status_code=422, + response_body=body, + param="amount", + field_errors={"amount": "required"}, + ) + + assert error.param == "amount" + assert error.field_errors == {"amount": "required"} + + +def test_400_response_raises_invalid_request_error(): + body = ( + '{"error": {"code": "invalid_param", "param": "amount", ' + '"message": "Amount must be greater than zero"}}' + ) + + with pytest.raises(InvalidRequestError) as exc_info: + raise_for_invalid_request(400, body) + + error = exc_info.value + assert error.status_code == 400 + assert error.param == "amount" + assert isinstance(error, ShadeError) + + +def test_422_response_raises_invalid_request_error(): + body = '{"error": {"param": "email", "message": "Invalid email format"}}' + + with pytest.raises(InvalidRequestError) as exc_info: + raise_for_invalid_request(422, body) + + assert exc_info.value.param == "email" + assert exc_info.value.status_code == 422 + + +def test_raise_for_invalid_request_ignores_other_status_codes(): + raise_for_invalid_request(404, '{"error": {"message": "not found"}}') + raise_for_invalid_request(500, '{"error": {"message": "server error"}}')