|
5 | 5 | import json |
6 | 6 | import logging |
7 | 7 |
|
8 | | -from collections.abc import Awaitable, Callable |
9 | | -from typing import Any, TypeVar |
| 8 | +from collections.abc import AsyncIterator, Awaitable, Callable |
| 9 | +from typing import Any, TypeVar, cast |
10 | 10 | from uuid import uuid4 |
11 | 11 |
|
12 | 12 | from google.protobuf.json_format import MessageToDict |
| 13 | +from packaging.version import InvalidVersion, Version |
13 | 14 |
|
| 15 | +from a2a.server.context import ServerCallContext |
14 | 16 | from a2a.types.a2a_pb2 import ( |
15 | 17 | AgentCard, |
16 | 18 | Artifact, |
|
21 | 23 | TaskState, |
22 | 24 | TaskStatus, |
23 | 25 | ) |
24 | | -from a2a.utils.errors import UnsupportedOperationError |
| 26 | +from a2a.utils import constants |
| 27 | +from a2a.utils.errors import UnsupportedOperationError, VersionNotSupportedError |
25 | 28 | from a2a.utils.telemetry import trace_function |
26 | 29 |
|
27 | 30 |
|
28 | 31 | T = TypeVar('T') |
| 32 | +F = TypeVar('F', bound=Callable[..., Any]) |
29 | 33 |
|
30 | 34 |
|
31 | 35 | logger = logging.getLogger(__name__) |
@@ -297,7 +301,6 @@ def validate_async_generator( |
297 | 301 | This decorator is specifically for async generator methods (async def with yield). |
298 | 302 | The validation happens before the generator starts yielding values. |
299 | 303 | """ |
300 | | - |
301 | 304 | def decorator(function): |
302 | 305 | @functools.wraps(function) |
303 | 306 | async def wrapper(self, *args, **kwargs): |
@@ -378,3 +381,117 @@ async def maybe_await(value: T | Awaitable[T]) -> T: |
378 | 381 | if inspect.isawaitable(value): |
379 | 382 | return await value |
380 | 383 | return value |
| 384 | + |
| 385 | + |
| 386 | +def validate_version(expected_version: str) -> Callable[[F], F]: |
| 387 | + """Decorator that validates the A2A-Version header in the request context. |
| 388 | +
|
| 389 | + The header name is defined by `constants.VERSION_HEADER` ('A2A-Version'). |
| 390 | + If the header is missing or empty, it is interpreted as `constants.PROTOCOL_VERSION_0_3` ('0.3'). |
| 391 | + If the version in the header does not match the `expected_version` (major and minor parts), |
| 392 | + a `VersionNotSupportedError` is raised. Patch version is ignored. |
| 393 | +
|
| 394 | + This decorator supports both async methods and async generator methods. It |
| 395 | + expects a `ServerCallContext` to be present either in the arguments or |
| 396 | + keyword arguments of the decorated method. |
| 397 | +
|
| 398 | + Args: |
| 399 | + expected_version: The A2A protocol version string expected by the method. |
| 400 | +
|
| 401 | + Returns: |
| 402 | + The decorated function. |
| 403 | +
|
| 404 | + Raises: |
| 405 | + VersionNotSupportedError: If the version in the request does not match `expected_version`. |
| 406 | + """ |
| 407 | + try: |
| 408 | + expected_v = Version(expected_version) |
| 409 | + except InvalidVersion: |
| 410 | + # If the expected version is not a valid semver, we can't do major/minor comparison. |
| 411 | + # This shouldn't happen with our constants. |
| 412 | + expected_v = None |
| 413 | + |
| 414 | + def decorator(func: F) -> F: |
| 415 | + def _get_actual_version( |
| 416 | + args: tuple[Any, ...], kwargs: dict[str, Any] |
| 417 | + ) -> str: |
| 418 | + context = kwargs.get('context') |
| 419 | + if context is None: |
| 420 | + for arg in args: |
| 421 | + if isinstance(arg, ServerCallContext): |
| 422 | + context = arg |
| 423 | + break |
| 424 | + |
| 425 | + if context is None: |
| 426 | + # If no context is found, we can't validate the version. |
| 427 | + # In a real scenario, this shouldn't happen for properly routed requests. |
| 428 | + # We default to the expected version to allow test call to proceed. |
| 429 | + return expected_version |
| 430 | + |
| 431 | + headers = context.state.get('headers', {}) |
| 432 | + # Header names are usually case-insensitive in most frameworks, but dict lookup is case-sensitive. |
| 433 | + # We check both standard and lowercase versions. |
| 434 | + actual_version = headers.get( |
| 435 | + constants.VERSION_HEADER |
| 436 | + ) or headers.get(constants.VERSION_HEADER.lower()) |
| 437 | + |
| 438 | + if not actual_version: |
| 439 | + return constants.PROTOCOL_VERSION_0_3 |
| 440 | + |
| 441 | + return str(actual_version) |
| 442 | + |
| 443 | + def _is_version_compatible(actual: str) -> bool: |
| 444 | + if actual == expected_version: |
| 445 | + return True |
| 446 | + if not expected_v: |
| 447 | + return False |
| 448 | + try: |
| 449 | + actual_v = Version(actual) |
| 450 | + except InvalidVersion: |
| 451 | + return False |
| 452 | + else: |
| 453 | + return ( |
| 454 | + actual_v.major == expected_v.major |
| 455 | + and actual_v.minor == expected_v.minor |
| 456 | + ) |
| 457 | + |
| 458 | + if inspect.isasyncgenfunction(inspect.unwrap(func)): |
| 459 | + |
| 460 | + @functools.wraps(func) |
| 461 | + async def async_gen_wrapper( |
| 462 | + self: Any, *args: Any, **kwargs: Any |
| 463 | + ) -> AsyncIterator[Any]: |
| 464 | + actual_version = _get_actual_version(args, kwargs) |
| 465 | + if not _is_version_compatible(actual_version): |
| 466 | + logger.warning( |
| 467 | + "Version mismatch: actual='%s', expected='%s'", |
| 468 | + actual_version, |
| 469 | + expected_version, |
| 470 | + ) |
| 471 | + raise VersionNotSupportedError( |
| 472 | + message=f"A2A version '{actual_version}' is not supported by this handler. " |
| 473 | + f"Expected version '{expected_version}'." |
| 474 | + ) |
| 475 | + async for item in func(self, *args, **kwargs): |
| 476 | + yield item |
| 477 | + |
| 478 | + return cast('F', async_gen_wrapper) |
| 479 | + |
| 480 | + @functools.wraps(func) |
| 481 | + async def async_wrapper(self: Any, *args: Any, **kwargs: Any) -> Any: |
| 482 | + actual_version = _get_actual_version(args, kwargs) |
| 483 | + if not _is_version_compatible(actual_version): |
| 484 | + logger.warning( |
| 485 | + "Version mismatch: actual='%s', expected='%s'", |
| 486 | + actual_version, |
| 487 | + expected_version, |
| 488 | + ) |
| 489 | + raise VersionNotSupportedError( |
| 490 | + message=f"A2A version '{actual_version}' is not supported by this handler. " |
| 491 | + f"Expected version '{expected_version}'." |
| 492 | + ) |
| 493 | + return await func(self, *args, **kwargs) |
| 494 | + |
| 495 | + return cast('F', async_wrapper) |
| 496 | + |
| 497 | + return decorator |
0 commit comments