Skip to content

Commit a8c18eb

Browse files
author
Mateusz
committed
fix: harden composite routing and error handling
1 parent d710094 commit a8c18eb

12 files changed

Lines changed: 264 additions & 80 deletions

docs/user_guide/features/auto-continue-removal.md

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,29 @@ session:
5555
- Your workflow uses custom continue-like keywords that should reach the model.
5656
- You are testing context window behavior and need every message forwarded verbatim.
5757

58+
## Usage Examples
59+
60+
### Resume After A Network Interruption
61+
62+
1. A coding session is interrupted mid-stream.
63+
2. The user sends `continue`.
64+
3. The proxy keeps the message in local history but does not forward it to the remote model.
65+
4. The next backend request reuses the meaningful prior context without wasting tokens on the mechanical resume prompt.
66+
67+
### Preserve Literal Continue Prompts
68+
69+
If you intentionally want the remote model to see `continue` or `proceed`, disable the feature first:
70+
71+
```bash
72+
python -m src.core.cli --disable-auto-continue-removal
73+
```
74+
75+
## Use Cases
76+
77+
- Agentic coding sessions where transient disconnects are common and users resume by typing `continue`.
78+
- Long-running terminal workflows where preserving context window capacity matters.
79+
- Local debugging of session history, where the proxy should remember the resume command without sending it upstream.
80+
5881
## Logging
5982

6083
When a message is tagged, the proxy logs at INFO level:

src/connectors/openai_codex/contracts.py

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -55,12 +55,15 @@ class ProcessedMessage(BaseModel):
5555
metadata: dict[str, object] | None = None
5656

5757

58-
class MessagePart(BaseModel):
59-
"""Content part in a multimodal message."""
60-
61-
type: str
62-
text: str | None = None
63-
data: object | None = None
58+
class MessagePart(BaseModel):
59+
"""Content part in a multimodal message."""
60+
61+
type: str
62+
text: str | None = None
63+
data: object | None = None
64+
65+
66+
ProcessedMessage.model_rebuild()
6467

6568

6669
class CodexInputItem(BaseModel):

src/core/common/exceptions.py

Lines changed: 44 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -10,12 +10,35 @@
1010
import time
1111
from typing import TYPE_CHECKING, Any
1212

13-
if TYPE_CHECKING:
14-
from src.core.domain.client_termination import ClientTerminationReason
15-
from src.core.domain.session_key import SessionKey
16-
17-
18-
class LLMProxyError(Exception):
13+
if TYPE_CHECKING:
14+
from src.core.domain.client_termination import ClientTerminationReason
15+
from src.core.domain.session_key import SessionKey
16+
17+
18+
def _is_json_scalar(value: Any) -> bool:
19+
return isinstance(value, str | int | float | bool) or value is None
20+
21+
22+
def _to_json_safe_error_value(value: Any) -> Any:
23+
if _is_json_scalar(value):
24+
return value
25+
if isinstance(value, dict):
26+
return {
27+
str(key): _to_json_safe_error_value(nested_value)
28+
for key, nested_value in value.items()
29+
if isinstance(key, str | int | float | bool)
30+
}
31+
if isinstance(value, list | tuple | set):
32+
return [_to_json_safe_error_value(item) for item in value]
33+
34+
model_dump = getattr(value, "model_dump", None)
35+
if callable(model_dump):
36+
return _to_json_safe_error_value(model_dump(mode="json"))
37+
38+
return str(value)
39+
40+
41+
class LLMProxyError(Exception):
1942
"""Base exception class for all LLM proxy errors."""
2043

