Skip to content

Commit 0ebf78b

Browse files
google-genai-botcopybara-github
authored andcommitted
feat(mcp): gracefully handle tool execution errors and transport crashes
Previously, if an MCP tool returned a JSON-RPC error (e.g. 403 Forbidden) or if the underlying transport connection crashed, the resulting exceptions (McpError and ConnectionError) would bubble up and crash the entire ADK runner. This change introduces robust error boundaries for MCP tools: - `McpTool.run_async()` now catches `McpError` and general exceptions, returning them as structured error dictionaries `{"error": ...}` to the LLM agent so the conversation can continue gracefully. - `SessionContext` races tool calls against the background session task so transport crashes surface immediately instead of hanging. - Fixes an AnyIO cancellation scope bug ("Attempted to exit cancel scope in a different task") by removing redundant `asyncio.wait_for` wrappers around exit stack context entry. - Connection errors trigger automatic retries via `@retry_on_errors` before finally surfacing the failure to the agent. Fixes #4901, #4162 PiperOrigin-RevId: 902369269
1 parent 60b9073 commit 0ebf78b

9 files changed

Lines changed: 760 additions & 227 deletions

File tree

src/google/adk/tools/load_mcp_resource_tool.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -29,8 +29,7 @@
2929
from .base_tool import BaseTool
3030

3131
if TYPE_CHECKING:
32-
from mcp_toolset import McpToolset
33-
32+
from .mcp_tool.mcp_toolset import McpToolset
3433
from .tool_context import ToolContext
3534

3635
logger = logging.getLogger("google_adk." + __name__)
@@ -39,7 +38,7 @@
3938
class LoadMcpResourceTool(BaseTool):
4039
"""A tool that loads the MCP resources and adds them to the session."""
4140

