diff --git a/src/google/adk/tools/base_toolset.py b/src/google/adk/tools/base_toolset.py index 25f0d6a2a2..e3f16a145e 100644 --- a/src/google/adk/tools/base_toolset.py +++ b/src/google/adk/tools/base_toolset.py @@ -17,6 +17,7 @@ from abc import ABC from abc import abstractmethod import copy +from typing import Callable from typing import final from typing import List from typing import Optional @@ -27,6 +28,8 @@ from typing import TypeVar from typing import Union +from google.genai import types + from ..agents.readonly_context import ReadonlyContext from ..auth.auth_tool import AuthConfig from .base_tool import BaseTool @@ -80,8 +83,7 @@ def __init__( """ self.tool_filter = tool_filter self.tool_name_prefix = tool_name_prefix - self._cached_invocation_id: Optional[str] = None - self._cached_prefixed_tools: Optional[list[BaseTool]] = None + self._cached_prefixed_tools: dict[Optional[str], list[BaseTool]] = {} self._use_invocation_cache = True @abstractmethod @@ -119,16 +121,14 @@ async def get_tools_with_prefix( if ( self._use_invocation_cache - and self._cached_prefixed_tools is not None - and self._cached_invocation_id == invocation_id + and invocation_id in self._cached_prefixed_tools ): - return self._cached_prefixed_tools + return self._cached_prefixed_tools[invocation_id] tools = await self.get_tools(readonly_context) if not self.tool_name_prefix: - self._cached_invocation_id = invocation_id - self._cached_prefixed_tools = tools + self._cached_prefixed_tools[invocation_id] = tools return tools prefix = self.tool_name_prefix @@ -146,10 +146,12 @@ async def get_tools_with_prefix( # Also update the function declaration name if the tool has one # Use default parameters to capture the current values in the closure def _create_prefixed_declaration( - original_get_declaration=tool._get_declaration, - prefixed_name=prefixed_name, - ): - def _get_prefixed_declaration(): + original_get_declaration: Callable[ + [], Optional[types.FunctionDeclaration] + ] = tool._get_declaration, + prefixed_name: str = prefixed_name, + ) -> Callable[[], Optional[types.FunctionDeclaration]]: + def _get_prefixed_declaration() -> Optional[types.FunctionDeclaration]: declaration = original_get_declaration() if declaration is not None: declaration.name = prefixed_name @@ -161,8 +163,7 @@ def _get_prefixed_declaration(): tool_copy._get_declaration = _create_prefixed_declaration() prefixed_tools.append(tool_copy) - self._cached_invocation_id = invocation_id - self._cached_prefixed_tools = prefixed_tools + self._cached_prefixed_tools[invocation_id] = prefixed_tools return prefixed_tools async def close(self) -> None: @@ -174,6 +175,7 @@ async def close(self) -> None: should ensure that any open connections, files, or other managed resources are properly released to prevent leaks. """ + self._cached_prefixed_tools.clear() @classmethod def from_config( diff --git a/tests/unittests/tools/test_base_toolset.py b/tests/unittests/tools/test_base_toolset.py index cdb7808db7..a6a97cc33a 100644 --- a/tests/unittests/tools/test_base_toolset.py +++ b/tests/unittests/tools/test_base_toolset.py @@ -41,14 +41,16 @@ class _TestingToolset(BaseToolset): def __init__(self, *args, tools: Optional[list[BaseTool]] = None, **kwargs): super().__init__(*args, **kwargs) self._tools = tools or [] + self.get_tools_call_count = 0 async def get_tools( self, readonly_context: Optional[ReadonlyContext] = None ) -> list[BaseTool]: + self.get_tools_call_count += 1 return self._tools async def close(self) -> None: - pass + await super().close() @pytest.mark.asyncio @@ -439,6 +441,14 @@ async def test_get_tools_with_prefix_caching(): assert tools3 is not tools1 # Should be a new list instance assert tools3[0].name == 'test_tool1' + # The first invocation should still be cached after another invocation uses + # the same toolset. + tools1_again = await toolset.get_tools_with_prefix( + readonly_context=readonly_context1 + ) + assert tools1_again is tools1 + assert toolset.get_tools_call_count == 2 + # Test disabling caching toolset._use_invocation_cache = False tools4 = await toolset.get_tools_with_prefix( @@ -448,3 +458,16 @@ async def test_get_tools_with_prefix_caching(): readonly_context=readonly_context2 ) assert tools4 is not tools5 + + +@pytest.mark.asyncio +async def test_get_tools_with_prefix_close_clears_cache(): + tool1 = _TestingTool(name='tool1', description='Test tool 1') + toolset = _TestingToolset(tools=[tool1], tool_name_prefix='test') + + await toolset.get_tools_with_prefix() + assert toolset._cached_prefixed_tools + + await toolset.close() + + assert not toolset._cached_prefixed_tools