2144
__resilience_context__: dict[str, Any] | None
@@ -44,23 +67,25 @@ def __init__(
4467
for key, value in (kwargs or {}).items():
4568
setattr(self, key, value)
4669

47-
def to_dict(self) -> dict[str, Any]:
48-
error_dict: dict[str, Any] = {
49-
"message": self.message,
50-
"type": self.__class__.__name__,
51-
"details": self.details,
52-
}
70+
def to_dict(self) -> dict[str, Any]:
71+
error_dict: dict[str, Any] = {
72+
"message": self.message,
73+
"type": self.__class__.__name__,
74+
"details": _to_json_safe_error_value(self.details),
75+
}
5376

5477
# Include any additional attributes that were set via kwargs
5578
for attr_name in dir(self):
5679
if (
57-
not attr_name.startswith("_")
58-
and attr_name not in ["message", "details", "status_code", "args"]
59-
and not callable(getattr(self, attr_name))
60-
):
61-
error_dict[attr_name] = getattr(self, attr_name)
62-
63-
return {"error": error_dict}
80+
not attr_name.startswith("_")
81+
and attr_name not in ["message", "details", "status_code", "args"]
82+
and not callable(getattr(self, attr_name))
83+
):
84+
error_dict[attr_name] = _to_json_safe_error_value(
85+
getattr(self, attr_name)
86+
)
87+
88+
return {"error": error_dict}
6489

6590

6691
class AuthenticationError(LLMProxyError):

src/core/services/backend_model_resolver.py

Lines changed: 16 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -34,12 +34,12 @@
3434
from src.core.interfaces.planning_phase_manager_interface import IPlanningPhaseManager
3535
from src.core.interfaces.session_service_interface import ISessionService
3636
from src.core.services.backend_routing_service import BackendRoutingService
37-
from src.core.services.composite_routing_state import (
38-
COMPOSITE_LEAF_RESOLUTION_EXTRA_BODY_KEY,
39-
COMPOSITE_LEAF_RESOLUTION_FLAG,
40-
is_composite_selector,
41-
resolve_composite_routing_surface,
42-
)
37+
from src.core.services.composite_routing_state import (
38+
COMPOSITE_LEAF_RESOLUTION_EXTRA_BODY_KEY,
39+
COMPOSITE_LEAF_RESOLUTION_FLAG,
40+
COMPOSITE_LEAF_SELECTOR_EXTRA_BODY_KEY,
41+
resolve_composite_routing_surface,
42+
)
4343
from src.core.services.replacement_compatibility_bridge import (
4444
ReplacementCompatibilityBridge,
4545
)
@@ -476,21 +476,18 @@ def _is_composite_leaf_resolution(
476476
context: RequestContext | None,
477477
request: ChatRequest,
478478
) -> bool:
479-
model = request.model
480-
if is_composite_selector(model):
481-
return False
479+
if context is not None and bool(
480+
context.extensions.get(COMPOSITE_LEAF_RESOLUTION_FLAG)
481+
):
482+
return True
482483

483-
if context is None:
484-
extra_body = request.extra_body
485-
if isinstance(extra_body, dict):
486-
return bool(extra_body.get(COMPOSITE_LEAF_RESOLUTION_EXTRA_BODY_KEY))
484+
extra_body = request.extra_body
485+
if not isinstance(extra_body, dict):
487486
return False
488-
if bool(context.extensions.get(COMPOSITE_LEAF_RESOLUTION_FLAG)):
489-
return True
490-
extra_body = request.extra_body
491-
if isinstance(extra_body, dict):
492-
return bool(extra_body.get(COMPOSITE_LEAF_RESOLUTION_EXTRA_BODY_KEY))
493-
return False
487+
if not bool(extra_body.get(COMPOSITE_LEAF_RESOLUTION_EXTRA_BODY_KEY)):
488+
return False
489+
leaf_selector = extra_body.get(COMPOSITE_LEAF_SELECTOR_EXTRA_BODY_KEY)
490+
return isinstance(leaf_selector, str) and leaf_selector == request.model
494491

495492
@staticmethod
496493
def _normalize_uri_params(raw_value: Any) -> dict[str, JsonValue]:

