Skip to content

Commit ba6deaf

Browse files
committed
WIP
1 parent be01a4e commit ba6deaf

3 files changed

Lines changed: 34 additions & 13 deletions

File tree

src/a2a/server/request_handlers/__init__.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,10 @@
66
DefaultRequestHandler,
77
)
88
from a2a.server.request_handlers.jsonrpc_handler import JSONRPCHandler
9-
from a2a.server.request_handlers.request_handler import RequestHandler
9+
from a2a.server.request_handlers.request_handler import (
10+
RequestHandler,
11+
validate_request_params,
12+
)
1013
from a2a.server.request_handlers.response_helpers import (
1114
build_error_response,
1215
prepare_response_object,
@@ -45,4 +48,5 @@ def __init__(self, *args, **kwargs):
4548
'RequestHandler',
4649
'build_error_response',
4750
'prepare_response_object',
51+
'validate_request_params',
4852
]

src/a2a/server/request_handlers/default_request_handler.py

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
)
2121
from a2a.server.request_handlers.request_handler import (
2222
RequestHandler,
23+
validate_request_params,
2324
)
2425
from a2a.server.tasks import (
2526
PushNotificationConfigStore,
@@ -58,7 +59,6 @@
5859
validate_page_size,
5960
)
6061
from a2a.utils.telemetry import SpanKind, trace_class
61-
from a2a.utils.proto_utils import validate_proto_required_fields
6262

6363

