Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 31 additions & 3 deletions src/google/adk/runners.py
Original file line number Diff line number Diff line change
Expand Up @@ -2095,10 +2095,12 @@ async def _cleanup_toolsets(self, toolsets_to_close: set[BaseToolset]):

# This maintains the same task context throughout cleanup
for toolset in toolsets_to_close:
cleanup_task = asyncio.create_task(
asyncio.wait_for(toolset.close(), timeout=10.0)
)
try:
logger.info('Closing toolset: %s', type(toolset).__name__)
# Use asyncio.wait_for to add timeout protection
await asyncio.wait_for(toolset.close(), timeout=10.0)
await asyncio.shield(cleanup_task)
logger.info('Successfully closed toolset: %s', type(toolset).__name__)
except asyncio.TimeoutError:
logger.warning('Toolset %s cleanup timed out', type(toolset).__name__)
Expand All @@ -2113,8 +2115,34 @@ async def _cleanup_toolsets(self, toolsets_to_close: set[BaseToolset]):
# improved context propagation across task boundaries, and better cancellation
# handling prevent the cross-task cancel scope violation.
logger.warning(
'Toolset %s cleanup cancelled: %s', type(toolset).__name__, e
'Toolset %s cleanup cancellation requested: %s',
type(toolset).__name__,
e,
)
try:
await cleanup_task
logger.info(
'Successfully closed toolset after cancellation request: %s',
type(toolset).__name__,
)
except asyncio.TimeoutError:
cleanup_task.cancel()
logger.warning(
'Toolset %s cleanup timed out after cancellation request',
type(toolset).__name__,
)
except asyncio.CancelledError as close_cancelled:
logger.warning(
'Toolset %s cleanup cancelled: %s',
type(toolset).__name__,
close_cancelled,
)
except Exception as close_error:
logger.error(
'Error closing toolset %s after cancellation request: %s',
type(toolset).__name__,
close_error,
)
except Exception as e:
logger.error('Error closing toolset %s: %s', type(toolset).__name__, e)

Expand Down
47 changes: 47 additions & 0 deletions tests/unittests/test_runners.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
from google.adk.runners import Runner
from google.adk.sessions.in_memory_session_service import InMemorySessionService
from google.adk.sessions.session import Session
from google.adk.tools.base_toolset import BaseToolset
from google.genai import types
import pytest

Expand Down Expand Up @@ -1120,6 +1121,52 @@ async def test_runner_close_calls_plugin_close(self):

self.runner.plugin_manager.close.assert_awaited_once()

@pytest.mark.asyncio
async def test_runner_close_does_not_cancel_toolset_cleanup(self):
"""Caller cancellation should not cancel an in-flight toolset close."""
import asyncio

class SlowCloseToolset(BaseToolset):

def __init__(self):
super().__init__()
self.close_started = asyncio.Event()
self.close_finished = asyncio.Event()
self.close_cancelled = False

async def get_tools(self, readonly_context=None):
del readonly_context
return []

async def close(self) -> None:
self.close_started.set()
try:
await asyncio.sleep(0.05)
self.close_finished.set()
except asyncio.CancelledError:
self.close_cancelled = True
raise

toolset = SlowCloseToolset()
runner = Runner(
app_name="test_app",
agent=LlmAgent(
name="test_agent", model="gemini-1.5-pro", tools=[toolset]
),
session_service=self.session_service,
artifact_service=self.artifact_service,
)

close_task = asyncio.create_task(runner.close())
await toolset.close_started.wait()
close_task.cancel()

await close_task

assert close_task.cancelled() is False
assert toolset.close_cancelled is False
assert toolset.close_finished.is_set()

@pytest.mark.asyncio
async def test_runner_passes_plugin_close_timeout(self):
"""Test that runner passes plugin_close_timeout to PluginManager."""
Expand Down