@@ -134,9 +134,8 @@ async def send_message(
134134 extensions : list [str ] | None = None ,
135135 ) -> SendMessageResponse :
136136 """Sends a non-streaming message request to the agent."""
137- return await self .stub .SendMessage (
138- request ,
139- metadata = self ._get_grpc_metadata (extensions ),
137+ return await self ._call_grpc (
138+ self .stub .SendMessage , request , context , extensions
140139 )
141140
142141 @_handle_grpc_stream_exception
@@ -148,14 +147,9 @@ async def send_message_streaming(
148147 extensions : list [str ] | None = None ,
149148 ) -> AsyncGenerator [StreamResponse ]:
150149 """Sends a streaming message request to the agent and yields responses as they arrive."""
151- stream = self .stub .SendStreamingMessage (
152- request ,
153- metadata = self ._get_grpc_metadata (extensions ),
154- )
155- while True :
156- response = await stream .read ()
157- if response == grpc .aio .EOF : # pyright: ignore[reportAttributeAccessIssue]
158- break
150+ async for response in self ._call_grpc_stream (
151+ self .stub .SendStreamingMessage , request , context , extensions
152+ ):
159153 yield response
160154
161155 @_handle_grpc_stream_exception
@@ -167,14 +161,9 @@ async def subscribe(
167161 extensions : list [str ] | None = None ,
168162 ) -> AsyncGenerator [StreamResponse ]:
169163 """Reconnects to get task updates."""
170- stream = self .stub .SubscribeToTask (
171- request ,
172- metadata = self ._get_grpc_metadata (extensions ),
173- )
174- while True :
175- response = await stream .read ()
176- if response == grpc .aio .EOF : # pyright: ignore[reportAttributeAccessIssue]
177- break
164+ async for response in self ._call_grpc_stream (
165+ self .stub .SubscribeToTask , request , context , extensions
166+ ):
178167 yield response
179168
180169 @_handle_grpc_exception
@@ -186,9 +175,8 @@ async def get_task(
186175 extensions : list [str ] | None = None ,
187176 ) -> Task :
188177 """Retrieves the current state and history of a specific task."""
189- return await self .stub .GetTask (
190- request ,
191- metadata = self ._get_grpc_metadata (extensions ),
178+ return await self ._call_grpc (
179+ self .stub .GetTask , request , context , extensions
192180 )
193181
194182 @_handle_grpc_exception
@@ -200,9 +188,8 @@ async def list_tasks(
200188 extensions : list [str ] | None = None ,
201189 ) -> ListTasksResponse :
202190 """Retrieves tasks for an agent."""
203- return await self .stub .ListTasks (
204- request ,
205- metadata = self ._get_grpc_metadata (extensions ),
191+ return await self ._call_grpc (
192+ self .stub .ListTasks , request , context , extensions
206193 )
207194
208195 @_handle_grpc_exception
@@ -214,9 +201,8 @@ async def cancel_task(
214201 extensions : list [str ] | None = None ,
215202 ) -> Task :
216203 """Requests the agent to cancel a specific task."""
217- return await self .stub .CancelTask (
218- request ,
219- metadata = self ._get_grpc_metadata (extensions ),
204+ return await self ._call_grpc (
205+ self .stub .CancelTask , request , context , extensions
220206 )
221207
222208 @_handle_grpc_exception
@@ -228,9 +214,11 @@ async def create_task_push_notification_config(
228214 extensions : list [str ] | None = None ,
229215 ) -> TaskPushNotificationConfig :
230216 """Sets or updates the push notification configuration for a specific task."""
231- return await self .stub .CreateTaskPushNotificationConfig (
217+ return await self ._call_grpc (
218+ self .stub .CreateTaskPushNotificationConfig ,
232219 request ,
233- metadata = self ._get_grpc_metadata (extensions ),
220+ context ,
221+ extensions ,
234222 )
235223
236224 @_handle_grpc_exception
@@ -242,9 +230,11 @@ async def get_task_push_notification_config(
242230 extensions : list [str ] | None = None ,
243231 ) -> TaskPushNotificationConfig :
244232 """Retrieves the push notification configuration for a specific task."""
245- return await self .stub .GetTaskPushNotificationConfig (
233+ return await self ._call_grpc (
234+ self .stub .GetTaskPushNotificationConfig ,
246235 request ,
247- metadata = self ._get_grpc_metadata (extensions ),
236+ context ,
237+ extensions ,
248238 )
249239
250240 @_handle_grpc_exception
@@ -256,9 +246,11 @@ async def list_task_push_notification_configs(
256246 extensions : list [str ] | None = None ,
257247 ) -> ListTaskPushNotificationConfigsResponse :
258248 """Lists push notification configurations for a specific task."""
259- return await self .stub .ListTaskPushNotificationConfigs (
249+ return await self ._call_grpc (
250+ self .stub .ListTaskPushNotificationConfigs ,
260251 request ,
261- metadata = self ._get_grpc_metadata (extensions ),
252+ context ,
253+ extensions ,
262254 )
263255
264256 @_handle_grpc_exception
@@ -270,9 +262,11 @@ async def delete_task_push_notification_config(
270262 extensions : list [str ] | None = None ,
271263 ) -> None :
272264 """Deletes the push notification configuration for a specific task."""
273- await self .stub .DeleteTaskPushNotificationConfig (
265+ await self ._call_grpc (
266+ self .stub .DeleteTaskPushNotificationConfig ,
274267 request ,
275- metadata = self ._get_grpc_metadata (extensions ),
268+ context ,
269+ extensions ,
276270 )
277271
278272 @_handle_grpc_exception
@@ -285,9 +279,8 @@ async def get_extended_agent_card(
285279 signature_verifier : Callable [[AgentCard ], None ] | None = None ,
286280 ) -> AgentCard :
287281 """Retrieves the agent's card."""
288- card = await self .stub .GetExtendedAgentCard (
289- request ,
290- metadata = self ._get_grpc_metadata (extensions ),
282+ card = await self ._call_grpc (
283+ self .stub .GetExtendedAgentCard , request , context , extensions
291284 )
292285
293286 if signature_verifier :
@@ -315,3 +308,43 @@ def _get_grpc_metadata(
315308 )
316309
317310 return metadata
311+
312+ def _get_grpc_timeout (
313+ self , context : ClientCallContext | None
314+ ) -> float | None :
315+ return context .timeout if context else None
316+
317+ async def _call_grpc (
318+ self ,
319+ method : Callable [..., Any ],
320+ request : Any ,
321+ context : ClientCallContext | None ,
322+ extensions : list [str ] | None ,
323+ ** kwargs : Any ,
324+ ) -> Any :
325+ return await method (
326+ request ,
327+ metadata = self ._get_grpc_metadata (extensions ),
328+ timeout = self ._get_grpc_timeout (context ),
329+ ** kwargs ,
330+ )
331+
332+ async def _call_grpc_stream (
333+ self ,
334+ method : Callable [..., Any ],
335+ request : Any ,
336+ context : ClientCallContext | None ,
337+ extensions : list [str ] | None ,
338+ ** kwargs : Any ,
339+ ) -> AsyncGenerator [StreamResponse ]:
340+ stream = method (
341+ request ,
342+ metadata = self ._get_grpc_metadata (extensions ),
343+ timeout = self ._get_grpc_timeout (context ),
344+ ** kwargs ,
345+ )
346+ while True :
347+ response = await stream .read ()
348+ if response == grpc .aio .EOF :
349+ break
350+ yield response
0 commit comments