Skip to content

Commit 9562522

Browse files
committed
feat(server): validate fields presence according to google.api.field_behavior annotations
1 parent 24f5f1e commit 9562522

5 files changed

Lines changed: 200 additions & 20 deletions

File tree

src/a2a/server/request_handlers/default_request_handler.py

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,10 @@
1818
InMemoryQueueManager,
1919
QueueManager,
2020
)
21-
from a2a.server.request_handlers.request_handler import RequestHandler
21+
from a2a.server.request_handlers.request_handler import (
22+
RequestHandler,
23+
validate_request_params,
24+
)
2225
from a2a.server.tasks import (
2326
PushNotificationConfigStore,
2427
PushNotificationEvent,
@@ -118,6 +121,7 @@ def __init__( # noqa: PLR0913
118121
# asyncio tasks and to surface unexpected exceptions.
119122
self._background_tasks = set()
120123

124+
@validate_request_params
121125
async def on_get_task(
122126
self,
123127
params: GetTaskRequest,
@@ -133,6 +137,7 @@ async def on_get_task(
133137

134138
return apply_history_length(task, params)
135139

140+
@validate_request_params
136141
async def on_list_tasks(
137142
self,
138143
params: ListTasksRequest,
@@ -154,6 +159,7 @@ async def on_list_tasks(
154159

155160
return page
156161

162+
@validate_request_params
157163
async def on_cancel_task(
158164
self,
159165
params: CancelTaskRequest,
@@ -317,6 +323,7 @@ async def _send_push_notification_if_needed(
317323
):
318324
await self._push_sender.send_notification(task_id, event)
319325

326+
@validate_request_params
320327
async def on_message_send(
321328
self,
322329
params: SendMessageRequest,
@@ -386,6 +393,7 @@ async def push_notification_callback(event: Event) -> None:
386393

387394
return result
388395

396+
@validate_request_params
389397
async def on_message_send_stream(
390398
self,
391399
params: SendMessageRequest,
@@ -474,6 +482,7 @@ async def _cleanup_producer(
474482
async with self._running_agents_lock:
475483
self._running_agents.pop(task_id, None)
476484

485+
@validate_request_params
477486
async def on_create_task_push_notification_config(
478487
self,
479488
params: TaskPushNotificationConfig,
@@ -499,6 +508,7 @@ async def on_create_task_push_notification_config(
499508

500509
return params
501510

511+
@validate_request_params
502512
async def on_get_task_push_notification_config(
503513
self,
504514
params: GetTaskPushNotificationConfigRequest,
@@ -530,6 +540,7 @@ async def on_get_task_push_notification_config(
530540

531541
raise InternalError(message='Push notification config not found')
532542

543+
@validate_request_params
533544
async def on_subscribe_to_task(
534545
self,
535546
params: SubscribeToTaskRequest,
@@ -572,6 +583,7 @@ async def on_subscribe_to_task(
572583
async for event in result_aggregator.consume_and_emit(consumer):
573584
yield event
574585

586+
@validate_request_params
575587
async def on_list_task_push_notification_configs(
576588
self,
577589
params: ListTaskPushNotificationConfigsRequest,
@@ -597,6 +609,7 @@ async def on_list_task_push_notification_configs(
597609
configs=push_notification_config_list
598610
)
599611

612+
@validate_request_params
600613
async def on_delete_task_push_notification_config(
601614
self,
602615
params: DeleteTaskPushNotificationConfigRequest,

src/a2a/server/request_handlers/request_handler.py

Lines changed: 42 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,11 @@
1+
import functools
2+
import inspect
3+
14
from abc import ABC, abstractmethod
2-
from collections.abc import AsyncGenerator
5+
from collections.abc import AsyncGenerator, Callable
6+
from typing import Any
7+
8+
from google.protobuf.message import Message as ProtoMessage
39

410
from a2a.server.context import ServerCallContext
511
from a2a.server.events.event_queue import Event
@@ -19,6 +25,7 @@
1925
TaskPushNotificationConfig,
2026
)
2127
from a2a.utils.errors import UnsupportedOperationError
28+
from a2a.utils.proto_utils import validate_proto_required_fields
2229

2330

2431
class RequestHandler(ABC):
@@ -218,3 +225,37 @@ async def on_delete_task_push_notification_config(
218225
Returns:
219226
None
220227
"""
228+
229+
230+
def validate_request_params(method: Callable) -> Callable:
231+
"""Decorator for RequestHandler methods to validate required fields on incoming requests."""
232+
if inspect.isasyncgenfunction(method):
233+
234+
@functools.wraps(method)
235+
async def async_generator_wrapper(
236+
self: RequestHandler,
237+
params: ProtoMessage,
238+
context: ServerCallContext,
239+
*args: Any,
240+
**kwargs: Any,
241+
) -> AsyncGenerator:
242+
if params is not None:
243+
validate_proto_required_fields(params)
244+
async for item in method(self, params, context, *args, **kwargs):
245+
yield item
246+
247+
return async_generator_wrapper
248+
249+
@functools.wraps(method)
250+
async def async_wrapper(
251+
self: RequestHandler,
252+
params: ProtoMessage,
253+
context: ServerCallContext,
254+
*args: Any,
255+
**kwargs: Any,
256+
) -> Any:
257+
if params is not None:
258+
validate_proto_required_fields(params)
259+
return await method(self, params, context, *args, **kwargs)
260+
261+
return async_wrapper

src/a2a/utils/errors.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,9 +21,10 @@ class A2AError(Exception):
2121
message: str = 'A2A Error'
2222
data: dict | None = None
2323

24-
def __init__(self, message: str | None = None):
24+
def __init__(self, message: str | None = None, data: dict | None = None):
2525
if message:
2626
self.message = message
27+
self.data = data
2728
super().__init__(self.message)
2829

2930

src/a2a/utils/proto_utils.py

Lines changed: 108 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,11 +17,15 @@
1717
This module provides helper functions for common proto type operations.
1818
"""
1919

20-
from typing import TYPE_CHECKING, Any
20+
from typing import TYPE_CHECKING, Any, TypedDict
2121

22+
from google.api.field_behavior_pb2 import FieldBehavior, field_behavior
23+
from google.protobuf.descriptor import FieldDescriptor
2224
from google.protobuf.json_format import ParseDict
2325
from google.protobuf.message import Message as ProtobufMessage
2426

27+
from a2a.utils.errors import InvalidParamsError
28+
2529

2630
if TYPE_CHECKING:
2731
from starlette.datastructures import QueryParams
@@ -189,3 +193,106 @@ def parse_params(params: QueryParams, message: ProtobufMessage) -> None:
189193
processed[k] = parsed_val
190194

191195
ParseDict(processed, message, ignore_unknown_fields=True)
196+
197+
198+
class ValidationDetail(TypedDict):
199+
"""Structured validation error detail."""
200+
201+
field: str
202+
message: str
203+
204+
205+
def _check_required_field_violation(
206+
msg: ProtobufMessage, field: FieldDescriptor
207+
) -> ValidationDetail | None:
208+
"""Check if a required field is missing or invalid."""
209+
val = getattr(msg, field.name)
210+
if field.is_repeated:
211+
if not val:
212+
return ValidationDetail(
213+
field=field.name,
214+
message='Field must contain at least one element.',
215+
)
216+
elif field.has_presence:
217+
if not msg.HasField(field.name):
218+
return ValidationDetail(
219+
field=field.name, message='Field is required.'
220+
)
221+
elif val == field.default_value:
222+
return ValidationDetail(field=field.name, message='Field is required.')
223+
return None
224+
225+
226+
def _append_nested_errors(
227+
errors: list[ValidationDetail],
228+
prefix: str,
229+
sub_errs: list[ValidationDetail],
230+
) -> None:
231+
"""Format nested validation errors and append to errors list."""
232+
for sub in sub_errs:
233+
sub_field = sub['field']
234+
errors.append(
235+
ValidationDetail(
236+
field=f'{prefix}.{sub_field}' if sub_field else prefix,
237+
message=sub['message'],
238+
)
239+
)
240+
241+
242+
def _recurse_validation(
243+
msg: ProtobufMessage, field: FieldDescriptor
244+
) -> list[ValidationDetail]:
245+
"""Recurse validation for nested messages and map fields."""
246+
errors: list[ValidationDetail] = []
247+
if field.type != FieldDescriptor.TYPE_MESSAGE:
248+
return errors
249+
250+
val = getattr(msg, field.name)
251+
if not field.is_repeated:
252+
if msg.HasField(field.name):
253+
sub_errs = _validate_proto_required_fields_internal(val)
254+
_append_nested_errors(errors, field.name, sub_errs)
255+
elif field.message_type.GetOptions().map_entry:
256+
for k, v in val.items():
257+
if isinstance(v, ProtobufMessage):
258+
sub_errs = _validate_proto_required_fields_internal(v)
259+
_append_nested_errors(errors, f'{field.name}[{k}]', sub_errs)
260+
else:
261+
for i, item in enumerate(val):
262+
sub_errs = _validate_proto_required_fields_internal(item)
263+
_append_nested_errors(errors, f'{field.name}[{i}]', sub_errs)
264+
return errors
265+
266+
267+
def _validate_proto_required_fields_internal(
268+
msg: ProtobufMessage,
269+
) -> list[ValidationDetail]:
270+
"""Internal validation that returns a list of error dictionaries."""
271+
desc = msg.DESCRIPTOR
272+
errors: list[ValidationDetail] = []
273+
274+
for field in desc.fields:
275+
options = field.GetOptions()
276+
if FieldBehavior.REQUIRED in options.Extensions[field_behavior]:
277+
violation = _check_required_field_violation(msg, field)
278+
if violation:
279+
errors.append(violation)
280+
errors.extend(_recurse_validation(msg, field))
281+
return errors
282+
283+
284+
def validate_proto_required_fields(msg: ProtobufMessage) -> None:
285+
"""Validate that all fields marked as REQUIRED are present on the proto message.
286+
287+
Args:
288+
msg: The Protobuf message to validate.
289+
290+
Raises:
291+
InvalidParamsError: If a required field is missing or empty.
292+
"""
293+
errors = _validate_proto_required_fields_internal(msg)
294+
295+
if errors:
296+
raise InvalidParamsError(
297+
message='Validation failed', data={'errors': errors}
298+
)

0 commit comments

Comments
 (0)