|
5 | 5 |
|
6 | 6 | from datetime import datetime, timezone |
7 | 7 |
|
| 8 | +import grpc.aio |
8 | 9 | import uvicorn |
9 | 10 |
|
| 11 | +from starlette.applications import Starlette |
| 12 | + |
| 13 | +import a2a.compat.v0_3.a2a_v0_3_pb2_grpc as a2a_v0_3_grpc |
| 14 | +import a2a.types.a2a_pb2_grpc as a2a_grpc |
| 15 | + |
| 16 | +from a2a.compat.v0_3.grpc_handler import CompatGrpcHandler |
10 | 17 | from a2a.server.agent_execution.agent_executor import AgentExecutor |
11 | 18 | from a2a.server.agent_execution.context import RequestContext |
12 | | -from a2a.server.apps import A2AStarletteApplication |
| 19 | +from a2a.server.apps import ( |
| 20 | + A2ARESTFastAPIApplication, |
| 21 | + A2AStarletteApplication, |
| 22 | +) |
13 | 23 | from a2a.server.events.event_queue import EventQueue |
14 | 24 | from a2a.server.request_handlers.default_request_handler import ( |
15 | 25 | DefaultRequestHandler, |
16 | 26 | ) |
| 27 | +from a2a.server.request_handlers.grpc_handler import GrpcHandler |
17 | 28 | from a2a.server.tasks.inmemory_task_store import InMemoryTaskStore |
18 | 29 | from a2a.server.tasks.task_store import TaskStore |
19 | 30 | from a2a.types import ( |
|
32 | 43 |
|
33 | 44 |
|
34 | 45 | JSONRPC_URL = '/a2a/jsonrpc' |
| 46 | +REST_URL = '/a2a/rest' |
35 | 47 |
|
36 | 48 | logging.basicConfig(level=logging.INFO) |
37 | 49 | logger = logging.getLogger('SUTAgent') |
@@ -133,14 +145,26 @@ def serve(task_store: TaskStore) -> None: |
133 | 145 | """Sets up the A2A service and starts the HTTP server.""" |
134 | 146 | http_port = int(os.environ.get('HTTP_PORT', '41241')) |
135 | 147 |
|
| 148 | + grpc_port = int(os.environ.get('GRPC_PORT', '50051')) |
| 149 | + |
136 | 150 | agent_card = AgentCard( |
137 | 151 | name='SUT Agent', |
138 | 152 | description='An agent to be used as SUT against TCK tests.', |
139 | 153 | supported_interfaces=[ |
140 | 154 | AgentInterface( |
141 | 155 | url=f'http://localhost:{http_port}{JSONRPC_URL}', |
142 | 156 | protocol_binding='JSONRPC', |
143 | | - protocol_version='0.3.0', |
| 157 | + protocol_version='1.0.0', |
| 158 | + ), |
| 159 | + AgentInterface( |
| 160 | + url=f'http://localhost:{http_port}{REST_URL}', |
| 161 | + protocol_binding='REST', |
| 162 | + protocol_version='1.0.0', |
| 163 | + ), |
| 164 | + AgentInterface( |
| 165 | + url=f'http://localhost:{grpc_port}', |
| 166 | + protocol_binding='GRPC', |
| 167 | + protocol_version='1.0.0', |
144 | 168 | ), |
145 | 169 | ], |
146 | 170 | provider=AgentProvider( |
@@ -172,15 +196,49 @@ def serve(task_store: TaskStore) -> None: |
172 | 196 | task_store=task_store, |
173 | 197 | ) |
174 | 198 |
|
175 | | - server = A2AStarletteApplication( |
| 199 | + main_app = Starlette() |
| 200 | + |
| 201 | + # JSONRPC |
| 202 | + jsonrpc_server = A2AStarletteApplication( |
176 | 203 | agent_card=agent_card, |
177 | 204 | http_handler=request_handler, |
178 | 205 | ) |
| 206 | + jsonrpc_server.add_routes_to_app(main_app, rpc_url=JSONRPC_URL) |
179 | 207 |
|
180 | | - app = server.build(rpc_url=JSONRPC_URL) |
| 208 | + # REST |
| 209 | + rest_server = A2ARESTFastAPIApplication( |
| 210 | + agent_card=agent_card, |
| 211 | + http_handler=request_handler, |
| 212 | + ) |
| 213 | + rest_app = rest_server.build(rpc_url=REST_URL) |
| 214 | + main_app.mount('', rest_app) |
| 215 | + |
| 216 | + config = uvicorn.Config( |
| 217 | + main_app, host='127.0.0.1', port=http_port, log_level='info' |
| 218 | + ) |
| 219 | + uvicorn_server = uvicorn.Server(config) |
| 220 | + |
| 221 | + # GRPC |
| 222 | + grpc_server = grpc.aio.server() |
| 223 | + grpc_server.add_insecure_port(f'[::]:{grpc_port}') |
| 224 | + servicer = GrpcHandler(agent_card, request_handler) |
| 225 | + compat_servicer = CompatGrpcHandler(agent_card, request_handler) |
| 226 | + a2a_grpc.add_A2AServiceServicer_to_server(servicer, grpc_server) |
| 227 | + a2a_v0_3_grpc.add_A2AServiceServicer_to_server(compat_servicer, grpc_server) |
| 228 | + |
| 229 | + logger.info( |
| 230 | + 'Starting HTTP server on port %s and gRPC on port %s...', |
| 231 | + http_port, |
| 232 | + grpc_port, |
| 233 | + ) |
181 | 234 |
|
182 | | - logger.info('Starting HTTP server on port %s...', http_port) |
183 | | - uvicorn.run(app, host='127.0.0.1', port=http_port, log_level='info') |
| 235 | + loop = asyncio.get_event_loop() |
| 236 | + loop.run_until_complete(grpc_server.start()) |
| 237 | + loop.run_until_complete( |
| 238 | + asyncio.gather( |
| 239 | + uvicorn_server.serve(), grpc_server.wait_for_termination() |
| 240 | + ) |
| 241 | + ) |
184 | 242 |
|
185 | 243 |
|
186 | 244 | def main() -> None: |
|
0 commit comments