Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 15 additions & 13 deletions src/google/adk/tools/base_toolset.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from abc import ABC
from abc import abstractmethod
import copy
from typing import Callable
from typing import final
from typing import List
from typing import Optional
Expand All @@ -27,6 +28,8 @@
from typing import TypeVar
from typing import Union

from google.genai import types

from ..agents.readonly_context import ReadonlyContext
from ..auth.auth_tool import AuthConfig
from .base_tool import BaseTool
Expand Down Expand Up @@ -80,8 +83,7 @@ def __init__(
"""
self.tool_filter = tool_filter
self.tool_name_prefix = tool_name_prefix
self._cached_invocation_id: Optional[str] = None
self._cached_prefixed_tools: Optional[list[BaseTool]] = None
self._cached_prefixed_tools: dict[Optional[str], list[BaseTool]] = {}
self._use_invocation_cache = True

@abstractmethod
Expand Down Expand Up @@ -119,16 +121,14 @@ async def get_tools_with_prefix(

if (
self._use_invocation_cache
and self._cached_prefixed_tools is not None
and self._cached_invocation_id == invocation_id
and invocation_id in self._cached_prefixed_tools
):
return self._cached_prefixed_tools
return self._cached_prefixed_tools[invocation_id]

tools = await self.get_tools(readonly_context)

if not self.tool_name_prefix:
self._cached_invocation_id = invocation_id
self._cached_prefixed_tools = tools
self._cached_prefixed_tools[invocation_id] = tools
return tools

prefix = self.tool_name_prefix
Expand All @@ -146,10 +146,12 @@ async def get_tools_with_prefix(
# Also update the function declaration name if the tool has one
# Use default parameters to capture the current values in the closure
def _create_prefixed_declaration(
original_get_declaration=tool._get_declaration,
prefixed_name=prefixed_name,
):
def _get_prefixed_declaration():
original_get_declaration: Callable[
[], Optional[types.FunctionDeclaration]
] = tool._get_declaration,
prefixed_name: str = prefixed_name,
) -> Callable[[], Optional[types.FunctionDeclaration]]:
def _get_prefixed_declaration() -> Optional[types.FunctionDeclaration]:
declaration = original_get_declaration()
if declaration is not None:
declaration.name = prefixed_name
Expand All @@ -161,8 +163,7 @@ def _get_prefixed_declaration():
tool_copy._get_declaration = _create_prefixed_declaration()
prefixed_tools.append(tool_copy)

self._cached_invocation_id = invocation_id
self._cached_prefixed_tools = prefixed_tools
self._cached_prefixed_tools[invocation_id] = prefixed_tools
return prefixed_tools

async def close(self) -> None:
Expand All @@ -174,6 +175,7 @@ async def close(self) -> None:
should ensure that any open connections, files, or other managed
resources are properly released to prevent leaks.
"""
self._cached_prefixed_tools.clear()

@classmethod
def from_config(
Expand Down
25 changes: 24 additions & 1 deletion tests/unittests/tools/test_base_toolset.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,14 +41,16 @@ class _TestingToolset(BaseToolset):
def __init__(self, *args, tools: Optional[list[BaseTool]] = None, **kwargs):
super().__init__(*args, **kwargs)
self._tools = tools or []
self.get_tools_call_count = 0

async def get_tools(
self, readonly_context: Optional[ReadonlyContext] = None
) -> list[BaseTool]:
self.get_tools_call_count += 1
return self._tools

async def close(self) -> None:
pass
await super().close()


@pytest.mark.asyncio
Expand Down Expand Up @@ -439,6 +441,14 @@ async def test_get_tools_with_prefix_caching():
assert tools3 is not tools1 # Should be a new list instance
assert tools3[0].name == 'test_tool1'

# The first invocation should still be cached after another invocation uses
# the same toolset.
tools1_again = await toolset.get_tools_with_prefix(
readonly_context=readonly_context1
)
assert tools1_again is tools1
assert toolset.get_tools_call_count == 2

# Test disabling caching
toolset._use_invocation_cache = False
tools4 = await toolset.get_tools_with_prefix(
Expand All @@ -448,3 +458,16 @@ async def test_get_tools_with_prefix_caching():
readonly_context=readonly_context2
)
assert tools4 is not tools5


@pytest.mark.asyncio
async def test_get_tools_with_prefix_close_clears_cache():
tool1 = _TestingTool(name='tool1', description='Test tool 1')
toolset = _TestingToolset(tools=[tool1], tool_name_prefix='test')

await toolset.get_tools_with_prefix()
assert toolset._cached_prefixed_tools

await toolset.close()

assert not toolset._cached_prefixed_tools
Loading