Skip to content

Commit 027637d

Browse files
committed
Initial version of A2A version header validation.
1 parent 24f5f1e commit 027637d

16 files changed

Lines changed: 747 additions & 74 deletions

File tree

src/a2a/compat/v0_3/jsonrpc_adapter.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,8 +38,9 @@
3838
from a2a.server.jsonrpc_models import (
3939
JSONRPCError as CoreJSONRPCError,
4040
)
41+
from a2a.utils import constants
4142
from a2a.utils.errors import ExtendedAgentCardNotConfiguredError
42-
from a2a.utils.helpers import maybe_await
43+
from a2a.utils.helpers import maybe_await, validate_version
4344

4445

4546
logger = logging.getLogger(__name__)
@@ -152,6 +153,7 @@ async def handle_request(
152153
request_id, CoreInternalError(message=str(e))
153154
)
154155

156+
@validate_version(constants.PROTOCOL_VERSION_0_3)
155157
async def _process_non_streaming_request(
156158
self,
157159
request_id: 'str | int | None',
@@ -266,6 +268,7 @@ async def get_authenticated_extended_card(
266268

267269
return conversions.to_compat_agent_card(card_to_serve)
268270

271+
@validate_version(constants.PROTOCOL_VERSION_0_3)
269272
async def _process_streaming_request(
270273
self,
271274
request_id: 'str | int | None',

src/a2a/compat/v0_3/rest_handler.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,12 @@
2828
from a2a.compat.v0_3 import types as types_v03
2929
from a2a.compat.v0_3.request_handler import RequestHandler03
3030
from a2a.server.context import ServerCallContext
31-
from a2a.utils.helpers import validate, validate_async_generator
31+
from a2a.utils import constants
32+
from a2a.utils.helpers import (
33+
validate,
34+
validate_async_generator,
35+
validate_version,
36+
)
3237
from a2a.utils.telemetry import SpanKind, trace_class
3338

3439

@@ -53,6 +58,7 @@ def __init__(
5358
self.agent_card = agent_card
5459
self.handler03 = RequestHandler03(request_handler=request_handler)
5560

61+
@validate_version(constants.PROTOCOL_VERSION_0_3)
5662
async def on_message_send(
5763
self,
5864
request: Request,
@@ -78,6 +84,7 @@ async def on_message_send(
7884
pb2_v03_resp = proto_utils.ToProto.task_or_message(v03_resp)
7985
return MessageToDict(pb2_v03_resp)
8086

87+
@validate_version(constants.PROTOCOL_VERSION_0_3)
8188
@validate_async_generator(
8289
lambda self: self.agent_card.capabilities.streaming,
8390
'Streaming is not supported by the agent',

src/a2a/server/apps/jsonrpc/jsonrpc_app.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -254,7 +254,7 @@ def __init__( # noqa: PLR0913
254254
agent_card=agent_card,
255255
http_handler=http_handler,
256256
extended_agent_card=extended_agent_card,
257-
context_builder=context_builder,
257+
context_builder=self._context_builder,
258258
card_modifier=card_modifier,
259259
extended_card_modifier=extended_card_modifier,
260260
)

src/a2a/server/request_handlers/jsonrpc_handler.py

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@
3131
Task,
3232
TaskPushNotificationConfig,
3333
)
34-
from a2a.utils import proto_utils
34+
from a2a.utils import constants, proto_utils
3535
from a2a.utils.errors import (
3636
JSON_RPC_ERROR_CODE_MAP,
3737
A2AError,
@@ -49,7 +49,12 @@
4949
UnsupportedOperationError,
5050
VersionNotSupportedError,
5151
)
52-
from a2a.utils.helpers import maybe_await, validate, validate_async_generator
52+
from a2a.utils.helpers import (
53+
maybe_await,
54+
validate,
55+
validate_async_generator,
56+
validate_version,
57+
)
5358
from a2a.utils.telemetry import SpanKind, trace_class
5459

5560

@@ -142,6 +147,7 @@ def _get_request_id(
142147
return None
143148
return context.state.get('request_id')
144149

150+
@validate_version(constants.PROTOCOL_VERSION_1_0)
145151
async def on_message_send(
146152
self,
147153
request: SendMessageRequest,
@@ -171,6 +177,11 @@ async def on_message_send(
171177
except A2AError as e:
172178
return _build_error_response(request_id, e)
173179

180+
@validate_async_generator(
181+
lambda self: self.agent_card.capabilities.streaming,
182+
'Streaming is not supported by the agent',
183+
)
184+
@validate_version(constants.PROTOCOL_VERSION_1_0)
174185
@validate_async_generator(
175186
lambda self: self.agent_card.capabilities.streaming,
176187
'Streaming is not supported by the agent',

src/a2a/server/request_handlers/rest_handler.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,9 +27,13 @@
2727
GetTaskPushNotificationConfigRequest,
2828
SubscribeToTaskRequest,
2929
)
30-
from a2a.utils import proto_utils
30+
from a2a.utils import constants, proto_utils
3131
from a2a.utils.errors import TaskNotFoundError
32-
from a2a.utils.helpers import validate, validate_async_generator
32+
from a2a.utils.helpers import (
33+
validate,
34+
validate_async_generator,
35+
validate_version,
36+
)
3337
from a2a.utils.telemetry import SpanKind, trace_class
3438

3539

@@ -61,6 +65,7 @@ def __init__(
6165
self.agent_card = agent_card
6266
self.request_handler = request_handler
6367

68+
@validate_version(constants.PROTOCOL_VERSION_1_0)
6469
async def on_message_send(
6570
self,
6671
request: Request,
@@ -87,6 +92,7 @@ async def on_message_send(
8792
response = a2a_pb2.SendMessageResponse(message=task_or_message)
8893
return MessageToDict(response)
8994

95+
@validate_version(constants.PROTOCOL_VERSION_1_0)
9096
@validate_async_generator(
9197
lambda self: self.agent_card.capabilities.streaming,
9298
'Streaming is not supported by the agent',

src/a2a/utils/error_handlers.py

Lines changed: 29 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import functools
2+
import inspect
23
import logging
34

45
from collections.abc import Awaitable, Callable, Coroutine
@@ -130,11 +131,8 @@ def rest_stream_error_handler(
130131
) -> Callable[..., Coroutine[Any, Any, Any]]:
131132
"""Decorator to catch A2AError for a streaming method, log it and then rethrow it to be handled by framework."""
132133

133-
@functools.wraps(func)
134-
async def wrapper(*args: Any, **kwargs: Any) -> Any:
135-
try:
136-
return await func(*args, **kwargs)
137-
except A2AError as error:
134+
def _log_error(error: Exception) -> None:
135+
if isinstance(error, A2AError):
138136
log_level = (
139137
logging.ERROR
140138
if isinstance(error, InternalError)
@@ -147,14 +145,33 @@ async def wrapper(*args: Any, **kwargs: Any) -> Any:
147145
getattr(error, 'message', str(error)),
148146
f', Data={error.data}' if error.data else '',
149147
)
150-
# Since the stream has started, we can't return a JSONResponse.
151-
# Instead, we run the error handling logic (provides logging)
152-
# and reraise the error and let server framework manage
153-
raise error
148+
else:
149+
logger.exception('Unknown streaming error occurred')
150+
151+
@functools.wraps(func)
152+
async def wrapper(*args: Any, **kwargs: Any) -> Any:
153+
try:
154+
response = await func(*args, **kwargs)
155+
156+
# If the response has an async generator body (like EventSourceResponse),
157+
# we must wrap it to catch errors that occur during stream execution.
158+
if hasattr(response, 'body_iterator') and inspect.isasyncgen(response.body_iterator):
159+
original_iterator = response.body_iterator
160+
161+
async def error_catching_iterator():
162+
try:
163+
async for item in original_iterator:
164+
yield item
165+
except Exception as stream_error:
166+
_log_error(stream_error)
167+
raise stream_error
168+
169+
response.body_iterator = error_catching_iterator()
170+
171+
return response
172+
154173
except Exception as e:
155-
# Since the stream has started, we can't return a JSONResponse.
156-
# Instead, we run the error handling logic (provides logging)
157-
# and reraise the error and let server framework manage
174+
_log_error(e)
158175
raise e
159176

160177
return wrapper

src/a2a/utils/helpers.py

Lines changed: 121 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5,12 +5,14 @@
55
import json
66
import logging
77

8-
from collections.abc import Awaitable, Callable
9-
from typing import Any, TypeVar
8+
from collections.abc import AsyncIterator, Awaitable, Callable
9+
from typing import Any, TypeVar, cast
1010
from uuid import uuid4
1111

1212
from google.protobuf.json_format import MessageToDict
13+
from packaging.version import InvalidVersion, Version
1314

15+
from a2a.server.context import ServerCallContext
1416
from a2a.types.a2a_pb2 import (
1517
AgentCard,
1618
Artifact,
@@ -21,11 +23,13 @@
2123
TaskState,
2224
TaskStatus,
2325
)
24-
from a2a.utils.errors import UnsupportedOperationError
26+
from a2a.utils import constants
27+
from a2a.utils.errors import UnsupportedOperationError, VersionNotSupportedError
2528
from a2a.utils.telemetry import trace_function
2629

2730

2831
T = TypeVar('T')
32+
F = TypeVar('F', bound=Callable[..., Any])
2933

3034

3135
logger = logging.getLogger(__name__)
@@ -297,7 +301,6 @@ def validate_async_generator(
297301
This decorator is specifically for async generator methods (async def with yield).
298302
The validation happens before the generator starts yielding values.
299303
"""
300-
301304
def decorator(function):
302305
@functools.wraps(function)
303306
async def wrapper(self, *args, **kwargs):
@@ -378,3 +381,117 @@ async def maybe_await(value: T | Awaitable[T]) -> T:
378381
if inspect.isawaitable(value):
379382
return await value
380383
return value
384+
385+
386+
def validate_version(expected_version: str) -> Callable[[F], F]:
387+
"""Decorator that validates the A2A-Version header in the request context.
388+
389+
The header name is defined by `constants.VERSION_HEADER` ('A2A-Version').
390+
If the header is missing or empty, it is interpreted as `constants.PROTOCOL_VERSION_0_3` ('0.3').
391+
If the version in the header does not match the `expected_version` (major and minor parts),
392+
a `VersionNotSupportedError` is raised. Patch version is ignored.
393+
394+
This decorator supports both async methods and async generator methods. It
395+
expects a `ServerCallContext` to be present either in the arguments or
396+
keyword arguments of the decorated method.
397+
398+
Args:
399+
expected_version: The A2A protocol version string expected by the method.
400+
401+
Returns:
402+
The decorated function.
403+
404+
Raises:
405+
VersionNotSupportedError: If the version in the request does not match `expected_version`.
406+
"""
407+
try:
408+
expected_v = Version(expected_version)
409+
except InvalidVersion:
410+
# If the expected version is not a valid semver, we can't do major/minor comparison.
411+
# This shouldn't happen with our constants.
412+
expected_v = None
413+
414+
def decorator(func: F) -> F:
415+
def _get_actual_version(
416+
args: tuple[Any, ...], kwargs: dict[str, Any]
417+
) -> str:
418+
context = kwargs.get('context')
419+
if context is None:
420+
for arg in args:
421+
if isinstance(arg, ServerCallContext):
422+
context = arg
423+
break
424+
425+
if context is None:
426+
# If no context is found, we can't validate the version.
427+
# In a real scenario, this shouldn't happen for properly routed requests.
428+
# We default to the expected version to allow test call to proceed.
429+
return expected_version
430+
431+
headers = context.state.get('headers', {})
432+
# Header names are usually case-insensitive in most frameworks, but dict lookup is case-sensitive.
433+
# We check both standard and lowercase versions.
434+
actual_version = headers.get(
435+
constants.VERSION_HEADER
436+
) or headers.get(constants.VERSION_HEADER.lower())
437+
438+
if not actual_version:
439+
return constants.PROTOCOL_VERSION_0_3
440+
441+
return str(actual_version)
442+
443+
def _is_version_compatible(actual: str) -> bool:
444+
if actual == expected_version:
445+
return True
446+
if not expected_v:
447+
return False
448+
try:
449+
actual_v = Version(actual)
450+
except InvalidVersion:
451+
return False
452+
else:
453+
return (
454+
actual_v.major == expected_v.major
455+
and actual_v.minor == expected_v.minor
456+
)
457+
458+
if inspect.isasyncgenfunction(inspect.unwrap(func)):
459+
460+
@functools.wraps(func)
461+
async def async_gen_wrapper(
462+
self: Any, *args: Any, **kwargs: Any
463+
) -> AsyncIterator[Any]:
464+
actual_version = _get_actual_version(args, kwargs)
465+
if not _is_version_compatible(actual_version):
466+
logger.warning(
467+
"Version mismatch: actual='%s', expected='%s'",
468+
actual_version,
469+
expected_version,
470+
)
471+
raise VersionNotSupportedError(
472+
message=f"A2A version '{actual_version}' is not supported by this handler. "
473+
f"Expected version '{expected_version}'."
474+
)
475+
async for item in func(self, *args, **kwargs):
476+
yield item
477+
478+
return cast('F', async_gen_wrapper)
479+
480+
@functools.wraps(func)
481+
async def async_wrapper(self: Any, *args: Any, **kwargs: Any) -> Any:
482+
actual_version = _get_actual_version(args, kwargs)
483+
if not _is_version_compatible(actual_version):
484+
logger.warning(
485+
"Version mismatch: actual='%s', expected='%s'",
486+
actual_version,
487+
expected_version,
488+
)
489+
raise VersionNotSupportedError(
490+
message=f"A2A version '{actual_version}' is not supported by this handler. "
491+
f"Expected version '{expected_version}'."
492+
)
493+
return await func(self, *args, **kwargs)
494+
495+
return cast('F', async_wrapper)
496+
497+
return decorator

tests/compat/v0_3/test_rest_handler.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,9 @@ def rest_handler(agent_card, mock_core_handler):
3737

3838
@pytest.fixture
3939
def mock_context():
40-
return MagicMock(spec=ServerCallContext)
40+
m = MagicMock(spec=ServerCallContext)
41+
m.state = {'headers': {'A2A-Version': '0.3'}}
42+
return m
4143

4244

4345
@pytest.fixture

0 commit comments

Comments
 (0)