Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 6 additions & 9 deletions src/google/adk/tools/base_toolset.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,8 +80,7 @@ def __init__(
"""
self.tool_filter = tool_filter
self.tool_name_prefix = tool_name_prefix
self._cached_invocation_id: Optional[str] = None
self._cached_prefixed_tools: Optional[list[BaseTool]] = None
self._cached_prefixed_tools: dict[Optional[str], list[BaseTool]] = {}
self._use_invocation_cache = True

@abstractmethod
Expand Down Expand Up @@ -119,16 +118,14 @@ async def get_tools_with_prefix(

if (
self._use_invocation_cache
and self._cached_prefixed_tools is not None
and self._cached_invocation_id == invocation_id
and invocation_id in self._cached_prefixed_tools
):
return self._cached_prefixed_tools
return self._cached_prefixed_tools[invocation_id]

tools = await self.get_tools(readonly_context)

if not self.tool_name_prefix:
self._cached_invocation_id = invocation_id
self._cached_prefixed_tools = tools
self._cached_prefixed_tools[invocation_id] = tools
return tools

prefix = self.tool_name_prefix
Expand Down Expand Up @@ -161,8 +158,7 @@ def _get_prefixed_declaration():
tool_copy._get_declaration = _create_prefixed_declaration()
prefixed_tools.append(tool_copy)

self._cached_invocation_id = invocation_id
self._cached_prefixed_tools = prefixed_tools
self._cached_prefixed_tools[invocation_id] = prefixed_tools
return prefixed_tools

async def close(self) -> None:
Expand All @@ -174,6 +170,7 @@ async def close(self) -> None:
should ensure that any open connections, files, or other managed
resources are properly released to prevent leaks.
"""
self._cached_prefixed_tools.clear()

@classmethod
def from_config(
Expand Down
15 changes: 14 additions & 1 deletion tests/unittests/tools/test_base_toolset.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ async def get_tools(
return self._tools

async def close(self) -> None:
pass
await super().close()


@pytest.mark.asyncio
Expand Down Expand Up @@ -439,6 +439,19 @@ async def test_get_tools_with_prefix_caching():
assert tools3 is not tools1 # Should be a new list instance
assert tools3[0].name == 'test_tool1'

# Fourth call with first invocation_id again (should hit cache)
tools1_again = await toolset.get_tools_with_prefix(
readonly_context=readonly_context1
)
assert tools1_again is tools1

# Test close() clears the cache
await toolset.close()
tools1_after_close = await toolset.get_tools_with_prefix(
readonly_context=readonly_context1
)
assert tools1_after_close is not tools1

# Test disabling caching
toolset._use_invocation_cache = False
tools4 = await toolset.get_tools_with_prefix(
Expand Down
Loading