src/core/services/composite_leaf_target_resolver_adapter.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
from src.core.services.composite_routing_state import (
1414
COMPOSITE_LEAF_RESOLUTION_EXTRA_BODY_KEY,
1515
COMPOSITE_LEAF_RESOLUTION_FLAG,
16+
COMPOSITE_LEAF_SELECTOR_EXTRA_BODY_KEY,
1617
)
1718

1819
__all__ = ["CompositeLeafTargetResolverAdapter"]
@@ -35,6 +36,7 @@ async def resolve_leaf(
3536
) -> BackendTarget:
3637
leaf_extra_body = dict(request.extra_body or {})
3738
leaf_extra_body[COMPOSITE_LEAF_RESOLUTION_EXTRA_BODY_KEY] = True
39+
leaf_extra_body[COMPOSITE_LEAF_SELECTOR_EXTRA_BODY_KEY] = leaf_selector
3840

3941
parsed_leaf = parse_model_with_params(leaf_selector, default_backend="")
4042
if has_explicit_backend_selector(leaf_selector):

src/core/services/composite_routing_state.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
__all__ = [
1212
"COMPOSITE_LEAF_RESOLUTION_EXTRA_BODY_KEY",
1313
"COMPOSITE_LEAF_RESOLUTION_FLAG",
14+
"COMPOSITE_LEAF_SELECTOR_EXTRA_BODY_KEY",
1415
"COMPOSITE_ROUTING_STATE_KEY",
1516
"COMPOSITE_ROUTING_SURFACE_KEY",
1617
"FAILOVER_MODE",
@@ -29,6 +30,7 @@
2930
COMPOSITE_ROUTING_SURFACE_KEY = "composite_routing_surface"
3031
COMPOSITE_LEAF_RESOLUTION_FLAG = "composite_leaf_resolution"
3132
COMPOSITE_LEAF_RESOLUTION_EXTRA_BODY_KEY = "_composite_leaf_resolution"
33+
COMPOSITE_LEAF_SELECTOR_EXTRA_BODY_KEY = "_composite_leaf_selector"
3234
FAILOVER_MODE = "failover"
3335
WEIGHTED_RETRY_MODE = "weighted_retry"
3436

src/core/services/composite_selector_parser.py

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,7 @@ def parse(self, routing_input: CompositeRoutingInput) -> CompositeRoutePlan:
6262
parts = self._split_top_level(selector, operator)
6363
other_operator = "^" if operator == "|" else "|"
6464
has_mixed_operator = any(
65-
self._contains_operator_in_route_segment(segment, other_operator)
65+
self._contains_operator_outside_brackets(segment, other_operator)
6666
for segment in parts
6767
)
6868
if has_mixed_operator:
@@ -132,11 +132,6 @@ def _detect_primary_operator(selector: str) -> str | None:
132132
return char
133133
return None
134134

135-
@classmethod
136-
def _contains_operator_in_route_segment(cls, segment: str, operator: str) -> bool:
137-
route_segment, _, _ = segment.partition("?")
138-
return cls._contains_operator_outside_brackets(route_segment, operator)
139-
140135
@staticmethod
141136
def _contains_operator_outside_brackets(selector: str, operator: str) -> bool:
142137
bracket_depth = 0

src/core/services/streaming/error_mapping.py

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,12 +21,29 @@
2121
LLMProxyError,
2222
ParsingError,
2323
RateLimitExceededError,
24+
RoutingError,
2425
)
2526
from src.core.domain.streaming.contracts import StreamingErrorInfo
2627

2728
logger = logging.getLogger(__name__)
2829

2930

