Skip to content

Commit e1d0e7a

Browse files
authored
fix: update with_a2a_extensions to append instead of overwriting (#985)
Existing extensions are kept, enables better modularity of service parameters updates by (for instance) multiple interceptors.
1 parent 934b595 commit e1d0e7a

2 files changed

Lines changed: 66 additions & 9 deletions

File tree

src/a2a/client/service_parameters.py

Lines changed: 13 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,10 @@
11
from collections.abc import Callable
22
from typing import TypeAlias
33

4-
from a2a.extensions.common import HTTP_EXTENSION_HEADER
4+
from a2a.extensions.common import (
5+
HTTP_EXTENSION_HEADER,
6+
get_requested_extensions,
7+
)
58

69

710
ServiceParameters: TypeAlias = dict[str, str]
@@ -44,17 +47,18 @@ def create_from(
4447

4548

4649
def with_a2a_extensions(extensions: list[str]) -> ServiceParametersUpdate:
47-
"""Create a ServiceParametersUpdate that adds A2A extensions.
50+
"""Create a ServiceParametersUpdate that merges A2A extension URIs.
4851
49-
Args:
50-
extensions: List of extension strings.
51-
52-
Returns:
53-
A function that updates ServiceParameters with the extensions header.
52+
Unions the supplied URIs with any already present in the A2A-Extensions
53+
parameter, deduplicating and emitting them in sorted order. Repeated
54+
calls accumulate rather than overwrite.
5455
"""
5556

5657
def update(parameters: ServiceParameters) -> None:
57-
if extensions:
58-
parameters[HTTP_EXTENSION_HEADER] = ','.join(extensions)
58+
if not extensions:
59+
return
60+
existing = parameters.get(HTTP_EXTENSION_HEADER, '')
61+
merged = sorted(get_requested_extensions([existing, *extensions]))
62+
parameters[HTTP_EXTENSION_HEADER] = ','.join(merged)
5963

6064
return update
Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,53 @@
1+
"""Tests for a2a.client.service_parameters module."""
2+
3+
from a2a.client.service_parameters import (
4+
ServiceParametersFactory,
5+
with_a2a_extensions,
6+
)
7+
from a2a.extensions.common import HTTP_EXTENSION_HEADER
8+
9+
10+
def test_with_a2a_extensions_merges_dedupes_and_sorts():
11+
"""Repeated calls accumulate; duplicates collapse; output is sorted."""
12+
parameters = ServiceParametersFactory.create(
13+
[
14+
with_a2a_extensions(['ext-c', 'ext-a']),
15+
with_a2a_extensions(['ext-b', 'ext-a']),
16+
]
17+
)
18+
19+
assert parameters[HTTP_EXTENSION_HEADER] == 'ext-a,ext-b,ext-c'
20+
21+
22+
def test_with_a2a_extensions_merges_existing_header_value():
23+
"""Pre-existing comma-separated header values are parsed and merged."""
24+
parameters = ServiceParametersFactory.create_from(
25+
{HTTP_EXTENSION_HEADER: 'ext-a, ext-b'},
26+
[with_a2a_extensions(['ext-c'])],
27+
)
28+
29+
assert parameters[HTTP_EXTENSION_HEADER] == 'ext-a,ext-b,ext-c'
30+
31+
32+
def test_with_a2a_extensions_empty_is_noop():
33+
"""An empty extensions list leaves the header untouched / absent."""
34+
parameters = ServiceParametersFactory.create(
35+
[
36+
with_a2a_extensions(['ext-a']),
37+
with_a2a_extensions([]),
38+
]
39+
)
40+
41+
assert parameters[HTTP_EXTENSION_HEADER] == 'ext-a'
42+
assert HTTP_EXTENSION_HEADER not in ServiceParametersFactory.create(
43+
[with_a2a_extensions([])]
44+
)
45+
46+
47+
def test_with_a2a_extensions_normalizes_input_strings():
48+
"""Input strings are split on commas and stripped, like header values."""
49+
parameters = ServiceParametersFactory.create(
50+
[with_a2a_extensions(['ext-a, ext-b', ' ext-c '])]
51+
)
52+
53+
assert parameters[HTTP_EXTENSION_HEADER] == 'ext-a,ext-b,ext-c'

0 commit comments

Comments
 (0)