Skip to content

Commit 99bb2c8

Browse files
committed
refactor: address PR review comments
- Fix client factory transport preference selection - Update proto handling to use MergeFrom/CopyFrom correctly - Remove preserving_proto_field_name from MessageToDict - Fix history length limiting logic - Improve SendMessageResponse construction in handler Signed-off-by: Luca Muscariello <muscariello@ieee.org>
1 parent d7fb690 commit 99bb2c8

6 files changed

Lines changed: 40 additions & 43 deletions

File tree

src/a2a/client/base_client.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -80,10 +80,6 @@ async def send_message(
8080

8181
if configuration:
8282
config.MergeFrom(configuration)
83-
# Proto3 doesn't support HasField for scalars, so MergeFrom won't
84-
# override with default values (e.g. False). We explicitly set it here
85-
# assuming configuration is authoritative.
86-
config.blocking = configuration.blocking
8783

8884
send_message_request = SendMessageRequest(
8985
message=request, configuration=config, metadata=request_metadata

src/a2a/client/client_factory.py

Lines changed: 28 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -19,11 +19,11 @@
1919
AgentCard,
2020
AgentInterface,
2121
)
22-
23-
24-
TRANSPORT_PROTOCOLS_JSONRPC = 'JSONRPC'
25-
TRANSPORT_PROTOCOLS_GRPC = 'GRPC'
26-
TRANSPORT_PROTOCOLS_HTTP_JSON = 'HTTP+JSON'
22+
from a2a.utils.constants import (
23+
TRANSPORT_GRPC,
24+
TRANSPORT_HTTP_JSON,
25+
TRANSPORT_JSONRPC,
26+
)
2727

2828

2929
try:
@@ -74,9 +74,9 @@ def __init__(
7474

7575
def _register_defaults(self, supported: list[str]) -> None:
7676
# Empty support list implies JSON-RPC only.
77-
if TRANSPORT_PROTOCOLS_JSONRPC in supported or not supported:
77+
if TRANSPORT_JSONRPC in supported or not supported:
7878
self.register(
79-
TRANSPORT_PROTOCOLS_JSONRPC,
79+
TRANSPORT_JSONRPC,
8080
lambda card, url, config, interceptors: JsonRpcTransport(
8181
config.httpx_client or httpx.AsyncClient(),
8282
card,
@@ -85,9 +85,9 @@ def _register_defaults(self, supported: list[str]) -> None:
8585
config.extensions or None,
8686
),
8787
)
88-
if TRANSPORT_PROTOCOLS_HTTP_JSON in supported:
88+
if TRANSPORT_HTTP_JSON in supported:
8989
self.register(
90-
TRANSPORT_PROTOCOLS_HTTP_JSON,
90+
TRANSPORT_HTTP_JSON,
9191
lambda card, url, config, interceptors: RestTransport(
9292
config.httpx_client or httpx.AsyncClient(),
9393
card,
@@ -96,14 +96,14 @@ def _register_defaults(self, supported: list[str]) -> None:
9696
config.extensions or None,
9797
),
9898
)
99-
if TRANSPORT_PROTOCOLS_GRPC in supported:
99+
if TRANSPORT_GRPC in supported:
100100
if GrpcTransport is None:
101101
raise ImportError(
102102
'To use GrpcClient, its dependencies must be installed. '
103103
'You can install them with \'pip install "a2a-sdk[grpc]"\''
104104
)
105105
self.register(
106-
TRANSPORT_PROTOCOLS_GRPC,
106+
TRANSPORT_GRPC,
107107
GrpcTransport.create,
108108
)
109109

@@ -206,25 +206,30 @@ def create(
206206
If there is no valid matching of the client configuration with the
207207
server configuration, a `ValueError` is raised.
208208
"""
209-
server_set = {
210-
x.protocol_binding: x.url for x in card.supported_interfaces
211-
}
212209
client_set = self._config.supported_protocol_bindings or [
213-
TRANSPORT_PROTOCOLS_JSONRPC
210+
TRANSPORT_JSONRPC
214211
]
215212
transport_protocol = None
216213
transport_url = None
217214
if self._config.use_client_preference:
218-
for x in client_set:
219-
if x in server_set:
220-
transport_protocol = x
221-
transport_url = server_set[x]
215+
for protocol_binding in client_set:
216+
supported_interface = next(
217+
(
218+
si
219+
for si in card.supported_interfaces
220+
if si.protocol_binding == protocol_binding
221+
),
222+
None,
223+
)
224+
if supported_interface:
225+
transport_protocol = protocol_binding
226+
transport_url = supported_interface.url
222227
break
223228
else:
224-
for x, url in server_set.items():
225-
if x in client_set:
226-
transport_protocol = x
227-
transport_url = url
229+
for supported_interface in card.supported_interfaces:
230+
if supported_interface.protocol_binding in client_set:
231+
transport_protocol = supported_interface.protocol_binding
232+
transport_url = supported_interface.url
228233
break
229234
if not transport_protocol or not transport_url:
230235
raise ValueError('no compatible transports found.')

src/a2a/server/apps/rest/rest_adapter.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -158,7 +158,7 @@ async def handle_get_agent_card(
158158
if self.card_modifier:
159159
card_to_serve = self.card_modifier(card_to_serve)
160160

161-
return MessageToDict(card_to_serve, preserving_proto_field_name=True)
161+
return MessageToDict(card_to_serve)
162162

163163
async def handle_authenticated_agent_card(
164164
self, request: Request, call_context: ServerCallContext | None = None

src/a2a/server/request_handlers/jsonrpc_handler.py

Lines changed: 8 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
ListTaskPushNotificationConfigRequest,
2121
Message,
2222
SendMessageRequest,
23+
SendMessageResponse,
2324
SetTaskPushNotificationConfigRequest,
2425
SubscribeToTaskRequest,
2526
Task,
@@ -114,22 +115,16 @@ async def on_message_send(
114115
request, context
115116
)
116117
# Build result based on return type
118+
response = SendMessageResponse()
117119
if isinstance(task_or_message, Task):
118-
result = {
119-
'task': MessageToDict(
120-
task_or_message, preserving_proto_field_name=False
121-
)
122-
}
120+
response.task.CopyFrom(task_or_message)
123121
elif isinstance(task_or_message, Message):
124-
result = {
125-
'message': MessageToDict(
126-
task_or_message, preserving_proto_field_name=False
127-
)
128-
}
122+
response.message.CopyFrom(task_or_message)
129123
else:
130-
result = MessageToDict(
131-
task_or_message, preserving_proto_field_name=False
132-
)
124+
# Should we handle this fallthrough?
125+
pass
126+
127+
result = MessageToDict(response)
133128
return _build_success_response(request_id, result)
134129
except ServerError as e:
135130
return _build_error_response(

src/a2a/server/tasks/task_manager.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -142,7 +142,7 @@ async def save_task_event(
142142
if task.status.HasField('message'):
143143
task.history.append(task.status.message)
144144
if event.metadata:
145-
task.metadata.update(dict(event.metadata)) # type: ignore[arg-type]
145+
task.metadata.MergeFrom(event.metadata)
146146
task.status.CopyFrom(event.status)
147147
else:
148148
logger.debug('Appending artifact to task %s', task.id)

src/a2a/utils/task.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -96,6 +96,7 @@ def apply_history_length(task: Task, history_length: int | None) -> Task:
9696
task_copy = Task()
9797
task_copy.CopyFrom(task)
9898
# Clear and re-add history items
99-
task_copy.history[:] = limited_history
99+
del task_copy.history[:]
100+
task_copy.history.extend(limited_history)
100101
return task_copy
101102
return task

0 commit comments

Comments
 (0)