Skip to content

Commit 63e58aa

Browse files
AirswitchAsaspicadustpgrayy
authored
fix: provide unique toolUseId for gemini models (strands-agents#1201)
Co-authored-by: spicadust <spicastre@gmail.com> Co-authored-by: Patrick Gray <pgrayy@amazon.com>
1 parent 8b7f6cc commit 63e58aa

2 files changed

Lines changed: 86 additions & 9 deletions

File tree

src/strands/models/gemini.py

Lines changed: 20 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
import json
77
import logging
88
import mimetypes
9+
import secrets
910
from collections.abc import AsyncGenerator
1011
from typing import Any, TypedDict, TypeVar, cast
1112

@@ -86,6 +87,7 @@ def __init__(
8687

8788
self._custom_client = client
8889
self.client_args = client_args or {}
90+
self._tool_use_id_to_name: dict[str, str] = {}
8991

9092
# Validate gemini_tools if provided
9193
if "gemini_tools" in self.config:
@@ -173,10 +175,13 @@ def _format_request_content_part(self, content: ContentBlock) -> genai.types.Par
173175
return genai.types.Part(text=content["text"])
174176

175177
if "toolResult" in content:
178+
tool_use_id = content["toolResult"]["toolUseId"]
179+
function_name = self._tool_use_id_to_name.get(tool_use_id, tool_use_id)
180+
176181
return genai.types.Part(
177182
function_response=genai.types.FunctionResponse(
178-
id=content["toolResult"]["toolUseId"],
179-
name=content["toolResult"]["toolUseId"],
183+
id=tool_use_id,
184+
name=function_name,
180185
response={
181186
"output": [
182187
tool_result_content
@@ -191,6 +196,12 @@ def _format_request_content_part(self, content: ContentBlock) -> genai.types.Par
191196
)
192197

193198
if "toolUse" in content:
199+
# Store the mapping from toolUseId to name for later use in toolResult formatting.
200+
# This mapping is built as we format the request, ensuring that when we encounter
201+
# toolResult blocks (which come after toolUse blocks in the message history),
202+
# we can look up the function name.
203+
self._tool_use_id_to_name[content["toolUse"]["toolUseId"]] = content["toolUse"]["name"]
204+
194205
return genai.types.Part(
195206
function_call=genai.types.FunctionCall(
196207
args=content["toolUse"]["input"],
@@ -317,16 +328,16 @@ def _format_chunk(self, event: dict[str, Any]) -> StreamEvent:
317328
case "content_start":
318329
match event["data_type"]:
319330
case "tool":
320-
# Note: toolUseId is the only identifier available in a tool result. However, Gemini requires
321-
# that name be set in the equivalent FunctionResponse type. Consequently, we assign
322-
# function name to toolUseId in our tool use block. And another reason, function_call is
323-
# not guaranteed to have id populated.
331+
function_call = event["data"].function_call
332+
# Use Gemini's provided ID or generate one if missing
333+
tool_use_id = function_call.id or f"tooluse_{secrets.token_urlsafe(16)}"
334+
324335
return {
325336
"contentBlockStart": {
326337
"start": {
327338
"toolUse": {
328-
"name": event["data"].function_call.name,
329-
"toolUseId": event["data"].function_call.name,
339+
"name": function_call.name,
340+
"toolUseId": tool_use_id,
330341
},
331342
},
332343
},
@@ -417,6 +428,7 @@ async def stream(
417428
ModelThrottledException: If the request is throttled by Gemini.
418429
"""
419430
request = self._format_request(messages, tool_specs, system_prompt, self.config.get("params"))
431+
self._tool_use_id_to_name.clear()
420432

421433
client = self._get_client().aio
422434

tests/strands/models/test_gemini.py

Lines changed: 66 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -360,6 +360,71 @@ async def test_stream_request_with_tool_results(gemini_client, model, model_id):
360360
gemini_client.aio.models.generate_content_stream.assert_called_with(**exp_request)
361361

362362

363+
@pytest.mark.asyncio
364+
async def test_stream_request_with_tool_results_preserving_name(gemini_client, model, model_id):
365+
messages = [
366+
{
367+
"role": "assistant",
368+
"content": [
369+
{
370+
"toolUse": {
371+
"toolUseId": "t1",
372+
"name": "tool_1",
373+
"input": {},
374+
},
375+
},
376+
],
377+
},
378+
{
379+
"role": "user",
380+
"content": [
381+
{
382+
"toolResult": {
383+
"toolUseId": "t1",
384+
"status": "success",
385+
"content": [{"text": "done"}],
386+
},
387+
},
388+
],
389+
},
390+
]
391+
await anext(model.stream(messages))
392+
393+
exp_request = {
394+
"config": {
395+
"tools": [{"function_declarations": []}],
396+
},
397+
"contents": [
398+
{
399+
"parts": [
400+
{
401+
"function_call": {
402+
"args": {},
403+
"id": "t1",
404+
"name": "tool_1",
405+
},
406+
},
407+
],
408+
"role": "model",
409+
},
410+
{
411+
"parts": [
412+
{
413+
"function_response": {
414+
"id": "t1",
415+
"name": "tool_1",
416+
"response": {"output": [{"text": "done"}]},
417+
},
418+
},
419+
],
420+
"role": "user",
421+
},
422+
],
423+
"model": model_id,
424+
}
425+
gemini_client.aio.models.generate_content_stream.assert_called_with(**exp_request)
426+
427+
363428
@pytest.mark.asyncio
364429
async def test_stream_request_with_empty_content(gemini_client, model, model_id):
365430
messages = [
@@ -459,7 +524,7 @@ async def test_stream_response_tool_use(gemini_client, model, messages, agenerat
459524
exp_chunks = [
460525
{"messageStart": {"role": "assistant"}},
461526
{"contentBlockStart": {"start": {}}},
462-
{"contentBlockStart": {"start": {"toolUse": {"name": "calculator", "toolUseId": "calculator"}}}},
527+
{"contentBlockStart": {"start": {"toolUse": {"name": "calculator", "toolUseId": "c1"}}}},
463528
{"contentBlockDelta": {"delta": {"toolUse": {"input": '{"expression": "2+2"}'}}}},
464529
{"contentBlockStop": {}},
465530
{"contentBlockStop": {}},

0 commit comments

Comments
 (0)