31+
def _resolve_routing_error_status(error: RoutingError) -> int:
32+
details = getattr(error, "details", None)
33+
if isinstance(details, dict):
34+
code = details.get("code")
35+
if code == "unknown_model":
36+
return 404
37+
if code == "unsupported_on_instance":
38+
return 400
39+
if code == "temporarily_unavailable":
40+
return 503
41+
if code == "policy_rejected":
42+
return 403
43+
status_code = getattr(error, "status_code", None)
44+
return status_code if isinstance(status_code, int) else 500
45+
46+
3047
def _merge_provider_retry_metadata(
3148
details: dict[str, str], detail_payload: dict[str, Any]
3249
) -> None:
@@ -240,7 +257,9 @@ async def handle_streaming_error(
240257

241258
# Extract status_code if available
242259
status_code: int | None = None
243-
if hasattr(mapped_error, "status_code"):
260+
if isinstance(mapped_error, RoutingError):
261+
status_code = _resolve_routing_error_status(mapped_error)
262+
elif hasattr(mapped_error, "status_code"):
244263
status_code = mapped_error.status_code
245264

246265
# For quota_exceeded errors, use 503 instead of 429

tests/unit/app/middleware/test_exception_middleware.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,11 @@
55
from fastapi.testclient import TestClient
66
from src.core.app.middleware.exception_middleware import DomainExceptionMiddleware
77
from src.core.common.exceptions import DuplicateRequestError, RateLimitExceededError
8+
from src.core.domain.composite_routing import (
9+
CompositeSelectorValidationError,
10+
CompositeValidationErrorCode,
11+
CompositeValidationErrorEnvelope,
12+
)
813

914

1015
def test_domain_exception_middleware_sets_retry_after_header(monkeypatch):
@@ -62,6 +67,36 @@ async def duplicate_endpoint() -> None:
6267
assert body["error"]["type"] == "DuplicateRequestError"
6368

6469

70+
def test_domain_exception_middleware_serializes_composite_validation_envelope():
71+
app = FastAPI()
72+
app.add_middleware(DomainExceptionMiddleware)
73+
74+
@app.get("/composite-validation")
75+
async def composite_validation_endpoint() -> None:
76+
raise CompositeSelectorValidationError(
77+
CompositeValidationErrorEnvelope(
78+
code=CompositeValidationErrorCode.UNSUPPORTED_CONSTRUCT,
79+
message="Mixed operators are not supported.",
80+
selector_echo="openai:gpt-4|anthropic:claude^gemini:flash",
81+
)
82+
)
83+
84+
_ = composite_validation_endpoint
85+
86+
with TestClient(app) as client:
87+
response = client.get("/composite-validation")
88+
89+
assert response.status_code == 400
90+
body = response.json()
91+
assert body["error"]["type"] == "CompositeSelectorValidationError"
92+
assert body["error"]["details"]["composite_validation"]["code"] == (
93+
"unsupported_construct"
94+
)
95+
assert body["error"]["envelope"]["selector_echo"] == (
96+
"openai:gpt-4|anthropic:claude^gemini:flash"
97+
)
98+
99+
65100
async def test_domain_exception_middleware_reraises_transport_error_on_final_body_send():
66101
"""Transport failures on the final body write must not be swallowed."""
67102

tests/unit/core/services/test_backend_completion_flow_streaming_error_envelope.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -230,3 +230,33 @@ async def test_streaming_error_envelope_enters_failover_recovery_when_enabled()
230230
"category": "validation",
231231
"retryable": False,
232232
}
233+
234+
235+
@pytest.mark.asyncio
236+
async def test_streaming_terminal_routing_error_uses_unknown_model_http_status() -> (
237+
None
238+
):
239+
flow = _build_flow_with_erroring_backend()
240+
241+
routing_error = RoutingError(
242+
message="Missing extracted backend",
243+
details={
244+
"code": "unknown_model",
245+
"category": "validation",
246+
"retryable": False,
247+
},
248+
)
249+
250+
envelope = await flow._build_terminal_error_stream_envelope(
251+
error=routing_error,
252+
provider="gemini-oauth-plan",
253+
)
254+
255+
assert envelope.status_code == 404
256+
257+
assert envelope.content is not None
258+
chunks = [chunk async for chunk in envelope.content]
259+
assert len(chunks) == 1
260+
error_payload = chunks[0].metadata.get("error")
261+
assert isinstance(error_payload, dict)
262+
assert error_payload.get("status_code") == 404

0 commit comments

Comments
 (0)