42-
def __init__(self, mcp_toolset: McpToolset):
41+
def __init__(self, mcp_toolset: McpToolset) -> None:
4342
super().__init__(
4443
name="load_mcp_resource",
4544
description="""Loads resources from the MCP server.

src/google/adk/tools/mcp_tool/__init__.py

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -17,17 +17,19 @@
1717
try:
1818
from .conversion_utils import adk_to_mcp_tool_type
1919
from .conversion_utils import gemini_to_json_schema
20-
from .mcp_session_manager import SseConnectionParams
21-
from .mcp_session_manager import StdioConnectionParams
22-
from .mcp_session_manager import StreamableHTTPConnectionParams
23-
from .mcp_tool import MCPTool
24-
from .mcp_tool import McpTool
25-
from .mcp_toolset import MCPToolset
26-
from .mcp_toolset import McpToolset
20+
from .mcp_session_manager import MCPSessionManager as MCPSessionManager
21+
from .mcp_session_manager import SseConnectionParams as SseConnectionParams
22+
from .mcp_session_manager import StdioConnectionParams as StdioConnectionParams
23+
from .mcp_session_manager import StreamableHTTPConnectionParams as StreamableHTTPConnectionParams
24+
from .mcp_tool import MCPTool as MCPTool
25+
from .mcp_tool import McpTool as McpTool
26+
from .mcp_toolset import MCPToolset as MCPToolset
27+
from .mcp_toolset import McpToolset as McpToolset
2728

2829
__all__.extend([
2930
'adk_to_mcp_tool_type',
3031
'gemini_to_json_schema',
32+
'MCPSessionManager',
3133
'McpTool',
3234
'MCPTool',
3335
'McpToolset',

src/google/adk/tools/mcp_tool/mcp_session_manager.py

Lines changed: 102 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,9 @@
1616

1717
import asyncio
1818
from collections import deque
19+
import concurrent.futures
1920
from contextlib import AsyncExitStack
21+
from dataclasses import dataclass
2022
from datetime import timedelta
2123
import functools
2224
import hashlib
@@ -25,13 +27,20 @@
2527
import sys
2628
import threading
2729
from typing import Any
30+
from typing import Callable
31+
from typing import cast
2832
from typing import Dict
2933
from typing import Optional
3034
from typing import Protocol
3135
from typing import runtime_checkable
3236
from typing import TextIO
37+
from typing import TYPE_CHECKING
38+
from typing import TypeVar
3339
from typing import Union
3440

41+
if TYPE_CHECKING:
42+
from .session_context import SessionContext
43+
3544
from mcp import ClientSession
3645
from mcp import SamplingCapability
3746
from mcp import StdioServerParameters
@@ -44,8 +53,6 @@
4453
from pydantic import BaseModel
4554
from pydantic import ConfigDict
4655

47-
from .session_context import SessionContext
48-
4956
logger = logging.getLogger('google_adk.' + __name__)
5057

5158

@@ -146,7 +153,10 @@ class StreamableHTTPConnectionParams(BaseModel):
146153
httpx_client_factory: CheckableMcpHttpClientFactory = create_mcp_http_client
147154

148155

149-
def retry_on_errors(func):
156+
_F = TypeVar('_F', bound=Callable[..., Any])
157+
158+
159+
def retry_on_errors(func: _F) -> _F:
150160
"""Decorator to automatically retry action when MCP session errors occur.
151161
152162
When MCP session errors occur, the decorator will automatically retry the
@@ -165,7 +175,7 @@ def retry_on_errors(func):
165175
"""
166176

167177
@functools.wraps(func) # Preserves original function metadata
168-
async def wrapper(self, *args, **kwargs):
178+
async def wrapper(self: Any, *args: Any, **kwargs: Any) -> Any:
169179
try:
170180
return await func(self, *args, **kwargs)
171181
except Exception as e:
@@ -182,7 +192,17 @@ async def wrapper(self, *args, **kwargs):
182192
logger.info('Retrying %s due to error: %s', func.__name__, e)
183193
return await func(self, *args, **kwargs)
184194

185-
return wrapper
195+
return cast(_F, wrapper)
196+
197+
198+
@dataclass
199+
class _SessionEntry:
200+
"""A dataclass to hold session information."""
201+
202+
session: ClientSession
203+
exit_stack: AsyncExitStack
204+
loop: asyncio.AbstractEventLoop
205+
context: SessionContext
186206

187207

188208
class MCPSessionManager:
@@ -205,7 +225,7 @@ def __init__(
205225
*,
206226
sampling_callback: Optional[SamplingFnT] = None,
207227
sampling_capabilities: Optional[SamplingCapability] = None,
208-
):
228+
) -> None:
209229
"""Initializes the MCP session manager.
210230
211231
Args:
@@ -237,10 +257,8 @@ def __init__(
237257
self._connection_params = connection_params
238258
self._errlog = errlog
239259

240-
# Session pool: maps session keys to (session, exit_stack, loop) tuples
241-
self._sessions: Dict[
242-
str, tuple[ClientSession, AsyncExitStack, asyncio.AbstractEventLoop]
243-
] = {}
260+
# Session pool: maps session keys to _SessionEntry objects
261+
self._sessions: Dict[str, _SessionEntry] = {}
244262

245263
# Map of event loops to their respective locks to prevent race conditions
246264
# across different event loops in session creation.
@@ -312,35 +330,66 @@ def _merge_headers(
312330

313331
return base_headers
314332

315-
def _is_session_disconnected(self, session: ClientSession) -> bool:
333+
def _is_session_disconnected(
334+
self,
335+
entry: _SessionEntry,
336+
) -> bool:
316337
"""Checks if a session is disconnected or closed.
317338
318339
Args:
319-
session: The ClientSession to check.
340+
entry: The _SessionEntry to check.
320341
321342
Returns:
322343
True if the session is disconnected, False otherwise.
323344
"""
324-
return session._read_stream._closed or session._write_stream._closed
345+
if (
346+
entry.session._read_stream._closed
347+
or entry.session._write_stream._closed
348+
):
349+
return True
350+
if entry.context is not None and not entry.context._is_task_alive: # pylint: disable=protected-access
351+
return True
352+
return False
353+
354+
def _get_session_context(
355+
self, headers: Optional[Dict[str, str]] = None
356+
) -> Optional['SessionContext']:
357+
"""Returns the SessionContext for the session matching the given headers.
358+
359+
Note: This method reads from the session pool without acquiring
360+
``_session_lock``. This is safe because it is called immediately after
361+
``create_session()`` (which populates the entry under the lock) within
362+
the same task, and dict reads are atomic in CPython.
363+
364+
Args:
365+
headers: Optional headers used to identify the session.
366+
367+
Returns:
368+
The SessionContext if a matching session exists, None otherwise.
369+
"""
370+
merged_headers = self._merge_headers(headers)
371+
session_key = self._generate_session_key(merged_headers)
372+
entry = self._sessions.get(session_key)
373+
if entry is not None:
374+
return entry.context
375+
return None
325376

326377
async def _cleanup_session(
327378
self,
328379
session_key: str,
329-
exit_stack: AsyncExitStack,
330-
stored_loop: asyncio.AbstractEventLoop,
331-
):
380+
entry: _SessionEntry,
381+
) -> None:
332382
"""Cleans up a session, handling different event loops safely.
333383
334384
Args:
335385
session_key: The session key to clean up.
336-
exit_stack: The AsyncExitStack managing the session resources.
337-
stored_loop: The event loop on which the session was created.
386+
entry: The _SessionEntry managing the session resources.
338387
"""
339388
current_loop = asyncio.get_running_loop()
340389
try:
341-
if stored_loop is current_loop:
342-
await exit_stack.aclose()
343-
elif stored_loop.is_closed():
390+
if entry.loop is current_loop:
391+
await entry.exit_stack.aclose()
392+
elif entry.loop.is_closed():
344393
logger.warning(
345394
f'Error cleaning up session {session_key}: original event loop'
346395
' is closed, resources may be leaked.'
@@ -353,11 +402,11 @@ async def _cleanup_session(
353402
' event loop.'
354403
)
355404
future = asyncio.run_coroutine_threadsafe(
356-
exit_stack.aclose(), stored_loop
405+
entry.exit_stack.aclose(), entry.loop
357406
)
358407

359408
# Attach a callback so errors don't go unnoticed
360-
def cleanup_done(f: asyncio.Future):
409+
def cleanup_done(f: 'concurrent.futures.Future[Any]') -> None:
361410
try:
362411
if f.exception():
363412
logger.warning(
@@ -379,7 +428,9 @@ def cleanup_done(f: asyncio.Future):
379428
if session_key in self._sessions:
380429
del self._sessions[session_key]
381430

382-
def _create_client(self, merged_headers: Optional[Dict[str, str]] = None):
431+
def _create_client(
432+
self, merged_headers: Optional[Dict[str, str]] = None
433+
) -> Any:
383434
"""Creates an MCP client based on the connection parameters.
384435
385436
Args:
@@ -451,22 +502,22 @@ async def create_session(
451502
async with self._session_lock:
452503
# Check if we have an existing session
453504
if session_key in self._sessions:
454-
session, exit_stack, stored_loop = self._sessions[session_key]
505+
entry = self._sessions[session_key]
455506

456507
# Check if the existing session is still connected and bound to the current loop
457508
current_loop = asyncio.get_running_loop()
458-
if stored_loop is current_loop and not self._is_session_disconnected(
459-
session
509+
if entry.loop is current_loop and not self._is_session_disconnected(
510+
entry
460511
):
461512
# Session is still good, return it
462-
return session
513+
return entry.session
463514
else:
464515
# Session is disconnected or from a different loop, clean it up
465516
logger.info(
466517
'Cleaning up session (disconnected or different loop): %s',
467518
session_key,
468519
)
469-
await self._cleanup_session(session_key, exit_stack, stored_loop)
520+
await self._cleanup_session(session_key, entry)
470521

471522
# Create a new session (either first time or replacing disconnected one)
472523
exit_stack = AsyncExitStack()
@@ -482,28 +533,30 @@ async def create_session(
482533
)
483534

484535
try:
536+
from .session_context import SessionContext
537+
485538
client = self._create_client(merged_headers)
486539
is_stdio = isinstance(self._connection_params, StdioConnectionParams)
487540

541+
session_context = SessionContext(
542+
client=client,
543+
timeout=timeout_in_seconds,
544+
sse_read_timeout=sse_read_timeout_in_seconds,
545+
is_stdio=is_stdio,
546+
sampling_callback=self._sampling_callback,
547+
sampling_capabilities=self._sampling_capabilities,
548+
)
488549
session = await asyncio.wait_for(
489-
exit_stack.enter_async_context(
490-
SessionContext(
491-
client=client,
492-
timeout=timeout_in_seconds,
493-
sse_read_timeout=sse_read_timeout_in_seconds,
494-
is_stdio=is_stdio,
495-
sampling_callback=self._sampling_callback,
496-
sampling_capabilities=self._sampling_capabilities,
497-
)
498-
),
550+
exit_stack.enter_async_context(session_context),
499551
timeout=timeout_in_seconds,
500552
)
501553

502-
# Store session, exit stack, and loop in the pool
503-
self._sessions[session_key] = (
504-
session,
505-
exit_stack,
506-
asyncio.get_running_loop(),
554+
# Store session, exit stack, loop, and context in the pool
555+
self._sessions[session_key] = _SessionEntry(
556+
session=session,
557+
exit_stack=exit_stack,
558+
loop=asyncio.get_running_loop(),
559+
context=session_context,
507560
)
508561
logger.debug('Created new session: %s', session_key)
509562
return session
@@ -519,7 +572,7 @@ async def create_session(
519572
)
520573
raise ConnectionError(f'Failed to create MCP session: {e}') from e
521574

522-
def __getstate__(self):
575+
def __getstate__(self) -> Dict[str, Any]:
523576
"""Custom pickling to exclude non-picklable runtime objects."""
524577
state = self.__dict__.copy()
525578
# Remove unpicklable entries or those that shouldn't persist across pickle
@@ -532,7 +585,7 @@ def __getstate__(self):
532585

533586
return state
534587

535-
def __setstate__(self, state):
588+
def __setstate__(self, state: Dict[str, Any]) -> None:
536589
"""Custom unpickling to restore state."""
537590
self.__dict__.update(state)
538591
# Re-initialize members that were not pickled
@@ -543,12 +596,12 @@ def __setstate__(self, state):
543596
if not hasattr(self, '_errlog') or self._errlog is None:
544597
self._errlog = sys.stderr
545598

546-
async def close(self):
599+
async def close(self) -> None:
547600
"""Closes all sessions and cleans up resources."""
548601
async with self._session_lock:
549602
for session_key in list(self._sessions.keys()):
550-
_, exit_stack, stored_loop = self._sessions[session_key]
551-
await self._cleanup_session(session_key, exit_stack, stored_loop)
603+
entry = self._sessions[session_key]
604+
await self._cleanup_session(session_key, entry)
552605

553606

554607
SseServerParams = SseConnectionParams

0 commit comments

Comments
 (0)