Skip to content

Commit 542e668

Browse files
committed
wip
1 parent 7a9aec7 commit 542e668

4 files changed

Lines changed: 364 additions & 149 deletions

File tree

src/a2a/server/request_handlers/__init__.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,6 @@
1111
build_error_response,
1212
prepare_response_object,
1313
)
14-
from a2a.server.request_handlers.rest_handler import RESTHandler
1514

1615

1716
logger = logging.getLogger(__name__)
@@ -41,7 +40,6 @@ def __init__(self, *args, **kwargs):
4140
'DefaultRequestHandler',
4241
'GrpcHandler',
4342
'JSONRPCHandler',
44-
'RESTHandler',
4543
'RequestHandler',
4644
'build_error_response',
4745
'prepare_response_object',
Lines changed: 292 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,292 @@
1+
import json
2+
import logging
3+
from collections.abc import AsyncIterator, Awaitable, Callable
4+
from typing import TYPE_CHECKING, Any
5+
6+
from google.protobuf.json_format import MessageToDict, Parse
7+
8+
from a2a.server.context import ServerCallContext
9+
from a2a.server.request_handlers.request_handler import RequestHandler
10+
from a2a.server.routes import CallContextBuilder, DefaultCallContextBuilder
11+
from a2a.types import a2a_pb2
12+
from a2a.types.a2a_pb2 import (
13+
AgentCard,
14+
CancelTaskRequest,
15+
GetTaskPushNotificationConfigRequest,
16+
SubscribeToTaskRequest,
17+
)
18+
from a2a.utils import constants, proto_utils
19+
from a2a.utils.error_handlers import (
20+
rest_error_handler,
21+
rest_stream_error_handler,
22+
)
23+
from a2a.utils.errors import (
24+
ExtendedAgentCardNotConfiguredError,
25+
InvalidRequestError,
26+
TaskNotFoundError,
27+
)
28+
from a2a.utils.helpers import maybe_await, validate, validate_version
29+
from a2a.utils.telemetry import SpanKind, trace_class
30+
31+
32+
if TYPE_CHECKING:
33+
from sse_starlette.sse import EventSourceResponse
34+
from starlette.requests import Request
35+
from starlette.responses import JSONResponse, Response
36+
37+
_package_starlette_installed = True
38+
else:
39+
try:
40+
from sse_starlette.sse import EventSourceResponse
41+
from starlette.requests import Request
42+
from starlette.responses import JSONResponse, Response
43+
44+
_package_starlette_installed = True
45+
except ImportError:
46+
EventSourceResponse = Any
47+
Request = Any
48+
JSONResponse = Any
49+
Response = Any
50+
51+
_package_starlette_installed = False
52+
53+
logger = logging.getLogger(__name__)
54+
55+
@trace_class(kind=SpanKind.SERVER)
56+
class RestDispatcher:
57+
"""Dispatches incoming REST requests to the appropriate handler methods.
58+
59+
Handles context building, routing to RequestHandler directly, and response formatting (JSON/SSE).
60+
"""
61+
62+
def __init__( # noqa: PLR0913
63+
self,
64+
agent_card: AgentCard,
65+
request_handler: RequestHandler,
66+
extended_agent_card: AgentCard | None = None,
67+
context_builder: CallContextBuilder | None = None,
68+
card_modifier: Callable[[AgentCard], Awaitable[AgentCard] | AgentCard]
69+
| None = None,
70+
extended_card_modifier: Callable[
71+
[AgentCard, ServerCallContext], Awaitable[AgentCard] | AgentCard
72+
]
73+
| None = None,
74+
) -> None:
75+
"""Initializes the RestDispatcher.
76+
77+
Args:
78+
agent_card: The AgentCard describing the agent's capabilities.
79+
request_handler: The underlying `RequestHandler` instance to delegate requests to.
80+
extended_agent_card: An optional, distinct AgentCard to be served
81+
at the authenticated extended card endpoint.
82+
context_builder: The CallContextBuilder used to construct the
83+
ServerCallContext passed to the request_handler. If None, no
84+
ServerCallContext is passed.
85+
card_modifier: An optional callback to dynamically modify the public
86+
agent card before it is served.
87+
extended_card_modifier: An optional callback to dynamically modify
88+
the extended agent card before it is served. It receives the
89+
call context.
90+
"""
91+
if not _package_starlette_installed:
92+
raise ImportError(
93+
'Packages `starlette` and `sse-starlette` are required to use the'
94+
' `RestDispatcher`. They can be added as a part of `a2a-sdk` '
95+
'optional dependencies, `a2a-sdk[http-server]`.'
96+
)
97+
98+
self.agent_card = agent_card
99+
self.extended_agent_card = extended_agent_card
100+
self.card_modifier = card_modifier
101+
self.extended_card_modifier = extended_card_modifier
102+
self._context_builder = context_builder or DefaultCallContextBuilder()
103+
self.request_handler = request_handler
104+
105+
def _build_call_context(self, request: Request) -> ServerCallContext:
106+
call_context = self._context_builder.build(request)
107+
if 'tenant' in request.path_params:
108+
call_context.tenant = request.path_params['tenant']
109+
return call_context
110+
111+
@rest_error_handler
112+
@validate_version(constants.PROTOCOL_VERSION_1_0)
113+
async def on_message_send(self, request: Request) -> Response:
114+
"""Handles the 'message/send' REST method."""
115+
context = self._build_call_context(request)
116+
body = await request.body()
117+
params = a2a_pb2.SendMessageRequest()
118+
Parse(body, params)
119+
task_or_message = await self.request_handler.on_message_send(params, context)
120+
if isinstance(task_or_message, a2a_pb2.Task):
121+
response = a2a_pb2.SendMessageResponse(task=task_or_message)
122+
else:
123+
response = a2a_pb2.SendMessageResponse(message=task_or_message)
124+
return JSONResponse(content=MessageToDict(response))
125+
126+
@rest_stream_error_handler
127+
@validate_version(constants.PROTOCOL_VERSION_1_0)
128+
@validate(
129+
lambda self: self.agent_card.capabilities.streaming,
130+
'Streaming is not supported by the agent',
131+
)
132+
async def on_message_send_stream(self, request: Request) -> EventSourceResponse:
133+
"""Handles the 'message/stream' REST method."""
134+
try:
135+
await request.body()
136+
except (ValueError, RuntimeError, OSError) as e:
137+
raise InvalidRequestError(
138+
message=f'Failed to pre-consume request body: {e}'
139+
) from e
140+
141+
context = self._build_call_context(request)
142+
body = await request.body()
143+
params = a2a_pb2.SendMessageRequest()
144+
Parse(body, params)
145+
146+
stream = aiter(self.request_handler.on_message_send_stream(params, context))
147+
try:
148+
first_event = await anext(stream)
149+
except StopAsyncIteration:
150+
return EventSourceResponse(iter([]))
151+
152+
async def event_generator() -> AsyncIterator[str]:
153+
yield json.dumps(MessageToDict(proto_utils.to_stream_response(first_event)))
154+
async for event in stream:
155+
yield json.dumps(MessageToDict(proto_utils.to_stream_response(event)))
156+
157+
return EventSourceResponse(event_generator())
158+
159+
@rest_error_handler
160+
@validate_version(constants.PROTOCOL_VERSION_1_0)
161+
async def on_cancel_task(self, request: Request) -> Response:
162+
"""Handles the 'tasks/cancel' REST method."""
163+
context = self._build_call_context(request)
164+
task_id = request.path_params['id']
165+
task = await self.request_handler.on_cancel_task(CancelTaskRequest(id=task_id), context)
166+
if task:
167+
return JSONResponse(content=MessageToDict(task))
168+
raise TaskNotFoundError
169+
170+
@rest_stream_error_handler
171+
@validate_version(constants.PROTOCOL_VERSION_1_0)
172+
@validate(
173+
lambda self: self.agent_card.capabilities.streaming,
174+
'Streaming is not supported by the agent',
175+
)
176+
async def on_subscribe_to_task(self, request: Request) -> EventSourceResponse:
177+
"""Handles the 'SubscribeToTask' REST method."""
178+
try:
179+
await request.body()
180+
except (ValueError, RuntimeError, OSError) as e:
181+
raise InvalidRequestError(
182+
message=f'Failed to pre-consume request body: {e}'
183+
) from e
184+
185+
context = self._build_call_context(request)
186+
task_id = request.path_params['id']
187+
188+
stream = aiter(self.request_handler.on_subscribe_to_task(SubscribeToTaskRequest(id=task_id), context))
189+
try:
190+
first_event = await anext(stream)
191+
except StopAsyncIteration:
192+
return EventSourceResponse(iter([]))
193+
194+
async def event_generator() -> AsyncIterator[str]:
195+
yield json.dumps(MessageToDict(proto_utils.to_stream_response(first_event)))
196+
async for event in stream:
197+
yield json.dumps(MessageToDict(proto_utils.to_stream_response(event)))
198+
199+
return EventSourceResponse(event_generator())
200+
201+
@rest_error_handler
202+
@validate_version(constants.PROTOCOL_VERSION_1_0)
203+
async def on_get_task(self, request: Request) -> Response:
204+
"""Handles the 'tasks/{id}' REST method."""
205+
context = self._build_call_context(request)
206+
params = a2a_pb2.GetTaskRequest()
207+
proto_utils.parse_params(request.query_params, params)
208+
params.id = request.path_params['id']
209+
task = await self.request_handler.on_get_task(params, context)
210+
if task:
211+
return JSONResponse(content=MessageToDict(task))
212+
raise TaskNotFoundError
213+
214+
@rest_error_handler
215+
@validate_version(constants.PROTOCOL_VERSION_1_0)
216+
async def get_push_notification(self, request: Request) -> Response:
217+
"""Handles the 'tasks/pushNotificationConfig/get' REST method."""
218+
context = self._build_call_context(request)
219+
task_id = request.path_params['id']
220+
push_id = request.path_params['push_id']
221+
params = GetTaskPushNotificationConfigRequest(task_id=task_id, id=push_id)
222+
config = await self.request_handler.on_get_task_push_notification_config(params, context)
223+
return JSONResponse(content=MessageToDict(config))
224+
225+
@rest_error_handler
226+
@validate_version(constants.PROTOCOL_VERSION_1_0)
227+
async def delete_push_notification(self, request: Request) -> Response:
228+
"""Handles the 'tasks/pushNotificationConfig/delete' REST method."""
229+
context = self._build_call_context(request)
230+
task_id = request.path_params['id']
231+
push_id = request.path_params['push_id']
232+
params = a2a_pb2.DeleteTaskPushNotificationConfigRequest(task_id=task_id, id=push_id)
233+
await self.request_handler.on_delete_task_push_notification_config(params, context)
234+
return JSONResponse(content={})
235+
236+
@rest_error_handler
237+
@validate_version(constants.PROTOCOL_VERSION_1_0)
238+
@validate(
239+
lambda self: self.agent_card.capabilities.push_notifications,
240+
'Push notifications are not supported by the agent',
241+
)
242+
async def set_push_notification(self, request: Request) -> Response:
243+
"""Handles the 'tasks/pushNotificationConfig/set' REST method."""
244+
context = self._build_call_context(request)
245+
body = await request.body()
246+
params = a2a_pb2.TaskPushNotificationConfig()
247+
Parse(body, params)
248+
params.task_id = request.path_params['id']
249+
config = await self.request_handler.on_create_task_push_notification_config(params, context)
250+
return JSONResponse(content=MessageToDict(config))
251+
252+
@rest_error_handler
253+
@validate_version(constants.PROTOCOL_VERSION_1_0)
254+
async def list_push_notifications(self, request: Request) -> Response:
255+
"""Handles the 'tasks/pushNotificationConfig/list' REST method."""
256+
context = self._build_call_context(request)
257+
params = a2a_pb2.ListTaskPushNotificationConfigsRequest()
258+
proto_utils.parse_params(request.query_params, params)
259+
params.task_id = request.path_params['id']
260+
result = await self.request_handler.on_list_task_push_notification_configs(params, context)
261+
return JSONResponse(content=MessageToDict(result))
262+
263+
@rest_error_handler
264+
@validate_version(constants.PROTOCOL_VERSION_1_0)
265+
async def list_tasks(self, request: Request) -> Response:
266+
"""Handles the 'tasks/list' REST method."""
267+
context = self._build_call_context(request)
268+
params = a2a_pb2.ListTasksRequest()
269+
proto_utils.parse_params(request.query_params, params)
270+
result = await self.request_handler.on_list_tasks(params, context)
271+
return JSONResponse(content=MessageToDict(result, always_print_fields_with_no_presence=True))
272+
273+
@rest_error_handler
274+
async def handle_authenticated_agent_card(self, request: Request) -> Response:
275+
"""Handles the 'extendedAgentCard' REST method."""
276+
if not self.agent_card.capabilities.extended_agent_card:
277+
raise ExtendedAgentCardNotConfiguredError(
278+
message='Authenticated card not supported'
279+
)
280+
card_to_serve = self.extended_agent_card or self.agent_card
281+
282+
if self.extended_card_modifier:
283+
context = self._build_call_context(request)
284+
card_to_serve = await maybe_await(
285+
self.extended_card_modifier(card_to_serve, context)
286+
)
287+
elif self.card_modifier:
288+
card_to_serve = await maybe_await(self.card_modifier(card_to_serve))
289+
290+
return JSONResponse(
291+
content=MessageToDict(card_to_serve, preserving_proto_field_name=True)
292+
)

0 commit comments

Comments
 (0)