11import argparse
22import asyncio
3- import contextlib
3+ import os
4+ import signal
45import uuid
56
7+ from typing import Any
8+
9+ import grpc
610import httpx
711
812from a2a .client import A2ACardResolver , ClientConfig , ClientFactory
913from a2a .types import Message , Part , Role , SendMessageRequest , TaskState
1014
1115
16+ async def _handle_stream (
17+ stream : Any , current_task_id : str | None
18+ ) -> str | None :
19+ async for event , task in stream :
20+ if not task :
21+ continue
22+ if not current_task_id :
23+ current_task_id = task .id
24+
25+ if event :
26+ if event .HasField ('status_update' ):
27+ state_name = TaskState .Name (event .status_update .status .state )
28+ print (f'TaskStatusUpdate [state={ state_name } ]:' , end = ' ' )
29+ if event .status_update .status .HasField ('message' ):
30+ for part in event .status_update .status .message .parts :
31+ if part .text :
32+ print (part .text , end = ' ' )
33+ print ()
34+
35+ if (
36+ event .status_update .status .state
37+ == TaskState .TASK_STATE_COMPLETED
38+ ):
39+ current_task_id = None
40+ print ('--- Task Completed ---' )
41+
42+ elif event .HasField ('artifact_update' ):
43+ print (
44+ f'TaskArtifactUpdate [name={ event .artifact_update .artifact .name } ]:' ,
45+ end = ' ' ,
46+ )
47+ for part in event .artifact_update .artifact .parts :
48+ if part .text :
49+ print (part .text , end = ' ' )
50+ print ()
51+
52+ return current_task_id
53+
54+
1255async def main () -> None :
1356 """Run the A2A terminal client."""
1457 parser = argparse .ArgumentParser (description = 'A2A Terminal Client' )
@@ -48,7 +91,8 @@ async def main() -> None:
4891
4992 while True :
5093 try :
51- user_input = input ('You: ' )
94+ loop = asyncio .get_running_loop ()
95+ user_input = await loop .run_in_executor (None , input , 'You: ' )
5296 except KeyboardInterrupt :
5397 break
5498
@@ -69,49 +113,13 @@ async def main() -> None:
69113
70114 try :
71115 stream = client .send_message (request )
72- async for event , task in stream :
73- if not task :
74- continue
75- if not current_task_id :
76- current_task_id = task .id
77-
78- if event :
79- if event .HasField ('status_update' ):
80- state_name = TaskState .Name (
81- event .status_update .status .state
82- )
83- print (f'TaskStatusUpdate [{ state_name } ]:' , end = ' ' )
84- if event .status_update .status .HasField ('message' ):
85- for (
86- part
87- ) in event .status_update .status .message .parts :
88- if part .text :
89- print (part .text , end = ' ' )
90- print ()
91-
92- if (
93- event .status_update .status .state
94- == TaskState .TASK_STATE_COMPLETED
95- ):
96- current_task_id = None
97- print ('--- Task Completed ---' )
98-
99- elif event .HasField ('artifact_update' ):
100- print (
101- f'TaskArtifactUpdate [{ event .artifact_update .artifact .name } ]:' ,
102- end = ' ' ,
103- )
104- for part in event .artifact_update .artifact .parts :
105- if part .text :
106- print (part .text , end = ' ' )
107- print ()
108-
109- except Exception as e :
116+ current_task_id = await _handle_stream (stream , current_task_id )
117+ except (httpx .RequestError , grpc .RpcError ) as e :
110118 print (f'Error communicating with agent: { e } ' )
111119
112120 await client .close ()
113121
114122
115123if __name__ == '__main__' :
116- with contextlib . suppress ( KeyboardInterrupt , asyncio . CancelledError ):
117- asyncio .run (main ())
124+ signal . signal ( signal . SIGINT , lambda sig , frame : os . _exit ( 0 ))
125+ asyncio .run (main ())
0 commit comments