Skip to content

Commit 456b70a

Browse files
authored
gemini - tool_use_id_to_name - local (strands-agents#1521)
1 parent 63e58aa commit 456b70a

1 file changed

Lines changed: 18 additions & 11 deletions

File tree

src/strands/models/gemini.py

Lines changed: 18 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -87,7 +87,6 @@ def __init__(
8787

8888
self._custom_client = client
8989
self.client_args = client_args or {}
90-
self._tool_use_id_to_name: dict[str, str] = {}
9190

9291
# Validate gemini_tools if provided
9392
if "gemini_tools" in self.config:
@@ -135,13 +134,19 @@ def _get_client(self) -> genai.Client:
135134
# Create a new client from client_args
136135
return genai.Client(**self.client_args)
137136

138-
def _format_request_content_part(self, content: ContentBlock) -> genai.types.Part:
137+
def _format_request_content_part(
138+
self, content: ContentBlock, tool_use_id_to_name: dict[str, str]
139+
) -> genai.types.Part:
139140
"""Format content block into a Gemini part instance.
140141
141142
- Docs: https://googleapis.github.io/python-genai/genai.html#genai.types.Part
142143
143144
Args:
144145
content: Message content to format.
146+
tool_use_id_to_name: Mapping of tool use id to tool name.
147+
Store the mapping from toolUseId to name for later use in toolResult formatting. This mapping is built
148+
as we format the request, ensuring that when we encounter toolResult blocks (which come after toolUse
149+
blocks in the message history), we can look up the function name.
145150
146151
Returns:
147152
Gemini part.
@@ -176,7 +181,7 @@ def _format_request_content_part(self, content: ContentBlock) -> genai.types.Par
176181

177182
if "toolResult" in content:
178183
tool_use_id = content["toolResult"]["toolUseId"]
179-
function_name = self._tool_use_id_to_name.get(tool_use_id, tool_use_id)
184+
function_name = tool_use_id_to_name.get(tool_use_id, tool_use_id)
180185

181186
return genai.types.Part(
182187
function_response=genai.types.FunctionResponse(
@@ -187,7 +192,8 @@ def _format_request_content_part(self, content: ContentBlock) -> genai.types.Par
187192
tool_result_content
188193
if "json" in tool_result_content
189194
else self._format_request_content_part(
190-
cast(ContentBlock, tool_result_content)
195+
cast(ContentBlock, tool_result_content),
196+
tool_use_id_to_name,
191197
).to_json_dict()
192198
for tool_result_content in content["toolResult"]["content"]
193199
],
@@ -196,11 +202,7 @@ def _format_request_content_part(self, content: ContentBlock) -> genai.types.Par
196202
)
197203

198204
if "toolUse" in content:
199-
# Store the mapping from toolUseId to name for later use in toolResult formatting.
200-
# This mapping is built as we format the request, ensuring that when we encounter
201-
# toolResult blocks (which come after toolUse blocks in the message history),
202-
# we can look up the function name.
203-
self._tool_use_id_to_name[content["toolUse"]["toolUseId"]] = content["toolUse"]["name"]
205+
tool_use_id_to_name[content["toolUse"]["toolUseId"]] = content["toolUse"]["name"]
204206

205207
return genai.types.Part(
206208
function_call=genai.types.FunctionCall(
@@ -223,9 +225,15 @@ def _format_request_content(self, messages: Messages) -> list[genai.types.Conten
223225
Returns:
224226
Gemini content list.
225227
"""
228+
# Gemini FunctionResponses are constructed from tool result blocks. Function name is required but is not
229+
# available in tool result blocks, hence the mapping.
230+
tool_use_id_to_name: dict[str, str] = {}
231+
226232
return [
227233
genai.types.Content(
228-
parts=[self._format_request_content_part(content) for content in message["content"]],
234+
parts=[
235+
self._format_request_content_part(content, tool_use_id_to_name) for content in message["content"]
236+
],
229237
role="user" if message["role"] == "user" else "model",
230238
)
231239
for message in messages
@@ -428,7 +436,6 @@ async def stream(
428436
ModelThrottledException: If the request is throttled by Gemini.
429437
"""
430438
request = self._format_request(messages, tool_specs, system_prompt, self.config.get("params"))
431-
self._tool_use_id_to_name.clear()
432439

433440
client = self._get_client().aio
434441

0 commit comments

Comments
 (0)