6464
logger = logging.getLogger(__name__)
@@ -121,13 +121,13 @@ def __init__( # noqa: PLR0913
121121
# asyncio tasks and to surface unexpected exceptions.
122122
self._background_tasks = set()
123123

124+
@validate_request_params
124125
async def on_get_task(
125126
self,
126127
params: GetTaskRequest,
127128
context: ServerCallContext,
128129
) -> Task | None:
129130
"""Default handler for 'tasks/get'."""
130-
validate_proto_required_fields(params)
131131
validate_history_length(params)
132132

133133
task_id = params.id
@@ -137,13 +137,13 @@ async def on_get_task(
137137

138138
return apply_history_length(task, params)
139139

140+
@validate_request_params
140141
async def on_list_tasks(
141142
self,
142143
params: ListTasksRequest,
143144
context: ServerCallContext,
144145
) -> ListTasksResponse:
145146
"""Default handler for 'tasks/list'."""
146-
validate_proto_required_fields(params)
147147
validate_history_length(params)
148148
if params.HasField('page_size'):
149149
validate_page_size(params.page_size)
@@ -159,6 +159,7 @@ async def on_list_tasks(
159159

160160
return page
161161

162+
@validate_request_params
162163
async def on_cancel_task(
163164
self,
164165
params: CancelTaskRequest,
@@ -168,7 +169,6 @@ async def on_cancel_task(
168169
169170
Attempts to cancel the task managed by the `AgentExecutor`.
170171
"""
171-
validate_proto_required_fields(params)
172172
task_id = params.id
173173
task: Task | None = await self.task_store.get(task_id, context)
174174
if not task:
@@ -323,6 +323,7 @@ async def _send_push_notification_if_needed(
323323
):
324324
await self._push_sender.send_notification(task_id, event)
325325

326+
@validate_request_params
326327
async def on_message_send(
327328
self,
328329
params: SendMessageRequest,
@@ -333,7 +334,6 @@ async def on_message_send(
333334
Starts the agent execution for the message and waits for the final
334335
result (Task or Message).
335336
"""
336-
validate_proto_required_fields(params)
337337
validate_history_length(params.configuration)
338338

339339
(
@@ -393,6 +393,7 @@ async def push_notification_callback(event: Event) -> None:
393393

394394
return result
395395

396+
@validate_request_params
396397
async def on_message_send_stream(
397398
self,
398399
params: SendMessageRequest,
@@ -403,7 +404,6 @@ async def on_message_send_stream(
403404
Starts the agent execution and yields events as they are produced
404405
by the agent.
405406
"""
406-
validate_proto_required_fields(params)
407407
(
408408
_task_manager,
409409
task_id,
@@ -482,6 +482,7 @@ async def _cleanup_producer(
482482
async with self._running_agents_lock:
483483
self._running_agents.pop(task_id, None)
484484

485+
@validate_request_params
485486
async def on_create_task_push_notification_config(
486487
self,
487488
params: TaskPushNotificationConfig,
@@ -491,7 +492,6 @@ async def on_create_task_push_notification_config(
491492
492493
Requires a `PushNotifier` to be configured.
493494
"""
494-
validate_proto_required_fields(params)
495495
if not self._push_config_store:
496496
raise UnsupportedOperationError
497497

@@ -508,6 +508,7 @@ async def on_create_task_push_notification_config(
508508

509509
return params
510510

511+
@validate_request_params
511512
async def on_get_task_push_notification_config(
512513
self,
513514
params: GetTaskPushNotificationConfigRequest,
@@ -517,7 +518,6 @@ async def on_get_task_push_notification_config(
517518
518519
Requires a `PushConfigStore` to be configured.
519520
"""
520-
validate_proto_required_fields(params)
521521
if not self._push_config_store:
522522
raise UnsupportedOperationError
523523

@@ -540,6 +540,7 @@ async def on_get_task_push_notification_config(
540540

541541
raise InternalError(message='Push notification config not found')
542542

543+
@validate_request_params
543544
async def on_subscribe_to_task(
544545
self,
545546
params: SubscribeToTaskRequest,
@@ -550,7 +551,6 @@ async def on_subscribe_to_task(
550551
Allows a client to re-attach to a running streaming task's event stream.
551552
Requires the task and its queue to still be active.
552553
"""
553-
validate_proto_required_fields(params)
554554
task_id = params.id
555555
task: Task | None = await self.task_store.get(task_id, context)
556556
if not task:
@@ -583,6 +583,7 @@ async def on_subscribe_to_task(
583583
async for event in result_aggregator.consume_and_emit(consumer):
584584
yield event
585585

586+
@validate_request_params
586587
async def on_list_task_push_notification_configs(
587588
self,
588589
params: ListTaskPushNotificationConfigsRequest,
@@ -592,7 +593,6 @@ async def on_list_task_push_notification_configs(
592593
593594
Requires a `PushConfigStore` to be configured.
594595
"""
595-
validate_proto_required_fields(params)
596596
if not self._push_config_store:
597597
raise UnsupportedOperationError
598598

@@ -609,6 +609,7 @@ async def on_list_task_push_notification_configs(
609609
configs=push_notification_config_list
610610
)
611611

612+
@validate_request_params
612613
async def on_delete_task_push_notification_config(
613614
self,
614615
params: DeleteTaskPushNotificationConfigRequest,
@@ -618,7 +619,6 @@ async def on_delete_task_push_notification_config(
618619
619620
Requires a `PushConfigStore` to be configured.
620621
"""
621-
validate_proto_required_fields(params)
622622
if not self._push_config_store:
623623
raise UnsupportedOperationError
624624

src/a2a/server/request_handlers/request_handler.py

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -227,8 +227,25 @@ async def on_delete_task_push_notification_config(
227227
"""
228228

229229

230-
def _validate_request_params(method: Callable) -> Callable:
230+
def validate_request_params(method: Callable) -> Callable:
231231
"""Decorator for RequestHandler methods to validate required fields on incoming requests."""
232+
if inspect.isasyncgenfunction(method):
233+
234+
@functools.wraps(method)
235+
async def async_gen_wrapper(
236+
self: RequestHandler,
237+
params: ProtoMessage,
238+
context: ServerCallContext,
239+
*args: Any,
240+
**kwargs: Any,
241+
) -> Any:
242+
if params is not None:
243+
validate_proto_required_fields(params)
244+
async for item in method(self, params, context, *args, **kwargs):
245+
yield item
246+
247+
return async_gen_wrapper
248+
232249
if inspect.iscoroutinefunction(method):
233250

234251
@functools.wraps(method)

0 commit comments

Comments
 (0)