diff --git a/src/google/adk/tools/base_toolset.py b/src/google/adk/tools/base_toolset.py index 25f0d6a2a2..34c022e999 100644 --- a/src/google/adk/tools/base_toolset.py +++ b/src/google/adk/tools/base_toolset.py @@ -80,8 +80,7 @@ def __init__( """ self.tool_filter = tool_filter self.tool_name_prefix = tool_name_prefix - self._cached_invocation_id: Optional[str] = None - self._cached_prefixed_tools: Optional[list[BaseTool]] = None + self._cached_prefixed_tools: dict[Optional[str], list[BaseTool]] = {} self._use_invocation_cache = True @abstractmethod @@ -119,16 +118,14 @@ async def get_tools_with_prefix( if ( self._use_invocation_cache - and self._cached_prefixed_tools is not None - and self._cached_invocation_id == invocation_id + and invocation_id in self._cached_prefixed_tools ): - return self._cached_prefixed_tools + return self._cached_prefixed_tools[invocation_id] tools = await self.get_tools(readonly_context) if not self.tool_name_prefix: - self._cached_invocation_id = invocation_id - self._cached_prefixed_tools = tools + self._cached_prefixed_tools[invocation_id] = tools return tools prefix = self.tool_name_prefix @@ -161,8 +158,7 @@ def _get_prefixed_declaration(): tool_copy._get_declaration = _create_prefixed_declaration() prefixed_tools.append(tool_copy) - self._cached_invocation_id = invocation_id - self._cached_prefixed_tools = prefixed_tools + self._cached_prefixed_tools[invocation_id] = prefixed_tools return prefixed_tools async def close(self) -> None: @@ -174,6 +170,7 @@ async def close(self) -> None: should ensure that any open connections, files, or other managed resources are properly released to prevent leaks. """ + self._cached_prefixed_tools.clear() @classmethod def from_config( diff --git a/tests/unittests/tools/test_base_toolset.py b/tests/unittests/tools/test_base_toolset.py index cdb7808db7..9b89e570dd 100644 --- a/tests/unittests/tools/test_base_toolset.py +++ b/tests/unittests/tools/test_base_toolset.py @@ -48,7 +48,7 @@ async def get_tools( return self._tools async def close(self) -> None: - pass + await super().close() @pytest.mark.asyncio @@ -439,6 +439,19 @@ async def test_get_tools_with_prefix_caching(): assert tools3 is not tools1 # Should be a new list instance assert tools3[0].name == 'test_tool1' + # Fourth call with first invocation_id again (should hit cache) + tools1_again = await toolset.get_tools_with_prefix( + readonly_context=readonly_context1 + ) + assert tools1_again is tools1 + + # Test close() clears the cache + await toolset.close() + tools1_after_close = await toolset.get_tools_with_prefix( + readonly_context=readonly_context1 + ) + assert tools1_after_close is not tools1 + # Test disabling caching toolset._use_invocation_cache = False tools4 = await toolset.get_tools_with_prefix(