11import argparse
22import asyncio
33import logging
4+ from typing import cast
45from uuid import uuid4
56
67import httpx
1516 Message ,
1617 Part ,
1718 Role ,
19+ TaskArtifactUpdateEvent ,
1820 TextPart ,
1921)
2022from a2a .utils .constants import (
3032
3133BASE_URL = "http://localhost:9999"
3234
35+ # Two echo agent instances to multicast to
36+ AGENT_NAMES = [
37+ "agntcy/demo/echo_agent1" ,
38+ "agntcy/demo/echo_agent2" ,
39+ ]
40+
3341logger = logging .getLogger (__name__ )
3442
3543
@@ -66,7 +74,7 @@ async def main() -> None:
6674
6775 client_config = ClientConfig (
6876 supported_transports = ["JSONRPC" , "slimrpc" ],
69- streaming = args . stream ,
77+ streaming = True ,
7078 httpx_client = httpx_client ,
7179 slimrpc_channel_factory = slimrpc_channel_factory (slim_local_app , conn_id ),
7280 )
@@ -78,7 +86,7 @@ async def main() -> None:
7886 agent_card : AgentCard
7987 match args .type :
8088 case "slimrpc" :
81- agent_card = minimal_agent_card ("agntcy/demo/echo_agent" , ["slimrpc" ])
89+ agent_card = minimal_agent_card ("," . join ( AGENT_NAMES ) , ["slimrpc" ])
8290 case "starlette" :
8391 agent_card = await fetch_agent_card (
8492 resolver = A2ACardResolver (
@@ -90,6 +98,19 @@ async def main() -> None:
9098 raise ValueError (f"Invalid client type: { args .type } " )
9199
92100 client = client_factory .create (card = agent_card )
101+
102+ if args .type == "slimrpc" :
103+ # Fetch agent cards from all servers in the group.
104+ transport = cast (SRPCTransport , client ._transport ) # type: ignore[attr-defined]
105+ async for card in transport .get_all_cards ():
106+ logger .info (
107+ f"agent card: { card .model_dump_json (indent = 2 , exclude_none = True )} "
108+ )
109+
110+ # Fetch the real card from the first server so that
111+ # client._card.capabilities.streaming=True before send_message.
112+ await client .get_card ()
113+
93114 logger .info ("A2AClient initialized." )
94115
95116 response_text = await send_message (client , args .text )
@@ -106,12 +127,6 @@ def parse_arguments() -> argparse.Namespace:
106127 required = False ,
107128 default = "ERROR" ,
108129 )
109- parser .add_argument (
110- "--stream" ,
111- action = "store_true" ,
112- required = False ,
113- default = False ,
114- )
115130 parser .add_argument (
116131 "--text" ,
117132 type = str ,
@@ -144,24 +159,25 @@ async def send_message(
144159 )
145160 logger .info (f"associated request ({ request_id } ) with text: { text } " )
146161
147- output = ""
162+ # Group output by task_id so responses from each server are kept separate.
163+ outputs : dict [str , str ] = {}
148164 try :
149165 async for event in client .send_message (request = request ):
150166 if isinstance (event , Message ):
167+ outputs .setdefault ("msg" , "" )
151168 for part in event .parts :
152169 if isinstance (part .root , TextPart ):
153- output += part .root .text
170+ outputs [ "msg" ] += part .root .text
154171 else :
155172 task , update = event
156173 logger .info (f"task ({ task .id } ) status: { task .status .state } " )
157174
158- if task .status .state == "completed" and task .artifacts :
159- for artifact in task .artifacts :
160- for part in artifact .parts :
161- if isinstance (part .root , TextPart ):
162- output += part .root .text
163-
164- if update :
175+ if isinstance (update , TaskArtifactUpdateEvent ):
176+ outputs .setdefault (update .task_id , "" )
177+ for part in update .artifact .parts :
178+ if isinstance (part .root , TextPart ):
179+ outputs [update .task_id ] += part .root .text
180+ elif update :
165181 logger .info (f"update: { update .model_dump (mode = 'json' )} " )
166182 except Exception as e :
167183 logger .error (
@@ -170,7 +186,7 @@ async def send_message(
170186 )
171187 raise RuntimeError ("failed sending message or processing response" ) from e
172188
173- return output
189+ return " \n --- \n " . join ( outputs . values ())
174190
175191
176192if __name__ == "__main__" :
0 commit comments