Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
89 changes: 87 additions & 2 deletions src/shade/errors.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
from __future__ import annotations

from typing import Optional
import json
from typing import Any, Optional

INVALID_REQUEST_STATUS_CODES = (400, 422)


class ShadeError(Exception):
Expand Down Expand Up @@ -28,7 +31,47 @@ class AuthenticationError(ShadeError):


class InvalidRequestError(ShadeError):
"""Raised when a request is malformed or rejected by validation."""
"""Raised on HTTP 400/422 responses for malformed or invalid parameters."""

def __init__(
self,
message: str,
status_code: Optional[int] = None,
response_body: Optional[str] = None,
param: Optional[str] = None,
field_errors: Optional[dict[str, Any]] = None,
) -> None:
super().__init__(message, status_code, response_body)
parsed = _parse_error_response(response_body)
self.param: Optional[str] = param if param is not None else parsed.get("param")
self.field_errors: dict[str, Any] = (
field_errors if field_errors is not None else parsed.get("field_errors", {})
)

def __str__(self) -> str:
message = self.message
if self.param:
message = f"{message} (param: {self.param})"
if self.status_code is None:
return message
return f"{message} (status code: {self.status_code})"

@classmethod
def from_response(
cls,
status_code: int,
response_body: Optional[str] = None,
) -> "InvalidRequestError":
"""Construct from a raw 400/422 API response body."""
parsed = _parse_error_response(response_body)
message = parsed.get("message") or "Invalid request"
return cls(
message,
status_code=status_code,
response_body=response_body,
param=parsed.get("param"),
field_errors=parsed.get("field_errors", {}),
)


class NotFoundError(ShadeError):
Expand All @@ -37,3 +80,45 @@ class NotFoundError(ShadeError):

class NetworkError(ShadeError):
"""Raised when the SDK cannot complete a network request."""


def raise_for_invalid_request(
status_code: int,
response_body: Optional[str] = None,
) -> None:
"""Raise InvalidRequestError when the API returns 400 or 422."""
if status_code in INVALID_REQUEST_STATUS_CODES:
raise InvalidRequestError.from_response(status_code, response_body)


def _parse_body(response_body: Optional[str]) -> dict[str, Any]:
if not response_body:
return {}
try:
data = json.loads(response_body)
return data if isinstance(data, dict) else {}
except (json.JSONDecodeError, ValueError):
return {}


def _parse_error_response(response_body: Optional[str]) -> dict[str, Any]:
data = _parse_body(response_body)
error = data.get("error", {})
if not isinstance(error, dict):
error = {}

field_errors = error.get("field_errors")
if not isinstance(field_errors, dict):
field_errors = data.get("field_errors")
if not isinstance(field_errors, dict):
field_errors = {}

message = error.get("message")
if not message:
message = data.get("message")

return {
"message": message,
"param": error.get("param"),
"field_errors": field_errors,
}
86 changes: 86 additions & 0 deletions tests/test_errors.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import pytest
import shade
from shade import (
AuthenticationError,
Expand All @@ -6,6 +7,7 @@
NotFoundError,
ShadeError,
)
from shade.errors import raise_for_invalid_request


def test_shade_error_can_be_raised_standalone():
Expand Down Expand Up @@ -49,3 +51,87 @@ def test_package_root_exports_error_classes():
assert shade.InvalidRequestError is InvalidRequestError
assert shade.NetworkError is NetworkError
assert shade.NotFoundError is NotFoundError


def test_invalid_request_error_parses_param_from_body():
body = (
'{"error": {"code": "invalid_param", "param": "amount", '
'"message": "Amount must be greater than zero"}}'
)
error = InvalidRequestError("invalid request", status_code=400, response_body=body)

assert error.param == "amount"
assert error.message == "invalid request"


def test_invalid_request_error_parses_field_errors_from_body():
body = (
'{"error": {"code": "invalid_param", "param": "amount", "message": "Validation failed", '
'"field_errors": {"amount": "must be positive", "currency": "is required"}}}'
)
error = InvalidRequestError.from_response(400, body)

assert error.param == "amount"
assert error.field_errors == {
"amount": "must be positive",
"currency": "is required",
}


def test_invalid_request_error_str_includes_param():
body = (
'{"error": {"code": "invalid_param", "param": "amount", '
'"message": "Amount must be greater than zero"}}'
)
error = InvalidRequestError.from_response(400, body)

assert "amount" in str(error)
assert "Amount must be greater than zero" in str(error)
assert str(error) == "Amount must be greater than zero (param: amount) (status code: 400)"


def test_invalid_request_error_explicit_attrs_override_body():
body = (
'{"error": {"param": "currency", "field_errors": {"currency": "invalid"}}, '
'"field_errors": {"amount": "too small"}}'
)
error = InvalidRequestError(
"invalid request",
status_code=422,
response_body=body,
param="amount",
field_errors={"amount": "required"},
)

assert error.param == "amount"
assert error.field_errors == {"amount": "required"}


def test_400_response_raises_invalid_request_error():
body = (
'{"error": {"code": "invalid_param", "param": "amount", '
'"message": "Amount must be greater than zero"}}'
)

with pytest.raises(InvalidRequestError) as exc_info:
raise_for_invalid_request(400, body)

error = exc_info.value
assert error.status_code == 400
assert error.param == "amount"
assert isinstance(error, ShadeError)


def test_422_response_raises_invalid_request_error():
body = '{"error": {"param": "email", "message": "Invalid email format"}}'

with pytest.raises(InvalidRequestError) as exc_info:
raise_for_invalid_request(422, body)

assert exc_info.value.param == "email"
assert exc_info.value.status_code == 422


def test_raise_for_invalid_request_ignores_other_status_codes():
raise_for_invalid_request(404, '{"error": {"message": "not found"}}')
raise_for_invalid_request(500, '{"error": {"message": "server error"}}')
Loading