From a4f4199f64059656e63d30371338b85bf0547262 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Roland=20K=C3=A1konyi?= Date: Wed, 11 Mar 2026 09:16:53 +0100 Subject: [PATCH 01/15] Fix Vertex AI listSessions null handling --- .../adk/sessions/VertexAiSessionService.java | 10 +++--- .../sessions/VertexAiSessionServiceTest.java | 34 +++++++++++++++++++ 2 files changed, 40 insertions(+), 4 deletions(-) diff --git a/core/src/main/java/com/google/adk/sessions/VertexAiSessionService.java b/core/src/main/java/com/google/adk/sessions/VertexAiSessionService.java index 4336f96c9..b62add27a 100644 --- a/core/src/main/java/com/google/adk/sessions/VertexAiSessionService.java +++ b/core/src/main/java/com/google/adk/sessions/VertexAiSessionService.java @@ -128,15 +128,17 @@ public Single listSessions(String appName, String userId) .map( listSessionsResponseMap -> parseListSessionsResponse(listSessionsResponseMap, appName, userId)) - .defaultIfEmpty(ListSessionsResponse.builder().build()); + .defaultIfEmpty(ListSessionsResponse.builder().sessions(new ArrayList<>()).build()); } private ListSessionsResponse parseListSessionsResponse( JsonNode listSessionsResponseMap, String appName, String userId) { + JsonNode sessionsNode = listSessionsResponseMap.get("sessions"); + if (sessionsNode == null || sessionsNode.isNull() || sessionsNode.isEmpty()) { + return ListSessionsResponse.builder().sessions(new ArrayList<>()).build(); + } List> apiSessions = - objectMapper.convertValue( - listSessionsResponseMap.get("sessions"), - new TypeReference>>() {}); + objectMapper.convertValue(sessionsNode, new TypeReference>>() {}); List sessions = new ArrayList<>(); for (Map apiSession : apiSessions) { diff --git a/core/src/test/java/com/google/adk/sessions/VertexAiSessionServiceTest.java b/core/src/test/java/com/google/adk/sessions/VertexAiSessionServiceTest.java index def4faf4c..3dab94b46 100644 --- a/core/src/test/java/com/google/adk/sessions/VertexAiSessionServiceTest.java +++ b/core/src/test/java/com/google/adk/sessions/VertexAiSessionServiceTest.java @@ -25,6 +25,8 @@ import java.util.Optional; import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.ConcurrentMap; +import okhttp3.MediaType; +import okhttp3.ResponseBody; import org.junit.Before; import org.junit.Test; import org.junit.runner.RunWith; @@ -37,6 +39,20 @@ public class VertexAiSessionServiceTest { private static final ObjectMapper mapper = JsonBaseModel.getMapper(); + private static final MediaType JSON_MEDIA_TYPE = + MediaType.parse("application/json; charset=utf-8"); + + private static ApiResponse apiResponseJson(String json) { + return new ApiResponse() { + @Override + public ResponseBody getResponseBody() { + return ResponseBody.create(JSON_MEDIA_TYPE, json); + } + + @Override + public void close() {} + }; + } private static final String MOCK_SESSION_STRING_1 = """ @@ -319,6 +335,24 @@ public void listSessions_empty() { .isEmpty(); } + @Test + public void listSessions_missingSessionsField_returnsEmpty() { + when(mockApiClient.request("GET", "reasoningEngines/123/sessions?filter=user_id=userX", "")) + .thenReturn(apiResponseJson("{}")); + + assertThat(vertexAiSessionService.listSessions("123", "userX").blockingGet().sessions()) + .isEmpty(); + } + + @Test + public void listSessions_nullSessionsField_returnsEmpty() { + when(mockApiClient.request("GET", "reasoningEngines/123/sessions?filter=user_id=userY", "")) + .thenReturn(apiResponseJson("{\"sessions\": null}")); + + assertThat(vertexAiSessionService.listSessions("123", "userY").blockingGet().sessions()) + .isEmpty(); + } + @Test public void listEvents_empty() { assertThat(vertexAiSessionService.listEvents("789", "user1", "3").blockingGet().events()) From 70056707f42281772bd737e2c7fd5878181c7c37 Mon Sep 17 00:00:00 2001 From: Guillaume Laforge Date: Fri, 20 Mar 2026 16:47:40 +0100 Subject: [PATCH 02/15] refactor: migrate LangChain4j to builder pattern, enhance token usage, and use JSpecify Nullable - Migrate LangChain4j to a builder pattern - Enhance token usage handling with TokenCountEstimator (from PR #623) - Upgrade to latest version of LangChain4j - Replace javax.annotation.Nullable with org.jspecify.annotations.Nullable --- .../adk/models/langchain4j/LangChain4j.java | 230 ++++++++++++------ .../LangChain4jIntegrationTest.java | 24 +- .../models/langchain4j/LangChain4jTest.java | 162 +++++++++++- pom.xml | 2 +- 4 files changed, 327 insertions(+), 91 deletions(-) diff --git a/contrib/langchain4j/src/main/java/com/google/adk/models/langchain4j/LangChain4j.java b/contrib/langchain4j/src/main/java/com/google/adk/models/langchain4j/LangChain4j.java index 3ccb1e029..8279dc21a 100644 --- a/contrib/langchain4j/src/main/java/com/google/adk/models/langchain4j/LangChain4j.java +++ b/contrib/langchain4j/src/main/java/com/google/adk/models/langchain4j/LangChain4j.java @@ -23,6 +23,7 @@ import com.google.adk.models.BaseLlmConnection; import com.google.adk.models.LlmRequest; import com.google.adk.models.LlmResponse; +import com.google.auto.value.AutoValue; import com.google.genai.types.Blob; import com.google.genai.types.Content; import com.google.genai.types.FunctionCall; @@ -30,11 +31,11 @@ import com.google.genai.types.FunctionDeclaration; import com.google.genai.types.FunctionResponse; import com.google.genai.types.GenerateContentConfig; +import com.google.genai.types.GenerateContentResponseUsageMetadata; import com.google.genai.types.Part; import com.google.genai.types.Schema; import com.google.genai.types.ToolConfig; import com.google.genai.types.Type; -import dev.langchain4j.Experimental; import dev.langchain4j.agent.tool.ToolExecutionRequest; import dev.langchain4j.agent.tool.ToolSpecification; import dev.langchain4j.data.audio.Audio; @@ -52,6 +53,7 @@ import dev.langchain4j.data.pdf.PdfFile; import dev.langchain4j.data.video.Video; import dev.langchain4j.exception.UnsupportedFeatureException; +import dev.langchain4j.model.TokenCountEstimator; import dev.langchain4j.model.chat.ChatModel; import dev.langchain4j.model.chat.StreamingChatModel; import dev.langchain4j.model.chat.request.ChatRequest; @@ -65,6 +67,7 @@ import dev.langchain4j.model.chat.request.json.JsonStringSchema; import dev.langchain4j.model.chat.response.ChatResponse; import dev.langchain4j.model.chat.response.StreamingChatResponseHandler; +import dev.langchain4j.model.output.TokenUsage; import io.reactivex.rxjava3.core.BackpressureStrategy; import io.reactivex.rxjava3.core.Flowable; import java.util.ArrayList; @@ -72,66 +75,101 @@ import java.util.HashMap; import java.util.List; import java.util.Map; -import java.util.Objects; import java.util.UUID; +import org.jspecify.annotations.Nullable; -@Experimental -public class LangChain4j extends BaseLlm { +@AutoValue +public abstract class LangChain4j extends BaseLlm { private static final TypeReference> MAP_TYPE_REFERENCE = new TypeReference<>() {}; - private final ChatModel chatModel; - private final StreamingChatModel streamingChatModel; - private final ObjectMapper objectMapper; + LangChain4j() { + super(""); + } + + @Nullable + public abstract ChatModel chatModel(); + + @Nullable + public abstract StreamingChatModel streamingChatModel(); + + public abstract ObjectMapper objectMapper(); + + public abstract String modelName(); + + @Nullable + public abstract TokenCountEstimator tokenCountEstimator(); + + @Override + public String model() { + return modelName(); + } + + public static Builder builder() { + return new AutoValue_LangChain4j.Builder().objectMapper(new ObjectMapper()); + } + + @AutoValue.Builder + public abstract static class Builder { + public abstract Builder chatModel(ChatModel chatModel); + + public abstract Builder streamingChatModel(StreamingChatModel streamingChatModel); + + public abstract Builder tokenCountEstimator(TokenCountEstimator tokenCountEstimator); + + public abstract Builder objectMapper(ObjectMapper objectMapper); + + public abstract Builder modelName(String modelName); + + public abstract LangChain4j build(); + } public LangChain4j(ChatModel chatModel) { - super( - Objects.requireNonNull( - chatModel.defaultRequestParameters().modelName(), "chat model name cannot be null")); - this.chatModel = Objects.requireNonNull(chatModel, "chatModel cannot be null"); - this.streamingChatModel = null; - this.objectMapper = new ObjectMapper(); + this(chatModel, null, null, chatModel.defaultRequestParameters().modelName(), null); } public LangChain4j(ChatModel chatModel, String modelName) { - super(Objects.requireNonNull(modelName, "chat model name cannot be null")); - this.chatModel = Objects.requireNonNull(chatModel, "chatModel cannot be null"); - this.streamingChatModel = null; - this.objectMapper = new ObjectMapper(); + this(chatModel, null, null, modelName, null); } public LangChain4j(StreamingChatModel streamingChatModel) { - super( - Objects.requireNonNull( - streamingChatModel.defaultRequestParameters().modelName(), - "streaming chat model name cannot be null")); - this.chatModel = null; - this.streamingChatModel = - Objects.requireNonNull(streamingChatModel, "streamingChatModel cannot be null"); - this.objectMapper = new ObjectMapper(); + this( + null, + streamingChatModel, + null, + streamingChatModel.defaultRequestParameters().modelName(), + null); } public LangChain4j(StreamingChatModel streamingChatModel, String modelName) { - super(Objects.requireNonNull(modelName, "streaming chat model name cannot be null")); - this.chatModel = null; - this.streamingChatModel = - Objects.requireNonNull(streamingChatModel, "streamingChatModel cannot be null"); - this.objectMapper = new ObjectMapper(); + this(null, streamingChatModel, null, modelName, null); } public LangChain4j(ChatModel chatModel, StreamingChatModel streamingChatModel, String modelName) { - super(Objects.requireNonNull(modelName, "model name cannot be null")); - this.chatModel = Objects.requireNonNull(chatModel, "chatModel cannot be null"); - this.streamingChatModel = - Objects.requireNonNull(streamingChatModel, "streamingChatModel cannot be null"); - this.objectMapper = new ObjectMapper(); + this(chatModel, streamingChatModel, null, modelName, null); + } + + private LangChain4j( + ChatModel chatModel, + StreamingChatModel streamingChatModel, + ObjectMapper objectMapper, + String modelName, + TokenCountEstimator tokenCountEstimator) { + this(); + LangChain4j.builder() + .chatModel(chatModel) + .streamingChatModel(streamingChatModel) + .objectMapper(objectMapper) + .modelName(modelName) + .tokenCountEstimator(tokenCountEstimator) + .build(); } @Override public Flowable generateContent(LlmRequest llmRequest, boolean stream) { if (stream) { - if (this.streamingChatModel == null) { + if (this.streamingChatModel() == null) { return Flowable.error(new IllegalStateException("StreamingChatModel is not configured")); } @@ -139,54 +177,57 @@ public Flowable generateContent(LlmRequest llmRequest, boolean stre return Flowable.create( emitter -> { - streamingChatModel.chat( - chatRequest, - new StreamingChatResponseHandler() { - @Override - public void onPartialResponse(String s) { - emitter.onNext( - LlmResponse.builder().content(Content.fromParts(Part.fromText(s))).build()); - } - - @Override - public void onCompleteResponse(ChatResponse chatResponse) { - if (chatResponse.aiMessage().hasToolExecutionRequests()) { - AiMessage aiMessage = chatResponse.aiMessage(); - toParts(aiMessage).stream() - .map(Part::functionCall) - .forEach( - functionCall -> { - functionCall.ifPresent( - function -> { - emitter.onNext( - LlmResponse.builder() - .content( - Content.fromParts( - Part.fromFunctionCall( - function.name().orElse(""), - function.args().orElse(Map.of())))) - .build()); - }); - }); - } - emitter.onComplete(); - } - - @Override - public void onError(Throwable throwable) { - emitter.onError(throwable); - } - }); + streamingChatModel() + .chat( + chatRequest, + new StreamingChatResponseHandler() { + @Override + public void onPartialResponse(String s) { + emitter.onNext( + LlmResponse.builder() + .content(Content.fromParts(Part.fromText(s))) + .build()); + } + + @Override + public void onCompleteResponse(ChatResponse chatResponse) { + if (chatResponse.aiMessage().hasToolExecutionRequests()) { + AiMessage aiMessage = chatResponse.aiMessage(); + toParts(aiMessage).stream() + .map(Part::functionCall) + .forEach( + functionCall -> { + functionCall.ifPresent( + function -> { + emitter.onNext( + LlmResponse.builder() + .content( + Content.fromParts( + Part.fromFunctionCall( + function.name().orElse(""), + function.args().orElse(Map.of())))) + .build()); + }); + }); + } + emitter.onComplete(); + } + + @Override + public void onError(Throwable throwable) { + emitter.onError(throwable); + } + }); }, BackpressureStrategy.BUFFER); } else { - if (this.chatModel == null) { + if (this.chatModel() == null) { return Flowable.error(new IllegalStateException("ChatModel is not configured")); } ChatRequest chatRequest = toChatRequest(llmRequest); - ChatResponse chatResponse = chatModel.chat(chatRequest); - LlmResponse llmResponse = toLlmResponse(chatResponse); + ChatResponse chatResponse = chatModel().chat(chatRequest); + LlmResponse llmResponse = toLlmResponse(chatResponse, chatRequest); return Flowable.just(llmResponse); } @@ -413,7 +454,7 @@ private AiMessage toAiMessage(Content content) { private String toJson(Object object) { try { - return objectMapper.writeValueAsString(object); + return objectMapper().writeValueAsString(object); } catch (JsonProcessingException e) { throw new RuntimeException(e); } @@ -511,11 +552,38 @@ private JsonSchemaElement toJsonSchemaElement(Schema schema) { } } - private LlmResponse toLlmResponse(ChatResponse chatResponse) { + private LlmResponse toLlmResponse(ChatResponse chatResponse, ChatRequest chatRequest) { Content content = Content.builder().role("model").parts(toParts(chatResponse.aiMessage())).build(); - return LlmResponse.builder().content(content).build(); + LlmResponse.Builder builder = LlmResponse.builder().content(content); + TokenUsage tokenUsage = chatResponse.tokenUsage(); + if (tokenCountEstimator() != null) { + try { + int estimatedInput = + tokenCountEstimator().estimateTokenCountInMessages(chatRequest.messages()); + int estimatedOutput = + tokenCountEstimator().estimateTokenCountInText(chatResponse.aiMessage().text()); + int estimatedTotal = estimatedInput + estimatedOutput; + builder.usageMetadata( + GenerateContentResponseUsageMetadata.builder() + .promptTokenCount(estimatedInput) + .candidatesTokenCount(estimatedOutput) + .totalTokenCount(estimatedTotal) + .build()); + } catch (Exception e) { + e.printStackTrace(); + } + } else if (tokenUsage != null) { + builder.usageMetadata( + GenerateContentResponseUsageMetadata.builder() + .promptTokenCount(tokenUsage.inputTokenCount()) + .candidatesTokenCount(tokenUsage.outputTokenCount()) + .totalTokenCount(tokenUsage.totalTokenCount()) + .build()); + } + + return builder.build(); } private List toParts(AiMessage aiMessage) { @@ -546,7 +614,7 @@ private List toParts(AiMessage aiMessage) { private Map toArgs(ToolExecutionRequest toolExecutionRequest) { try { - return objectMapper.readValue(toolExecutionRequest.arguments(), MAP_TYPE_REFERENCE); + return objectMapper().readValue(toolExecutionRequest.arguments(), MAP_TYPE_REFERENCE); } catch (JsonProcessingException e) { throw new RuntimeException(e); } diff --git a/contrib/langchain4j/src/test/java/com/google/adk/models/langchain4j/LangChain4jIntegrationTest.java b/contrib/langchain4j/src/test/java/com/google/adk/models/langchain4j/LangChain4jIntegrationTest.java index 191e48017..5b6d3f3ad 100644 --- a/contrib/langchain4j/src/test/java/com/google/adk/models/langchain4j/LangChain4jIntegrationTest.java +++ b/contrib/langchain4j/src/test/java/com/google/adk/models/langchain4j/LangChain4jIntegrationTest.java @@ -62,7 +62,8 @@ void testSimpleAgent() { LlmAgent.builder() .name("science-app") .description("Science teacher agent") - .model(new LangChain4j(claudeModel, CLAUDE_4_6_SONNET)) + .model( + LangChain4j.builder().chatModel(claudeModel).modelName(CLAUDE_4_6_SONNET).build()) .instruction( """ You are a helpful science teacher that explains science concepts @@ -98,7 +99,8 @@ void testSingleAgentWithTools() { LlmAgent.builder() .name("friendly-weather-app") .description("Friend agent that knows about the weather") - .model(new LangChain4j(claudeModel, CLAUDE_4_6_SONNET)) + .model( + LangChain4j.builder().chatModel(claudeModel).modelName(CLAUDE_4_6_SONNET).build()) .instruction( """ You are a friendly assistant. @@ -183,7 +185,7 @@ void testAgentTool() { LlmAgent.builder() .name("friendly-weather-app") .description("Friend agent that knows about the weather") - .model(new LangChain4j(gptModel)) + .model(LangChain4j.builder().chatModel(gptModel).modelName(GPT_4_O_MINI).build()) .instruction( """ You are a friendly assistant. @@ -246,7 +248,7 @@ void testSubAgent() { LlmAgent.builder() .name("greeterAgent") .description("Friendly agent that greets users") - .model(new LangChain4j(gptModel)) + .model(LangChain4j.builder().chatModel(gptModel).modelName(GPT_4_O_MINI).build()) .instruction( """ You are a friendly that greets users. @@ -257,7 +259,7 @@ void testSubAgent() { LlmAgent.builder() .name("farewellAgent") .description("Friendly agent that says goodbye to users") - .model(new LangChain4j(gptModel)) + .model(LangChain4j.builder().chatModel(gptModel).modelName(GPT_4_O_MINI).build()) .instruction( """ You are a friendly that says goodbye to users. @@ -355,7 +357,11 @@ void testSimpleStreamingResponse() { .modelName(CLAUDE_4_6_SONNET) .build(); - LangChain4j lc4jClaude = new LangChain4j(claudeStreamingModel, CLAUDE_4_6_SONNET); + LangChain4j lc4jClaude = + LangChain4j.builder() + .streamingChatModel(claudeStreamingModel) + .modelName(CLAUDE_4_6_SONNET) + .build(); // when Flowable responses = @@ -413,7 +419,11 @@ void testStreamingRunConfig() { When someone greets you, respond with "Hello". If someone asks about the weather, call the `getWeather` function. """) - .model(new LangChain4j(streamingModel, "GPT_4_O_MINI")) + .model( + LangChain4j.builder() + .streamingChatModel(streamingModel) + .modelName("GPT_4_O_MINI") + .build()) // .model(new LangChain4j(streamingModel, // CLAUDE_3_7_SONNET_20250219)) .tools(FunctionTool.create(ToolExample.class, "getWeather")) diff --git a/contrib/langchain4j/src/test/java/com/google/adk/models/langchain4j/LangChain4jTest.java b/contrib/langchain4j/src/test/java/com/google/adk/models/langchain4j/LangChain4jTest.java index 076bb79a3..f88237ff1 100644 --- a/contrib/langchain4j/src/test/java/com/google/adk/models/langchain4j/LangChain4jTest.java +++ b/contrib/langchain4j/src/test/java/com/google/adk/models/langchain4j/LangChain4jTest.java @@ -19,6 +19,7 @@ import static org.mockito.ArgumentMatchers.any; import static org.mockito.Mockito.*; +import com.fasterxml.jackson.databind.ObjectMapper; import com.google.adk.models.LlmRequest; import com.google.adk.models.LlmResponse; import com.google.adk.tools.FunctionTool; @@ -26,6 +27,7 @@ import dev.langchain4j.agent.tool.ToolExecutionRequest; import dev.langchain4j.data.message.AiMessage; import dev.langchain4j.data.message.UserMessage; +import dev.langchain4j.model.TokenCountEstimator; import dev.langchain4j.model.chat.ChatModel; import dev.langchain4j.model.chat.StreamingChatModel; import dev.langchain4j.model.chat.request.ChatRequest; @@ -33,6 +35,7 @@ import dev.langchain4j.model.chat.request.json.JsonStringSchema; import dev.langchain4j.model.chat.response.ChatResponse; import dev.langchain4j.model.chat.response.StreamingChatResponseHandler; +import dev.langchain4j.model.output.TokenUsage; import io.reactivex.rxjava3.core.Flowable; import java.util.ArrayList; import java.util.List; @@ -57,8 +60,26 @@ void setUp() { chatModel = mock(ChatModel.class); streamingChatModel = mock(StreamingChatModel.class); - langChain4j = new LangChain4j(chatModel, MODEL_NAME); - streamingLangChain4j = new LangChain4j(streamingChatModel, MODEL_NAME); + langChain4j = LangChain4j.builder().chatModel(chatModel).modelName(MODEL_NAME).build(); + streamingLangChain4j = + LangChain4j.builder().streamingChatModel(streamingChatModel).modelName(MODEL_NAME).build(); + } + + @Test + void testBuilder() { + ObjectMapper customMapper = new ObjectMapper(); + LangChain4j customLc4j = + LangChain4j.builder() + .chatModel(chatModel) + .streamingChatModel(streamingChatModel) + .objectMapper(customMapper) + .modelName("custom-model") + .build(); + + assertThat(customLc4j.chatModel()).isEqualTo(chatModel); + assertThat(customLc4j.streamingChatModel()).isEqualTo(streamingChatModel); + assertThat(customLc4j.objectMapper()).isEqualTo(customMapper); + assertThat(customLc4j.modelName()).isEqualTo("custom-model"); } @Test @@ -812,4 +833,141 @@ void testGenerateContentWithMcpToolParametersJsonSchemaAsSchema() { assertThat(capturedRequest.toolSpecifications().get(0).name()).isEqualTo("mcpTool"); assertThat(capturedRequest.toolSpecifications().get(0).description()).isEqualTo("An MCP tool"); } + + @Test + @DisplayName( + "Should use TokenCountEstimator to estimate token usage when TokenUsage is not available") + void testTokenCountEstimatorFallback() { + // Given + // Create a mock TokenCountEstimator + final TokenCountEstimator tokenCountEstimator = mock(TokenCountEstimator.class); + when(tokenCountEstimator.estimateTokenCountInMessages(any())).thenReturn(50); // Input tokens + when(tokenCountEstimator.estimateTokenCountInText(any())).thenReturn(20); // Output tokens + + // Create LangChain4j with the TokenCountEstimator using Builder + final LangChain4j langChain4jWithEstimator = + LangChain4j.builder() + .chatModel(chatModel) + .modelName(MODEL_NAME) + .tokenCountEstimator(tokenCountEstimator) + .build(); + + // Create a LlmRequest + final LlmRequest llmRequest = + LlmRequest.builder() + .contents(List.of(Content.fromParts(Part.fromText("What is the weather today?")))) + .build(); + + // Mock ChatResponse WITHOUT TokenUsage (simulating when LLM doesn't provide token counts) + final ChatResponse chatResponse = mock(ChatResponse.class); + final AiMessage aiMessage = AiMessage.from("The weather is sunny today."); + when(chatResponse.aiMessage()).thenReturn(aiMessage); + when(chatResponse.tokenUsage()).thenReturn(null); // No token usage from LLM + when(chatModel.chat(any(ChatRequest.class))).thenReturn(chatResponse); + + // When + final LlmResponse response = + langChain4jWithEstimator.generateContent(llmRequest, false).blockingFirst(); + + // Then + // Verify the response has usage metadata estimated by TokenCountEstimator + assertThat(response).isNotNull(); + assertThat(response.content()).isPresent(); + assertThat(response.content().get().text()).isEqualTo("The weather is sunny today."); + + // IMPORTANT: Verify that token usage was estimated via the TokenCountEstimator + assertThat(response.usageMetadata()).isPresent(); + final GenerateContentResponseUsageMetadata usageMetadata = response.usageMetadata().get(); + assertThat(usageMetadata.promptTokenCount()).isEqualTo(Optional.of(50)); // From estimator + assertThat(usageMetadata.candidatesTokenCount()).isEqualTo(Optional.of(20)); // From estimator + assertThat(usageMetadata.totalTokenCount()).isEqualTo(Optional.of(70)); // 50 + 20 + + // Verify the estimator was actually called + verify(tokenCountEstimator).estimateTokenCountInMessages(any()); + verify(tokenCountEstimator).estimateTokenCountInText("The weather is sunny today."); + } + + @Test + @DisplayName("Should prioritize TokenCountEstimator over TokenUsage when estimator is provided") + void testTokenCountEstimatorPriority() { + // Given + // Create a mock TokenCountEstimator + final TokenCountEstimator tokenCountEstimator = mock(TokenCountEstimator.class); + when(tokenCountEstimator.estimateTokenCountInMessages(any())).thenReturn(100); // From estimator + when(tokenCountEstimator.estimateTokenCountInText(any())).thenReturn(50); // From estimator + + // Create LangChain4j with the TokenCountEstimator using Builder + final LangChain4j langChain4jWithEstimator = + LangChain4j.builder() + .chatModel(chatModel) + .modelName(MODEL_NAME) + .tokenCountEstimator(tokenCountEstimator) + .build(); + + // Create a LlmRequest + final LlmRequest llmRequest = + LlmRequest.builder() + .contents(List.of(Content.fromParts(Part.fromText("What is the weather today?")))) + .build(); + + // Mock ChatResponse WITH actual TokenUsage from the LLM + final ChatResponse chatResponse = mock(ChatResponse.class); + final AiMessage aiMessage = AiMessage.from("The weather is sunny today."); + final TokenUsage actualTokenUsage = new TokenUsage(30, 15, 45); // Actual token counts from LLM + when(chatResponse.aiMessage()).thenReturn(aiMessage); + when(chatResponse.tokenUsage()).thenReturn(actualTokenUsage); // LLM provides token usage + when(chatModel.chat(any(ChatRequest.class))).thenReturn(chatResponse); + + // When + final LlmResponse response = + langChain4jWithEstimator.generateContent(llmRequest, false).blockingFirst(); + + // Then + // IMPORTANT: When TokenCountEstimator is present, it takes priority over TokenUsage + assertThat(response).isNotNull(); + assertThat(response.usageMetadata()).isPresent(); + final GenerateContentResponseUsageMetadata usageMetadata = response.usageMetadata().get(); + assertThat(usageMetadata.promptTokenCount()).isEqualTo(Optional.of(100)); // From estimator + assertThat(usageMetadata.candidatesTokenCount()).isEqualTo(Optional.of(50)); // From estimator + assertThat(usageMetadata.totalTokenCount()).isEqualTo(Optional.of(150)); // 100 + 50 + + // Verify the estimator was called (it takes priority) + verify(tokenCountEstimator).estimateTokenCountInMessages(any()); + verify(tokenCountEstimator).estimateTokenCountInText("The weather is sunny today."); + } + + @Test + @DisplayName("Should not include usageMetadata when TokenUsage is null and no estimator provided") + void testNoUsageMetadataWithoutEstimator() { + // Given + // Create LangChain4j WITHOUT TokenCountEstimator (default behavior) + final LangChain4j langChain4jNoEstimator = + LangChain4j.builder().chatModel(chatModel).modelName(MODEL_NAME).build(); + + // Create a LlmRequest + final LlmRequest llmRequest = + LlmRequest.builder() + .contents(List.of(Content.fromParts(Part.fromText("Hello, world!")))) + .build(); + + // Mock ChatResponse WITHOUT TokenUsage + final ChatResponse chatResponse = mock(ChatResponse.class); + final AiMessage aiMessage = AiMessage.from("Hello! How can I help you?"); + when(chatResponse.aiMessage()).thenReturn(aiMessage); + when(chatResponse.tokenUsage()).thenReturn(null); // No token usage from LLM + when(chatModel.chat(any(ChatRequest.class))).thenReturn(chatResponse); + + // When + final LlmResponse response = + langChain4jNoEstimator.generateContent(llmRequest, false).blockingFirst(); + + // Then + // Verify the response does NOT have usage metadata + assertThat(response).isNotNull(); + assertThat(response.content()).isPresent(); + assertThat(response.content().get().text()).isEqualTo("Hello! How can I help you?"); + + // IMPORTANT: usageMetadata should be empty when no TokenUsage and no estimator + assertThat(response.usageMetadata()).isEmpty(); + } } diff --git a/pom.xml b/pom.xml index cbeca1b72..40332472f 100644 --- a/pom.xml +++ b/pom.xml @@ -62,7 +62,7 @@ 0.18.1 3.41.0 3.9.0 - 1.11.0 + 1.12.2 2.0.17 1.4.5 1.0.0 From 3633a7dd071265087ea2ff148d419969b0c888ef Mon Sep 17 00:00:00 2001 From: Google Team Member Date: Fri, 20 Mar 2026 13:15:38 -0700 Subject: [PATCH 03/15] fix: Removing deprecated methods from Runner PiperOrigin-RevId: 886942637 --- .../java/com/google/adk/runner/Runner.java | 42 ------------------- 1 file changed, 42 deletions(-) diff --git a/core/src/main/java/com/google/adk/runner/Runner.java b/core/src/main/java/com/google/adk/runner/Runner.java index 1f7d924ab..849a3cd04 100644 --- a/core/src/main/java/com/google/adk/runner/Runner.java +++ b/core/src/main/java/com/google/adk/runner/Runner.java @@ -425,36 +425,6 @@ public Flowable runAsync(String userId, String sessionId, Content newMess return runAsync(userId, sessionId, newMessage, RunConfig.builder().build()); } - /** - * See {@link #runAsync(Session, Content, RunConfig, Map)}. - * - * @deprecated Use runAsync with sessionId. - */ - @Deprecated(since = "0.4.0", forRemoval = true) - public Flowable runAsync(Session session, Content newMessage, RunConfig runConfig) { - return runAsync(session, newMessage, runConfig, /* stateDelta= */ null); - } - - /** - * Runs the agent asynchronously using a provided Session object. - * - * @param session The session to run the agent in. - * @param newMessage The new message from the user to process. - * @param runConfig Configuration for the agent run. - * @param stateDelta Optional map of state updates to merge into the session for this run. - * @return A Flowable stream of {@link Event} objects generated by the agent during execution. - * @deprecated Use runAsync with sessionId. - */ - @Deprecated(since = "0.4.0", forRemoval = true) - public Flowable runAsync( - Session session, - Content newMessage, - RunConfig runConfig, - @Nullable Map stateDelta) { - return runAsyncImpl(session, newMessage, runConfig, stateDelta) - .compose(Tracing.trace("invocation")); - } - /** * Runs the agent asynchronously using a provided Session object. * @@ -735,18 +705,6 @@ protected Flowable runLiveImpl( }); } - /** - * Runs the agent asynchronously with a default user ID. - * - * @return stream of generated events. - */ - @Deprecated(since = "0.5.0", forRemoval = true) - public Flowable runWithSessionId( - String sessionId, Content newMessage, RunConfig runConfig) { - // TODO(b/410859954): Add user_id to getter or method signature. Assuming "tmp-user" for now. - return this.runAsync("tmp-user", sessionId, newMessage, runConfig); - } - /** * Checks if the agent and its parent chain allow transfer up the tree. * From 8e9fb085354f8148e00cbd236e8f29e82de56d6e Mon Sep 17 00:00:00 2001 From: Google Team Member Date: Fri, 20 Mar 2026 13:56:07 -0700 Subject: [PATCH 04/15] refactor: Use concatMap for sequential event persistence in Runner Ensure sequential event processing and persistence in ADK Runner. This ensures that events are appended in order and returned from runAsync in order. This aligns better with the Python implementation. PiperOrigin-RevId: 886961696 --- .../java/com/google/adk/runner/Runner.java | 2 +- .../com/google/adk/runner/RunnerTest.java | 42 +++++++++++++++++++ 2 files changed, 43 insertions(+), 1 deletion(-) diff --git a/core/src/main/java/com/google/adk/runner/Runner.java b/core/src/main/java/com/google/adk/runner/Runner.java index 849a3cd04..2bfbca881 100644 --- a/core/src/main/java/com/google/adk/runner/Runner.java +++ b/core/src/main/java/com/google/adk/runner/Runner.java @@ -529,7 +529,7 @@ private Flowable runAgentWithFreshSession( contextWithUpdatedSession .agent() .runAsync(contextWithUpdatedSession) - .flatMap( + .concatMap( agentEvent -> this.sessionService .appendEvent(updatedSession, agentEvent) diff --git a/core/src/test/java/com/google/adk/runner/RunnerTest.java b/core/src/test/java/com/google/adk/runner/RunnerTest.java index a3e21cb73..efd565c16 100644 --- a/core/src/test/java/com/google/adk/runner/RunnerTest.java +++ b/core/src/test/java/com/google/adk/runner/RunnerTest.java @@ -26,6 +26,7 @@ import static com.google.common.truth.Truth.assertThat; import static java.nio.charset.StandardCharsets.UTF_8; import static java.util.Arrays.stream; +import static java.util.concurrent.TimeUnit.MILLISECONDS; import static org.mockito.ArgumentMatchers.any; import static org.mockito.Mockito.CALLS_REAL_METHODS; import static org.mockito.Mockito.mock; @@ -33,6 +34,7 @@ import static org.mockito.Mockito.verify; import static org.mockito.Mockito.when; +import com.google.adk.agents.BaseAgent; import com.google.adk.agents.InvocationContext; import com.google.adk.agents.LiveRequestQueue; import com.google.adk.agents.LlmAgent; @@ -43,6 +45,7 @@ import com.google.adk.flows.llmflows.Functions; import com.google.adk.models.LlmResponse; import com.google.adk.plugins.BasePlugin; +import com.google.adk.sessions.BaseSessionService; import com.google.adk.sessions.Session; import com.google.adk.sessions.SessionKey; import com.google.adk.summarizer.EventsCompactionConfig; @@ -851,6 +854,45 @@ public void beforeRunCallback_withStateDelta_seesMergedState() { assertThat(sessionInCallback.state()).containsEntry("number", 123); } + @Test + public void runAsync_ensureEventsAreAppendedInOrder() throws Exception { + Event event1 = TestUtils.createEvent("1"); + Event event2 = TestUtils.createEvent("2"); + BaseAgent mockAgent = TestUtils.createSubAgent("test agent", event1, event2); + + BaseSessionService mockSessionService = mock(BaseSessionService.class); + + when(mockSessionService.getSession(any(), any(), any(), any())).thenReturn(Maybe.just(session)); + when(mockSessionService.appendEvent(any(), any())) + .thenAnswer( + invocation -> { + Event eventArg = invocation.getArgument(1); + Single result = Single.just(eventArg); + if (eventArg.id().equals("1")) { + // Artificially delay the first event to ensure it is appended first. + return result.delay(100, MILLISECONDS); + } + return result; + }); + + Runner mockRunner = + Runner.builder() + .agent(mockAgent) + .appName("test") + .sessionService(mockSessionService) + .build(); + + List results = + mockRunner + .runAsync("user", session.id(), createContent("user message")) + .toList() + .blockingGet(); + + assertThat(simplifyEvents(results)) + .containsExactly("author: content for event 1", "author: content for event 2") + .inOrder(); + } + private Content createContent(String text) { return Content.builder().parts(Part.builder().text(text).build()).build(); } From 3e21e7ac46b634341819b3543388a38caef85516 Mon Sep 17 00:00:00 2001 From: Guillaume Laforge Date: Sat, 21 Mar 2026 20:11:12 +0100 Subject: [PATCH 05/15] fix: handle null `AiMessage.text()` to prevent NPE and add unit test (PR #1035) --- .../adk/models/langchain4j/LangChain4j.java | 7 ++++-- .../models/langchain4j/LangChain4jTest.java | 23 +++++++++++++++++++ 2 files changed, 28 insertions(+), 2 deletions(-) diff --git a/contrib/langchain4j/src/main/java/com/google/adk/models/langchain4j/LangChain4j.java b/contrib/langchain4j/src/main/java/com/google/adk/models/langchain4j/LangChain4j.java index 8279dc21a..97331e7b4 100644 --- a/contrib/langchain4j/src/main/java/com/google/adk/models/langchain4j/LangChain4j.java +++ b/contrib/langchain4j/src/main/java/com/google/adk/models/langchain4j/LangChain4j.java @@ -607,8 +607,11 @@ private List toParts(AiMessage aiMessage) { }); return parts; } else { - Part part = Part.builder().text(aiMessage.text()).build(); - return List.of(part); + String text = aiMessage.text(); + if (text == null) { + return List.of(); + } + return List.of(Part.builder().text(text).build()); } } diff --git a/contrib/langchain4j/src/test/java/com/google/adk/models/langchain4j/LangChain4jTest.java b/contrib/langchain4j/src/test/java/com/google/adk/models/langchain4j/LangChain4jTest.java index f88237ff1..a1ec7a3c2 100644 --- a/contrib/langchain4j/src/test/java/com/google/adk/models/langchain4j/LangChain4jTest.java +++ b/contrib/langchain4j/src/test/java/com/google/adk/models/langchain4j/LangChain4jTest.java @@ -970,4 +970,27 @@ void testNoUsageMetadataWithoutEstimator() { // IMPORTANT: usageMetadata should be empty when no TokenUsage and no estimator assertThat(response.usageMetadata()).isEmpty(); } + + @Test + @DisplayName("Should handle null AiMessage text without throwing NPE") + void testGenerateContentWithNullAiMessageText() { + // Given + final LlmRequest llmRequest = + LlmRequest.builder().contents(List.of(Content.fromParts(Part.fromText("Hello")))).build(); + + final ChatResponse chatResponse = mock(ChatResponse.class); + final AiMessage aiMessage = mock(AiMessage.class); + when(aiMessage.text()).thenReturn(null); + when(aiMessage.hasToolExecutionRequests()).thenReturn(false); + when(chatResponse.aiMessage()).thenReturn(aiMessage); + when(chatModel.chat(any(ChatRequest.class))).thenReturn(chatResponse); + + // When + final Flowable responseFlowable = langChain4j.generateContent(llmRequest, false); + final LlmResponse response = responseFlowable.blockingFirst(); + // Then - no NPE thrown, and content has no text parts + assertThat(response).isNotNull(); + assertThat(response.content()).isPresent(); + assertThat(response.content().get().parts().orElse(List.of())).isEmpty(); + } } From cdc5199eb0f92cb95db2ee7ff139d67317968457 Mon Sep 17 00:00:00 2001 From: Guillaume Laforge Date: Mon, 23 Mar 2026 13:43:34 +0100 Subject: [PATCH 06/15] fix: add schema validation to SetModelResponseTool (issue #587 already implemented, but adding tests from PR #603) --- .../adk/tools/SetModelResponseTool.java | 7 +- .../adk/tools/SetModelResponseToolTest.java | 123 ++++++++++++++++++ 2 files changed, 129 insertions(+), 1 deletion(-) create mode 100644 core/src/test/java/com/google/adk/tools/SetModelResponseToolTest.java diff --git a/core/src/main/java/com/google/adk/tools/SetModelResponseTool.java b/core/src/main/java/com/google/adk/tools/SetModelResponseTool.java index e23d6414a..3b0e411b4 100644 --- a/core/src/main/java/com/google/adk/tools/SetModelResponseTool.java +++ b/core/src/main/java/com/google/adk/tools/SetModelResponseTool.java @@ -16,6 +16,7 @@ package com.google.adk.tools; +import com.google.adk.SchemaUtils; import com.google.genai.types.FunctionDeclaration; import com.google.genai.types.Schema; import io.reactivex.rxjava3.core.Single; @@ -58,6 +59,10 @@ public Optional declaration() { public Single> runAsync(Map args, ToolContext toolContext) { // This tool is a marker for the final response, it doesn't do anything but return its arguments // which will be captured as the final result. - return Single.just(args); + return Single.fromCallable( + () -> { + SchemaUtils.validateMapOnSchema(args, outputSchema, /* isInput= */ false); + return args; + }); } } diff --git a/core/src/test/java/com/google/adk/tools/SetModelResponseToolTest.java b/core/src/test/java/com/google/adk/tools/SetModelResponseToolTest.java new file mode 100644 index 000000000..64b600af9 --- /dev/null +++ b/core/src/test/java/com/google/adk/tools/SetModelResponseToolTest.java @@ -0,0 +1,123 @@ +/* + * Copyright 2025 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.google.adk.tools; + +import static com.google.common.truth.Truth.assertThat; +import static org.junit.Assert.assertThrows; + +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; +import com.google.genai.types.FunctionDeclaration; +import com.google.genai.types.Schema; +import java.util.Map; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +@RunWith(JUnit4.class) +public final class SetModelResponseToolTest { + + @Test + public void declaration_returnsCorrectFunctionDeclaration() { + Schema outputSchema = + Schema.builder() + .type("OBJECT") + .properties(ImmutableMap.of("field1", Schema.builder().type("STRING").build())) + .required(ImmutableList.of("field1")) + .build(); + + SetModelResponseTool tool = new SetModelResponseTool(outputSchema); + FunctionDeclaration declaration = tool.declaration().get(); + + assertThat(declaration.name()).hasValue("set_model_response"); + assertThat(declaration.description()).isPresent(); + assertThat(declaration.description().get()).contains("Set your final response"); + assertThat(declaration.parameters()).hasValue(outputSchema); + } + + @Test + public void runAsync_returnsArgs() { + Schema outputSchema = + Schema.builder() + .type("OBJECT") + .properties(ImmutableMap.of("field1", Schema.builder().type("STRING").build())) + .build(); + + SetModelResponseTool tool = new SetModelResponseTool(outputSchema); + Map args = ImmutableMap.of("field1", "value1"); + + Map result = tool.runAsync(args, null).blockingGet(); + + assertThat(result).isEqualTo(args); + } + + @Test + public void runAsync_validatesArgs() { + Schema outputSchema = + Schema.builder() + .type("OBJECT") + .properties(ImmutableMap.of("field1", Schema.builder().type("STRING").build())) + .required(ImmutableList.of("field1")) + .build(); + + SetModelResponseTool tool = new SetModelResponseTool(outputSchema); + Map invalidArgs = ImmutableMap.of("field2", "value2"); + + // Should throw validation error + IllegalArgumentException exception = + assertThrows( + IllegalArgumentException.class, () -> tool.runAsync(invalidArgs, null).blockingGet()); + + assertThat(exception).hasMessageThat().contains("does not match agent output schema"); + } + + @Test + public void runAsync_validatesComplexArgs() { + Schema complexSchema = + Schema.builder() + .type("OBJECT") + .properties( + ImmutableMap.of( + "id", + Schema.builder().type("INTEGER").build(), + "tags", + Schema.builder() + .type("ARRAY") + .items(Schema.builder().type("STRING").build()) + .build(), + "metadata", + Schema.builder() + .type("OBJECT") + .properties(ImmutableMap.of("key", Schema.builder().type("STRING").build())) + .build())) + .required(ImmutableList.of("id", "tags", "metadata")) + .build(); + + SetModelResponseTool tool = new SetModelResponseTool(complexSchema); + Map complexArgs = + ImmutableMap.of( + "id", 123, + "tags", ImmutableList.of("tag1", "tag2"), + "metadata", ImmutableMap.of("key", "value")); + + Map result = tool.runAsync(complexArgs, null).blockingGet(); + + assertThat(result).containsEntry("id", 123); + assertThat(result).containsEntry("tags", ImmutableList.of("tag1", "tag2")); + assertThat(result).containsEntry("metadata", ImmutableMap.of("key", "value")); + } +} From e9df447f1445044552e8710713ab5a76c2ae5093 Mon Sep 17 00:00:00 2001 From: "Michael Vorburger.ch" Date: Mon, 23 Mar 2026 08:42:56 -0700 Subject: [PATCH 07/15] Remove explicit SLF4J binding from city-time-weather ADK tutorial. The `slf4j-simple` dependency and the exclusion of `logback-classic` are removed, allowing the default logging implementation provided by `google-adk-dev` to be used. PiperOrigin-RevId: 888114465 --- tutorials/city-time-weather/pom.xml | 10 ---------- 1 file changed, 10 deletions(-) diff --git a/tutorials/city-time-weather/pom.xml b/tutorials/city-time-weather/pom.xml index aeb110cf6..19ef08a2d 100644 --- a/tutorials/city-time-weather/pom.xml +++ b/tutorials/city-time-weather/pom.xml @@ -36,16 +36,6 @@ com.google.adk google-adk-dev ${project.version} - - - ch.qos.logback - logback-classic - - - - - org.slf4j - slf4j-simple From ce4b642220c785f48711d92657faccaa4eded4f1 Mon Sep 17 00:00:00 2001 From: ddobrin Date: Mon, 23 Mar 2026 13:33:26 -0400 Subject: [PATCH 08/15] Fixes #490 and #1064 ToolConverter issues in the spring-ai module --- contrib/spring-ai/DOCUMENT-GEMINI.md | 86 ------------------ contrib/spring-ai/README.md | 26 +++--- .../adk/models/springai/ToolConverter.java | 88 +++++++++++++------ .../ToolConverterArgumentProcessingTest.java | 84 ++++++++++++++++++ .../models/springai/ToolConverterTest.java | 34 +++++++ 5 files changed, 190 insertions(+), 128 deletions(-) delete mode 100644 contrib/spring-ai/DOCUMENT-GEMINI.md diff --git a/contrib/spring-ai/DOCUMENT-GEMINI.md b/contrib/spring-ai/DOCUMENT-GEMINI.md deleted file mode 100644 index 393562528..000000000 --- a/contrib/spring-ai/DOCUMENT-GEMINI.md +++ /dev/null @@ -1,86 +0,0 @@ -# Documentation for the ADK Spring AI Library - -## 📖 Overview -The `google-adk-spring-ai` library provides an integration layer between the Google Agent Development Kit (ADK) and the Spring AI project. It allows developers to use Spring AI's `ChatModel`, `StreamingChatModel`, and `EmbeddingModel` as `BaseLlm` and `Embedding` implementations within the ADK framework. - -The library handles the conversion between ADK's request/response formats and Spring AI's prompt/chat response formats. It also includes auto-configuration to automatically expose Spring AI models as ADK `SpringAI` and `SpringAIEmbedding` beans in a Spring Boot application. - -## 🛠️ Building -To include this library in your project, use the following Maven coordinates: - -```xml - - com.google.adk - google-adk-spring-ai - 0.3.1-SNAPSHOT - -``` - -You will also need to include a dependency for the specific Spring AI model you want to use, for example: -```xml - - org.springframework.ai - spring-ai-openai - -``` - -## 🚀 Usage -The primary way to use this library is through Spring Boot auto-configuration. By including the `google-adk-spring-ai` dependency and a Spring AI model dependency (e.g., `spring-ai-openai`), the library will automatically create a `SpringAI` bean. This bean can then be injected and used as a `BaseLlm` in the ADK. - -**Example `application.properties`:** -```properties -# OpenAI configuration -spring.ai.openai.api-key=${OPENAI_API_KEY} -spring.ai.openai.chat.options.model=gpt-4o-mini -spring.ai.openai.chat.options.temperature=0.7 - -# ADK Spring AI configuration -adk.spring-ai.model=gpt-4o-mini -adk.spring-ai.validation.enabled=true -``` - -**Example usage in a Spring service:** -```java -import com.google.adk.models.BaseLlm; -import com.google.adk.models.LlmRequest; -import com.google.adk.models.LlmResponse; -import org.springframework.beans.factory.annotation.Autowired; -import org.springframework.stereotype.Service; -import reactor.core.publisher.Mono; - -@Service -public class MyAgentService { - - private final BaseLlm llm; - - @Autowired - public MyAgentService(BaseLlm llm) { - this.llm = llm; - } - - public Mono generateResponse(String prompt) { - LlmRequest request = LlmRequest.builder() - .addText(prompt) - .build(); - return Mono.from(llm.generateContent(request)) - .map(llmResponse -> llmResponse.content().get().parts().get(0).text().get()); - } -} -``` - -## 📚 API Reference -### Key Classes -- **`SpringAI`**: The main class that wraps a Spring AI `ChatModel` and/or `StreamingChatModel` and implements the ADK `BaseLlm` interface. - - **Methods**: - - `generateContent(LlmRequest llmRequest, boolean stream)`: Generates content, either streaming or non-streaming, by calling the underlying Spring AI model. It converts the ADK `LlmRequest` to a Spring AI `Prompt` and the `ChatResponse` back to an ADK `LlmResponse`. - -- **`SpringAIEmbedding`**: Wraps a Spring AI `EmbeddingModel` to be used for generating embeddings within the ADK framework. - -- **`SpringAIAutoConfiguration`**: The Spring Boot auto-configuration class that automatically discovers and configures `SpringAI` and `SpringAIEmbedding` beans based on the `ChatModel`, `StreamingChatModel`, and `EmbeddingModel` beans present in the application context. - -- **`SpringAIProperties`**: A configuration properties class (`@ConfigurationProperties("adk.spring-ai")`) that allows for customization of the Spring AI integration. - - **Properties**: - - `model`: The model name to use. - - `validation.enabled`: Whether to enable configuration validation. - - `validation.fail-fast`: Whether to fail fast on validation errors. - - `observability.enabled`: Whether to enable observability features. diff --git a/contrib/spring-ai/README.md b/contrib/spring-ai/README.md index c45f0e033..0ce7de4fe 100644 --- a/contrib/spring-ai/README.md +++ b/contrib/spring-ai/README.md @@ -18,21 +18,21 @@ To use ADK Java with the Spring AI integration in your application, add the foll com.google.adk google-adk - 0.3.1-SNAPSHOT + 1.0.1-rc.1-SNAPSHOT com.google.adk google-adk-spring-ai - 0.3.1-SNAPSHOT + 1.0.1-rc.1-SNAPSHOT org.springframework.ai spring-ai-bom - 1.1.0-M3 + 2.0.0-M3 pom import @@ -109,14 +109,14 @@ Add the Spring AI provider dependencies for the AI services you want to use: org.springframework.boot spring-boot-starter-parent - 3.2.0 + 4.0.2 17 - 1.1.0-M3 - 0.3.1-SNAPSHOT + 2.0.0-M3 + 1.0.1-rc.1-SNAPSHOT @@ -271,7 +271,7 @@ public class MyAdkSpringAiApplication { .anthropicApi(anthropicApi) .build(); - return new SpringAI(chatModel, "claude-3-5-sonnet-20241022"); + return new SpringAI(chatModel, "claude-sonnet-4-6"); } @Bean @@ -312,7 +312,7 @@ spring: api-key: ${ANTHROPIC_API_KEY} chat: options: - model: claude-3-5-sonnet-20241022 + model: claude-sonnet-4-6 temperature: 0.7 # ADK Spring AI Configuration @@ -365,13 +365,13 @@ The main adapter class that implements `BaseLlm` and wraps Spring AI `ChatModel` **Usage:** ```java // With ChatModel only -SpringAI springAI = new SpringAI(chatModel, "claude-sonnet-4-20250514"); +SpringAI springAI = new SpringAI(chatModel, "claude-sonnet-4-6"); // With both ChatModel and StreamingChatModel -SpringAI springAI = new SpringAI(chatModel, streamingChatModel, "claude-sonnet-4-20250514"); +SpringAI springAI = new SpringAI(chatModel, streamingChatModel, "claude-sonnet-4-6"); // With observability configuration -SpringAI springAI = new SpringAI(chatModel, "claude-sonnet-4-20250514", observabilityConfig); +SpringAI springAI = new SpringAI(chatModel, "claude-sonnet-4-6", observabilityConfig); ``` #### 2. MessageConverter (MessageConverter.java) @@ -533,7 +533,7 @@ The library works with any Spring AI provider: - Features: Chat, streaming, function calling, embeddings 2. **Anthropic** (`spring-ai-anthropic`) - - Models: Claude 3.5 Sonnet, Claude 3 Haiku + - Models: Claude 4.x Sonnet, Claude 4.x Haiku - Features: Chat, streaming, function calling - **Note:** Requires proper function schema registration @@ -563,7 +563,7 @@ The library works with any Spring AI provider: #### Anthropic - **Function Calling:** Requires explicit schema registration using `inputSchema()` method -- **Model Names:** Use full model names like `claude-3-5-sonnet-20241022` +- **Model Names:** Use full model names like `claude-sonnet-4-6` - **API Key:** Requires `ANTHROPIC_API_KEY` environment variable #### OpenAI diff --git a/contrib/spring-ai/src/main/java/com/google/adk/models/springai/ToolConverter.java b/contrib/spring-ai/src/main/java/com/google/adk/models/springai/ToolConverter.java index 95dafadb4..4012ee5d6 100644 --- a/contrib/spring-ai/src/main/java/com/google/adk/models/springai/ToolConverter.java +++ b/contrib/spring-ai/src/main/java/com/google/adk/models/springai/ToolConverter.java @@ -23,6 +23,7 @@ import java.util.HashMap; import java.util.List; import java.util.Map; +import java.util.Set; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import org.springframework.ai.tool.ToolCallback; @@ -172,6 +173,17 @@ public List convertToSpringAiTools(Map tools) { } catch (Exception e) { logger.error("Error serializing schema to JSON: {}", e.getMessage(), e); } + } else if (declaration.parametersJsonSchema().isPresent()) { + callbackBuilder.inputType(Map.class); + try { + String schemaJson = + new com.fasterxml.jackson.databind.ObjectMapper() + .writeValueAsString(declaration.parametersJsonSchema().get()); + callbackBuilder.inputSchema(schemaJson); + logger.debug("Set input schema JSON from parametersJsonSchema: {}", schemaJson); + } catch (Exception e) { + logger.error("Error serializing parametersJsonSchema to JSON: {}", e.getMessage(), e); + } } toolCallbacks.add(callbackBuilder.build()); @@ -187,45 +199,63 @@ public List convertToSpringAiTools(Map tools) { */ private Map processArguments( Map args, FunctionDeclaration declaration) { - // If the arguments already match the expected format, return as-is if (declaration.parameters().isPresent()) { var schema = declaration.parameters().get(); if (schema.properties().isPresent()) { - var expectedParams = schema.properties().get().keySet(); - - // Check if all expected parameters are present at the top level - boolean allParamsPresent = expectedParams.stream().allMatch(args::containsKey); - if (allParamsPresent) { - return args; + return normalizeArguments(args, schema.properties().get().keySet()); + } + } else if (declaration.parametersJsonSchema().isPresent()) { + try { + @SuppressWarnings("unchecked") + Map schemaMap = + new com.fasterxml.jackson.databind.ObjectMapper() + .convertValue(declaration.parametersJsonSchema().get(), Map.class); + Object propertiesObj = schemaMap.get("properties"); + if (propertiesObj instanceof Map) { + @SuppressWarnings("unchecked") + Set expectedParams = ((Map) propertiesObj).keySet(); + return normalizeArguments(args, expectedParams); } + } catch (Exception e) { + logger.warn( + "Error processing parametersJsonSchema for argument mapping: {}", e.getMessage()); + } + } - // Check if arguments are nested under a single key (common pattern) - if (args.size() == 1) { - var singleValue = args.values().iterator().next(); - if (singleValue instanceof Map) { - @SuppressWarnings("unchecked") - Map nestedArgs = (Map) singleValue; - boolean allNestedParamsPresent = - expectedParams.stream().allMatch(nestedArgs::containsKey); - if (allNestedParamsPresent) { - return nestedArgs; - } - } - } + // If no processing worked, return original args and let ADK handle the error + return args; + } - // Check if we have a single parameter function and got a direct value - if (expectedParams.size() == 1) { - String expectedParam = expectedParams.iterator().next(); - if (args.size() == 1 && !args.containsKey(expectedParam)) { - // Try to map the single value to the expected parameter name - Object singleValue = args.values().iterator().next(); - return Map.of(expectedParam, singleValue); - } + private Map normalizeArguments( + Map args, Set expectedParams) { + // Check if all expected parameters are present at the top level + boolean allParamsPresent = expectedParams.stream().allMatch(args::containsKey); + if (allParamsPresent) { + return args; + } + + // Check if arguments are nested under a single key (common pattern) + if (args.size() == 1) { + var singleValue = args.values().iterator().next(); + if (singleValue instanceof Map) { + @SuppressWarnings("unchecked") + Map nestedArgs = (Map) singleValue; + boolean allNestedParamsPresent = expectedParams.stream().allMatch(nestedArgs::containsKey); + if (allNestedParamsPresent) { + return nestedArgs; } } } - // If no processing worked, return original args and let ADK handle the error + // Check if we have a single parameter function and got a direct value + if (expectedParams.size() == 1) { + String expectedParam = expectedParams.iterator().next(); + if (args.size() == 1 && !args.containsKey(expectedParam)) { + Object singleValue = args.values().iterator().next(); + return Map.of(expectedParam, singleValue); + } + } + return args; } diff --git a/contrib/spring-ai/src/test/java/com/google/adk/models/springai/ToolConverterArgumentProcessingTest.java b/contrib/spring-ai/src/test/java/com/google/adk/models/springai/ToolConverterArgumentProcessingTest.java index 301a145e0..77b988837 100644 --- a/contrib/spring-ai/src/test/java/com/google/adk/models/springai/ToolConverterArgumentProcessingTest.java +++ b/contrib/spring-ai/src/test/java/com/google/adk/models/springai/ToolConverterArgumentProcessingTest.java @@ -115,6 +115,90 @@ private Map invokeProcessArguments( return (Map) method.invoke(converter, args, declaration); } + @Test + void testArgumentProcessingWithParametersJsonSchema_correctFormat() throws Exception { + ToolConverter converter = new ToolConverter(); + Method processArguments = getProcessArgumentsMethod(converter); + + com.google.genai.types.FunctionDeclaration declaration = + com.google.genai.types.FunctionDeclaration.builder() + .name("getWeatherInfo") + .description("Get weather information") + .parametersJsonSchema( + Map.of( + "type", "object", "properties", Map.of("location", Map.of("type", "string")))) + .build(); + + Map correctArgs = Map.of("location", "San Francisco"); + Map processedArgs = + invokeProcessArguments(processArguments, converter, correctArgs, declaration); + + assertThat(processedArgs).isEqualTo(correctArgs); + } + + @Test + void testArgumentProcessingWithParametersJsonSchema_nestedFormat() throws Exception { + ToolConverter converter = new ToolConverter(); + Method processArguments = getProcessArgumentsMethod(converter); + + com.google.genai.types.FunctionDeclaration declaration = + com.google.genai.types.FunctionDeclaration.builder() + .name("getWeatherInfo") + .description("Get weather information") + .parametersJsonSchema( + Map.of( + "type", "object", "properties", Map.of("location", Map.of("type", "string")))) + .build(); + + Map nestedArgs = Map.of("args", Map.of("location", "San Francisco")); + Map processedArgs = + invokeProcessArguments(processArguments, converter, nestedArgs, declaration); + + assertThat(processedArgs).containsEntry("location", "San Francisco"); + } + + @Test + void testArgumentProcessingWithParametersJsonSchema_directValue() throws Exception { + ToolConverter converter = new ToolConverter(); + Method processArguments = getProcessArgumentsMethod(converter); + + com.google.genai.types.FunctionDeclaration declaration = + com.google.genai.types.FunctionDeclaration.builder() + .name("getWeatherInfo") + .description("Get weather information") + .parametersJsonSchema( + Map.of( + "type", "object", "properties", Map.of("location", Map.of("type", "string")))) + .build(); + + Map directValueArgs = Map.of("value", "San Francisco"); + Map processedArgs = + invokeProcessArguments(processArguments, converter, directValueArgs, declaration); + + assertThat(processedArgs).containsEntry("location", "San Francisco"); + } + + @Test + void testArgumentProcessingWithParametersJsonSchema_noMatch() throws Exception { + ToolConverter converter = new ToolConverter(); + Method processArguments = getProcessArgumentsMethod(converter); + + com.google.genai.types.FunctionDeclaration declaration = + com.google.genai.types.FunctionDeclaration.builder() + .name("getWeatherInfo") + .description("Get weather information") + .parametersJsonSchema( + Map.of( + "type", "object", "properties", Map.of("location", Map.of("type", "string")))) + .build(); + + Map wrongArgs = Map.of("city", "San Francisco", "country", "USA"); + Map processedArgs = + invokeProcessArguments(processArguments, converter, wrongArgs, declaration); + + assertThat(processedArgs).isEqualTo(wrongArgs); + } + public static class WeatherTools { public static Map getWeatherInfo(String location) { return Map.of( diff --git a/contrib/spring-ai/src/test/java/com/google/adk/models/springai/ToolConverterTest.java b/contrib/spring-ai/src/test/java/com/google/adk/models/springai/ToolConverterTest.java index 231c8e1fe..1f3044159 100644 --- a/contrib/spring-ai/src/test/java/com/google/adk/models/springai/ToolConverterTest.java +++ b/contrib/spring-ai/src/test/java/com/google/adk/models/springai/ToolConverterTest.java @@ -26,6 +26,7 @@ import java.util.Optional; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; +import org.springframework.ai.tool.ToolCallback; class ToolConverterTest { @@ -178,4 +179,37 @@ void testToolMetadata() { assertThat(metadata.getDescription()).isEqualTo("Test description"); assertThat(metadata.getDeclaration()).isEqualTo(function); } + + @Test + void testConvertToSpringAiToolsWithParametersJsonSchema() { + Map jsonSchema = + Map.of( + "type", + "object", + "properties", + Map.of("location", Map.of("type", "string", "description", "City name")), + "required", + List.of("location")); + + FunctionDeclaration function = + FunctionDeclaration.builder() + .name("get_weather") + .description("Get weather for a location") + .parametersJsonSchema(jsonSchema) + .build(); + + BaseTool testTool = + new BaseTool("get_weather", "Get weather for a location") { + @Override + public Optional declaration() { + return Optional.of(function); + } + }; + + Map tools = Map.of("get_weather", testTool); + List toolCallbacks = toolConverter.convertToSpringAiTools(tools); + + assertThat(toolCallbacks).hasSize(1); + assertThat(toolCallbacks.get(0).getToolDefinition().name()).isEqualTo("get_weather"); + } } From 8a7f816ffeb80d58b7e8e2a32d7c70ba8ad89d73 Mon Sep 17 00:00:00 2001 From: Google Team Member Date: Tue, 24 Mar 2026 08:00:58 -0700 Subject: [PATCH 09/15] refactor: use mock api answers for tests PiperOrigin-RevId: 888667558 --- .../adk/sessions/VertexAiSessionService.java | 4 ++-- .../google/adk/sessions/MockApiAnswer.java | 11 ++++++++++ .../sessions/VertexAiSessionServiceTest.java | 21 ++----------------- 3 files changed, 15 insertions(+), 21 deletions(-) diff --git a/core/src/main/java/com/google/adk/sessions/VertexAiSessionService.java b/core/src/main/java/com/google/adk/sessions/VertexAiSessionService.java index b62add27a..99e7e3479 100644 --- a/core/src/main/java/com/google/adk/sessions/VertexAiSessionService.java +++ b/core/src/main/java/com/google/adk/sessions/VertexAiSessionService.java @@ -135,7 +135,7 @@ private ListSessionsResponse parseListSessionsResponse( JsonNode listSessionsResponseMap, String appName, String userId) { JsonNode sessionsNode = listSessionsResponseMap.get("sessions"); if (sessionsNode == null || sessionsNode.isNull() || sessionsNode.isEmpty()) { - return ListSessionsResponse.builder().sessions(new ArrayList<>()).build(); + return ListSessionsResponse.builder().build(); } List> apiSessions = objectMapper.convertValue(sessionsNode, new TypeReference>>() {}); @@ -174,7 +174,7 @@ public Single listEvents(String appName, String userId, Stri private ListEventsResponse parseListEventsResponse(JsonNode listEventsResponse) { JsonNode sessionEventsNode = listEventsResponse.get("sessionEvents"); if (sessionEventsNode == null || sessionEventsNode.isEmpty()) { - return ListEventsResponse.builder().events(new ArrayList<>()).build(); + return ListEventsResponse.builder().build(); } return ListEventsResponse.builder() .events( diff --git a/core/src/test/java/com/google/adk/sessions/MockApiAnswer.java b/core/src/test/java/com/google/adk/sessions/MockApiAnswer.java index 111b1dce3..743bcee8d 100644 --- a/core/src/test/java/com/google/adk/sessions/MockApiAnswer.java +++ b/core/src/test/java/com/google/adk/sessions/MockApiAnswer.java @@ -36,14 +36,25 @@ class MockApiAnswer implements Answer { private final Map sessionMap; private final Map eventMap; + private final String rawApiResponse; MockApiAnswer(Map sessionMap, Map eventMap) { this.sessionMap = sessionMap; this.eventMap = eventMap; + this.rawApiResponse = null; + } + + MockApiAnswer(String rawApiResponse) { + this.sessionMap = null; + this.eventMap = null; + this.rawApiResponse = rawApiResponse; } @Override public ApiResponse answer(InvocationOnMock invocation) throws Throwable { + if (rawApiResponse != null) { + return responseWithBody(rawApiResponse); + } String httpMethod = invocation.getArgument(0); String path = invocation.getArgument(1); if (httpMethod.equals("POST") && SESSIONS_REGEX.matcher(path).matches()) { diff --git a/core/src/test/java/com/google/adk/sessions/VertexAiSessionServiceTest.java b/core/src/test/java/com/google/adk/sessions/VertexAiSessionServiceTest.java index 3dab94b46..dd62263d7 100644 --- a/core/src/test/java/com/google/adk/sessions/VertexAiSessionServiceTest.java +++ b/core/src/test/java/com/google/adk/sessions/VertexAiSessionServiceTest.java @@ -25,8 +25,6 @@ import java.util.Optional; import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.ConcurrentMap; -import okhttp3.MediaType; -import okhttp3.ResponseBody; import org.junit.Before; import org.junit.Test; import org.junit.runner.RunWith; @@ -39,21 +37,6 @@ public class VertexAiSessionServiceTest { private static final ObjectMapper mapper = JsonBaseModel.getMapper(); - private static final MediaType JSON_MEDIA_TYPE = - MediaType.parse("application/json; charset=utf-8"); - - private static ApiResponse apiResponseJson(String json) { - return new ApiResponse() { - @Override - public ResponseBody getResponseBody() { - return ResponseBody.create(JSON_MEDIA_TYPE, json); - } - - @Override - public void close() {} - }; - } - private static final String MOCK_SESSION_STRING_1 = """ { @@ -338,7 +321,7 @@ public void listSessions_empty() { @Test public void listSessions_missingSessionsField_returnsEmpty() { when(mockApiClient.request("GET", "reasoningEngines/123/sessions?filter=user_id=userX", "")) - .thenReturn(apiResponseJson("{}")); + .thenAnswer(new MockApiAnswer("{}")); assertThat(vertexAiSessionService.listSessions("123", "userX").blockingGet().sessions()) .isEmpty(); @@ -347,7 +330,7 @@ public void listSessions_missingSessionsField_returnsEmpty() { @Test public void listSessions_nullSessionsField_returnsEmpty() { when(mockApiClient.request("GET", "reasoningEngines/123/sessions?filter=user_id=userY", "")) - .thenReturn(apiResponseJson("{\"sessions\": null}")); + .thenAnswer(new MockApiAnswer("{\"sessions\": null}")); assertThat(vertexAiSessionService.listSessions("123", "userY").blockingGet().sessions()) .isEmpty(); From 677b6d7452aa28fab42d554d18c150d59ca88eec Mon Sep 17 00:00:00 2001 From: Mateusz Krawiec Date: Wed, 25 Mar 2026 03:24:35 -0700 Subject: [PATCH 10/15] fix: parallel agent execution PiperOrigin-RevId: 889140710 --- .../com/google/adk/agents/ParallelAgent.java | 30 +++++-- .../google/adk/agents/ParallelAgentTest.java | 86 ++++++++++++++++++- 2 files changed, 108 insertions(+), 8 deletions(-) diff --git a/core/src/main/java/com/google/adk/agents/ParallelAgent.java b/core/src/main/java/com/google/adk/agents/ParallelAgent.java index f30d951aa..2593ec13a 100644 --- a/core/src/main/java/com/google/adk/agents/ParallelAgent.java +++ b/core/src/main/java/com/google/adk/agents/ParallelAgent.java @@ -16,11 +16,14 @@ package com.google.adk.agents; import static com.google.common.base.Strings.isNullOrEmpty; -import static com.google.common.collect.ImmutableList.toImmutableList; import com.google.adk.agents.ConfigAgentUtils.ConfigurationException; import com.google.adk.events.Event; +import com.google.errorprone.annotations.CanIgnoreReturnValue; import io.reactivex.rxjava3.core.Flowable; +import io.reactivex.rxjava3.core.Scheduler; +import io.reactivex.rxjava3.schedulers.Schedulers; +import java.util.ArrayList; import java.util.List; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -35,6 +38,7 @@ public class ParallelAgent extends BaseAgent { private static final Logger logger = LoggerFactory.getLogger(ParallelAgent.class); + private final Scheduler scheduler; /** * Constructor for ParallelAgent. @@ -44,24 +48,35 @@ public class ParallelAgent extends BaseAgent { * @param subAgents The list of sub-agents to run in parallel. * @param beforeAgentCallback Optional callback before the agent runs. * @param afterAgentCallback Optional callback after the agent runs. + * @param scheduler The scheduler to use for parallel execution. */ private ParallelAgent( String name, String description, List subAgents, List beforeAgentCallback, - List afterAgentCallback) { + List afterAgentCallback, + Scheduler scheduler) { super(name, description, subAgents, beforeAgentCallback, afterAgentCallback); + this.scheduler = scheduler; } /** Builder for {@link ParallelAgent}. */ public static class Builder extends BaseAgent.Builder { + private Scheduler scheduler = Schedulers.io(); + + @CanIgnoreReturnValue + public Builder scheduler(Scheduler scheduler) { + this.scheduler = scheduler; + return this; + } + @Override public ParallelAgent build() { return new ParallelAgent( - name, description, subAgents, beforeAgentCallback, afterAgentCallback); + name, description, subAgents, beforeAgentCallback, afterAgentCallback, scheduler); } } @@ -129,10 +144,11 @@ protected Flowable runAsyncImpl(InvocationContext invocationContext) { } var updatedInvocationContext = setBranchForCurrentAgent(this, invocationContext); - return Flowable.merge( - currentSubAgents.stream() - .map(subAgent -> subAgent.runAsync(updatedInvocationContext)) - .collect(toImmutableList())); + List> agentFlowables = new ArrayList<>(); + for (BaseAgent subAgent : currentSubAgents) { + agentFlowables.add(subAgent.runAsync(updatedInvocationContext).subscribeOn(scheduler)); + } + return Flowable.merge(agentFlowables); } /** diff --git a/core/src/test/java/com/google/adk/agents/ParallelAgentTest.java b/core/src/test/java/com/google/adk/agents/ParallelAgentTest.java index a6afb5793..e51240c45 100644 --- a/core/src/test/java/com/google/adk/agents/ParallelAgentTest.java +++ b/core/src/test/java/com/google/adk/agents/ParallelAgentTest.java @@ -25,7 +25,10 @@ import com.google.genai.types.Content; import com.google.genai.types.Part; import io.reactivex.rxjava3.core.Flowable; +import io.reactivex.rxjava3.core.Scheduler; import io.reactivex.rxjava3.schedulers.Schedulers; +import io.reactivex.rxjava3.schedulers.TestScheduler; +import io.reactivex.rxjava3.subscribers.TestSubscriber; import java.util.List; import org.junit.Test; import org.junit.runner.RunWith; @@ -36,10 +39,16 @@ public final class ParallelAgentTest { static class TestingAgent extends BaseAgent { private final long delayMillis; + private final Scheduler scheduler; private TestingAgent(String name, String description, long delayMillis) { + this(name, description, delayMillis, Schedulers.computation()); + } + + private TestingAgent(String name, String description, long delayMillis, Scheduler scheduler) { super(name, description, ImmutableList.of(), null, null); this.delayMillis = delayMillis; + this.scheduler = scheduler; } @Override @@ -55,7 +64,7 @@ protected Flowable runAsyncImpl(InvocationContext invocationContext) { .build()); if (delayMillis > 0) { - return event.delay(delayMillis, MILLISECONDS, Schedulers.computation()); + return event.delay(delayMillis, MILLISECONDS, scheduler); } return event; } @@ -110,4 +119,79 @@ public void runAsync_noSubAgents_returnsEmptyFlowable() { assertThat(events).isEmpty(); } + + static class BlockingAgent extends BaseAgent { + private final long sleepMillis; + + private BlockingAgent(String name, long sleepMillis) { + super(name, "Blocking Agent", ImmutableList.of(), null, null); + this.sleepMillis = sleepMillis; + } + + @Override + protected Flowable runAsyncImpl(InvocationContext invocationContext) { + return Flowable.fromCallable( + () -> { + Thread.sleep(sleepMillis); + return Event.builder() + .author(name()) + .branch(invocationContext.branch().orElse(null)) + .invocationId(invocationContext.invocationId()) + .content(Content.fromParts(Part.fromText("Done"))) + .build(); + }); + } + + @Override + protected Flowable runLiveImpl(InvocationContext invocationContext) { + throw new UnsupportedOperationException("Not implemented"); + } + } + + @Test + public void runAsync_blockingSubAgents_shouldExecuteInParallel() { + long sleepTime = 1000; + BlockingAgent agent1 = new BlockingAgent("agent1", sleepTime); + BlockingAgent agent2 = new BlockingAgent("agent2", sleepTime); + + ParallelAgent parallelAgent = + ParallelAgent.builder().name("parallel_agent").subAgents(agent1, agent2).build(); + + InvocationContext invocationContext = createInvocationContext(parallelAgent); + + long startTime = System.currentTimeMillis(); + List events = parallelAgent.runAsync(invocationContext).toList().blockingGet(); + long duration = System.currentTimeMillis() - startTime; + + assertThat(events).hasSize(2); + // If parallel, duration should be less than 1.5 * sleepTime (1500ms). + assertThat(duration).isAtLeast(sleepTime); + assertThat(duration).isLessThan((long) (1.5 * sleepTime)); + } + + @Test + public void runAsync_withTestScheduler_usesVirtualTime() { + TestScheduler testScheduler = new TestScheduler(); + long delayMillis = 1000; + TestingAgent agent = + new TestingAgent("delayed_agent", "Delayed Agent", delayMillis, testScheduler); + + ParallelAgent parallelAgent = + ParallelAgent.builder() + .name("parallel_agent") + .subAgents(agent) + .scheduler(testScheduler) + .build(); + + InvocationContext invocationContext = createInvocationContext(parallelAgent); + + TestSubscriber testSubscriber = parallelAgent.runAsync(invocationContext).test(); + + testScheduler.advanceTimeBy(delayMillis - 100, MILLISECONDS); + testSubscriber.assertNoValues(); + testSubscriber.assertNotComplete(); + testScheduler.advanceTimeBy(200, MILLISECONDS); + testSubscriber.assertValueCount(1); + testSubscriber.assertComplete(); + } } From 5a2abbfe6f9e4e1ebdd5b918e34fcdb144603b5a Mon Sep 17 00:00:00 2001 From: Mateusz Krawiec Date: Wed, 25 Mar 2026 03:25:54 -0700 Subject: [PATCH 11/15] fix: resolve MCP tool parsing errors in Claude integration The Claude model integration parsing logic failed when processing MCP tool responses because it only extracted output from the legacy `result` field. Extended extraction logic to: - Support native MCP `content` arrays. - Support legacy `result` structures natively. - Fallback to generic JSON serialization of the entire map. Additionally, updated AbstractMcpTool.wrapCallResult() format to match Python ADK. PiperOrigin-RevId: 889141233 --- .../java/com/google/adk/models/Claude.java | 51 ++++++++-- .../google/adk/tools/mcp/AbstractMcpTool.java | 52 +--------- .../com/google/adk/models/ClaudeTest.java | 97 +++++++++++++++++++ .../adk/tools/mcp/AbstractMcpToolTest.java | 62 ++++++++++++ 4 files changed, 203 insertions(+), 59 deletions(-) create mode 100644 core/src/test/java/com/google/adk/models/ClaudeTest.java create mode 100644 core/src/test/java/com/google/adk/tools/mcp/AbstractMcpToolTest.java diff --git a/core/src/main/java/com/google/adk/models/Claude.java b/core/src/main/java/com/google/adk/models/Claude.java index ebb786e35..01feda1d4 100644 --- a/core/src/main/java/com/google/adk/models/Claude.java +++ b/core/src/main/java/com/google/adk/models/Claude.java @@ -31,8 +31,7 @@ import com.anthropic.models.messages.ToolUnion; import com.anthropic.models.messages.ToolUseBlockParam; import com.fasterxml.jackson.core.type.TypeReference; -import com.fasterxml.jackson.databind.ObjectMapper; -import com.fasterxml.jackson.datatype.jdk8.Jdk8Module; +import com.google.adk.JsonBaseModel; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; import com.google.genai.types.Content; @@ -170,9 +169,22 @@ private ContentBlockParam partToAnthropicMessageBlock(Part part) { .build()); } else if (part.functionResponse().isPresent()) { String content = ""; - if (part.functionResponse().get().response().isPresent() - && part.functionResponse().get().response().get().getOrDefault("result", null) != null) { - content = part.functionResponse().get().response().get().get("result").toString(); + if (part.functionResponse().get().response().isPresent()) { + Map responseData = part.functionResponse().get().response().get(); + + Object contentObj = responseData.get("content"); + Object resultObj = responseData.get("result"); + + if (contentObj instanceof List list && !list.isEmpty()) { + // Native MCP format: list of content blocks + content = extractMcpContentBlocks(list); + } else if (resultObj != null) { + // ADK tool result object + content = resultObj instanceof String s ? s : serializeToJson(resultObj); + } else if (!responseData.isEmpty()) { + // Fallback: arbitrary JSON structure + content = serializeToJson(responseData); + } } return ContentBlockParam.ofToolResult( ToolResultBlockParam.builder() @@ -184,6 +196,30 @@ private ContentBlockParam partToAnthropicMessageBlock(Part part) { throw new UnsupportedOperationException("Not supported yet."); } + private String extractMcpContentBlocks(List list) { + List textBlocks = new ArrayList<>(); + for (Object item : list) { + if (item instanceof Map m && "text".equals(m.get("type"))) { + Object textObj = m.get("text"); + textBlocks.add(textObj != null ? String.valueOf(textObj) : ""); + } else if (item instanceof String s) { + textBlocks.add(s); + } else { + textBlocks.add(serializeToJson(item)); + } + } + return String.join("\n", textBlocks); + } + + private String serializeToJson(Object obj) { + try { + return JsonBaseModel.getMapper().writeValueAsString(obj); + } catch (Exception e) { + logger.warn("Failed to serialize object to JSON", e); + return String.valueOf(obj); + } + } + private void updateTypeString(Map valueDict) { if (valueDict == null) { return; @@ -221,10 +257,9 @@ private Tool functionDeclarationToAnthropicTool(FunctionDeclaration functionDecl .get() .forEach( (key, schema) -> { - ObjectMapper objectMapper = new ObjectMapper(); - objectMapper.registerModule(new Jdk8Module()); Map schemaMap = - objectMapper.convertValue(schema, new TypeReference>() {}); + JsonBaseModel.getMapper() + .convertValue(schema, new TypeReference>() {}); updateTypeString(schemaMap); properties.put(key, schemaMap); }); diff --git a/core/src/main/java/com/google/adk/tools/mcp/AbstractMcpTool.java b/core/src/main/java/com/google/adk/tools/mcp/AbstractMcpTool.java index d9c28e501..3b0c3d70a 100644 --- a/core/src/main/java/com/google/adk/tools/mcp/AbstractMcpTool.java +++ b/core/src/main/java/com/google/adk/tools/mcp/AbstractMcpTool.java @@ -16,7 +16,6 @@ package com.google.adk.tools.mcp; -import com.fasterxml.jackson.core.JsonProcessingException; import com.fasterxml.jackson.core.type.TypeReference; import com.fasterxml.jackson.databind.ObjectMapper; import com.google.adk.tools.BaseTool; @@ -24,13 +23,9 @@ import com.google.common.collect.ImmutableMap; import com.google.genai.types.FunctionDeclaration; import io.modelcontextprotocol.spec.McpSchema.CallToolResult; -import io.modelcontextprotocol.spec.McpSchema.Content; import io.modelcontextprotocol.spec.McpSchema.JsonSchema; -import io.modelcontextprotocol.spec.McpSchema.TextContent; import io.modelcontextprotocol.spec.McpSchema.Tool; import io.modelcontextprotocol.spec.McpSchema.ToolAnnotations; -import java.util.ArrayList; -import java.util.List; import java.util.Map; import java.util.Optional; @@ -116,51 +111,6 @@ protected static Map wrapCallResult( return ImmutableMap.of("error", "MCP framework error: CallToolResult was null"); } - List contents = callResult.content(); - Boolean isToolError = callResult.isError(); - - if (isToolError != null && isToolError) { - String errorMessage = "Tool execution failed."; - if (contents != null - && !contents.isEmpty() - && contents.get(0) instanceof TextContent textContent) { - if (textContent.text() != null && !textContent.text().isEmpty()) { - errorMessage += " Details: " + textContent.text(); - } - } - return ImmutableMap.of("error", errorMessage); - } - - if (contents == null || contents.isEmpty()) { - return ImmutableMap.of(); - } - - List textOutputs = new ArrayList<>(); - for (Content content : contents) { - if (content instanceof TextContent textContent) { - if (textContent.text() != null) { - textOutputs.add(textContent.text()); - } - } - } - - if (textOutputs.isEmpty()) { - return ImmutableMap.of( - "error", - "Tool '" + mcpToolName + "' returned content that is not TextContent.", - "content_details", - contents.toString()); - } - - List> resultMaps = new ArrayList<>(); - for (String textOutput : textOutputs) { - try { - resultMaps.add( - objectMapper.readValue(textOutput, new TypeReference>() {})); - } catch (JsonProcessingException e) { - resultMaps.add(ImmutableMap.of("text", textOutput)); - } - } - return ImmutableMap.of("text_output", resultMaps); + return objectMapper.convertValue(callResult, new TypeReference>() {}); } } diff --git a/core/src/test/java/com/google/adk/models/ClaudeTest.java b/core/src/test/java/com/google/adk/models/ClaudeTest.java new file mode 100644 index 000000000..677d40627 --- /dev/null +++ b/core/src/test/java/com/google/adk/models/ClaudeTest.java @@ -0,0 +1,97 @@ +/* + * Copyright 2026 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.google.adk.models; + +import static com.google.common.truth.Truth.assertThat; + +import com.anthropic.client.AnthropicClient; +import com.anthropic.models.messages.ContentBlockParam; +import com.anthropic.models.messages.ToolResultBlockParam; +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; +import com.google.genai.types.FunctionResponse; +import com.google.genai.types.Part; +import java.lang.reflect.Method; +import java.util.Map; +import org.junit.Before; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; +import org.mockito.Mockito; + +@RunWith(JUnit4.class) +public final class ClaudeTest { + + private Claude claude; + private Method partToAnthropicMessageBlockMethod; + + @Before + public void setUp() throws Exception { + AnthropicClient mockClient = Mockito.mock(AnthropicClient.class); + claude = new Claude("claude-3-opus", mockClient); + + // Access private method for testing the extraction logic + partToAnthropicMessageBlockMethod = + Claude.class.getDeclaredMethod("partToAnthropicMessageBlock", Part.class); + partToAnthropicMessageBlockMethod.setAccessible(true); + } + + @Test + public void testPartToAnthropicMessageBlock_mcpNativeFormat() throws Exception { + Map responseData = + ImmutableMap.of( + "content", + ImmutableList.of(ImmutableMap.of("type", "text", "text", "Extracted native MCP text"))); + FunctionResponse funcParam = + FunctionResponse.builder().name("test_tool").response(responseData).id("call_123").build(); + Part part = Part.builder().functionResponse(funcParam).build(); + + ContentBlockParam result = + (ContentBlockParam) partToAnthropicMessageBlockMethod.invoke(claude, part); + + ToolResultBlockParam toolResult = result.asToolResult(); + assertThat(toolResult.content().get().asString()).isEqualTo("Extracted native MCP text"); + } + + @Test + public void testPartToAnthropicMessageBlock_legacyResultKey() throws Exception { + Map responseData = ImmutableMap.of("result", "Legacy result text"); + FunctionResponse funcParam = + FunctionResponse.builder().name("test_tool").response(responseData).id("call_123").build(); + Part part = Part.builder().functionResponse(funcParam).build(); + + ContentBlockParam result = + (ContentBlockParam) partToAnthropicMessageBlockMethod.invoke(claude, part); + + ToolResultBlockParam toolResult = result.asToolResult(); + assertThat(toolResult.content().get().asString()).isEqualTo("Legacy result text"); + } + + @Test + public void testPartToAnthropicMessageBlock_jsonFallback() throws Exception { + Map responseData = ImmutableMap.of("custom_key", "custom_value"); + FunctionResponse funcParam = + FunctionResponse.builder().name("test_tool").response(responseData).id("call_123").build(); + Part part = Part.builder().functionResponse(funcParam).build(); + + ContentBlockParam result = + (ContentBlockParam) partToAnthropicMessageBlockMethod.invoke(claude, part); + + ToolResultBlockParam toolResult = result.asToolResult(); + assertThat(toolResult.content().get().asString()).contains("\"custom_key\":\"custom_value\""); + } +} diff --git a/core/src/test/java/com/google/adk/tools/mcp/AbstractMcpToolTest.java b/core/src/test/java/com/google/adk/tools/mcp/AbstractMcpToolTest.java new file mode 100644 index 000000000..e8d9ea631 --- /dev/null +++ b/core/src/test/java/com/google/adk/tools/mcp/AbstractMcpToolTest.java @@ -0,0 +1,62 @@ +/* + * Copyright 2026 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.google.adk.tools.mcp; + +import static com.google.common.truth.Truth.assertThat; + +import com.fasterxml.jackson.databind.ObjectMapper; +import com.google.common.collect.ImmutableList; +import io.modelcontextprotocol.spec.McpSchema.CallToolResult; +import io.modelcontextprotocol.spec.McpSchema.TextContent; +import java.util.List; +import java.util.Map; +import org.junit.Before; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +@RunWith(JUnit4.class) +public final class AbstractMcpToolTest { + + private ObjectMapper objectMapper; + + @Before + public void setUp() { + objectMapper = new ObjectMapper(); + } + + @Test + public void testWrapCallResult_success() { + CallToolResult result = + CallToolResult.builder() + .content(ImmutableList.of(new TextContent("success"))) + .isError(false) + .build(); + + Map map = AbstractMcpTool.wrapCallResult(objectMapper, "my_tool", result); + + assertThat(map).containsKey("content"); + List content = (List) map.get("content"); + assertThat(content).hasSize(1); + + Map contentItem = (Map) content.get(0); + assertThat(contentItem).containsEntry("type", "text"); + assertThat(contentItem).containsEntry("text", "success"); + + assertThat(map).containsEntry("isError", false); + } +} From 6a5a55eb3e531c6f8a7083712308c4800f680ca5 Mon Sep 17 00:00:00 2001 From: "Ganesh, Mohan" Date: Thu, 26 Feb 2026 18:15:32 -0500 Subject: [PATCH 12/15] fix(firestore): Remove hardcoded dependency version --- contrib/firestore-session-service/pom.xml | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/contrib/firestore-session-service/pom.xml b/contrib/firestore-session-service/pom.xml index ed1ecd09b..34b577984 100644 --- a/contrib/firestore-session-service/pom.xml +++ b/contrib/firestore-session-service/pom.xml @@ -14,7 +14,9 @@ See the License for the specific language governing permissions and limitations under the License. --> - + 4.0.0 @@ -49,7 +51,6 @@ com.google.cloud google-cloud-firestore - 3.30.3 com.google.truth From 8ab7f072cdaa363e07b7a786044376c021c4c009 Mon Sep 17 00:00:00 2001 From: pkarmarkar Date: Tue, 6 Jan 2026 13:24:22 -0800 Subject: [PATCH 13/15] fix: add media/image support in Spring AI MessageConverter Previously, MessageConverter only transferred text content from ADK to Spring AI, ignoring image and media attachments. This caused vision model requests to fail even though Spring AI's underlying models (like GPT-4o) support image inputs. Updated MessageConverter to properly handle image/media parts by constructing UserMessage with Media attachments. Fixes #705 --- .../adk/models/springai/MessageConverter.java | 12 +- .../models/springai/MessageConverterTest.java | 181 +++++++++++++++++- 2 files changed, 183 insertions(+), 10 deletions(-) diff --git a/contrib/spring-ai/src/main/java/com/google/adk/models/springai/MessageConverter.java b/contrib/spring-ai/src/main/java/com/google/adk/models/springai/MessageConverter.java index 036a898bb..3983b08a5 100644 --- a/contrib/spring-ai/src/main/java/com/google/adk/models/springai/MessageConverter.java +++ b/contrib/spring-ai/src/main/java/com/google/adk/models/springai/MessageConverter.java @@ -221,8 +221,7 @@ private List handleUserContent(Content content) { } catch (Exception e) { // Log warning but continue processing other parts // In production, consider proper logging framework - System.err.println( - "Warning: Failed to parse media mime type: " + blob.mimeType().get()); + System.err.println("Warning: Failed to process media part: " + e.getMessage()); } } } else if (part.fileData().isPresent()) { @@ -235,19 +234,14 @@ private List handleUserContent(Content content) { URI uri = URI.create(fileData.fileUri().get()); mediaList.add(new Media(mimeType, uri)); } catch (Exception e) { - System.err.println( - "Warning: Failed to parse media mime type: " + fileData.mimeType().get()); + System.err.println("Warning: Failed to process media part: " + e.getMessage()); } } } } List messages = new ArrayList<>(); - // Create UserMessage with text - // TODO: Media attachments support - UserMessage constructors with media are private in Spring - // AI 1.1.0 - // For now, only text content is supported - messages.add(new UserMessage(textBuilder.toString())); + messages.add(UserMessage.builder().text(textBuilder.toString()).media(mediaList).build()); messages.addAll(toolResponseMessages); return messages; diff --git a/contrib/spring-ai/src/test/java/com/google/adk/models/springai/MessageConverterTest.java b/contrib/spring-ai/src/test/java/com/google/adk/models/springai/MessageConverterTest.java index a57644b5d..b861a71f2 100644 --- a/contrib/spring-ai/src/test/java/com/google/adk/models/springai/MessageConverterTest.java +++ b/contrib/spring-ai/src/test/java/com/google/adk/models/springai/MessageConverterTest.java @@ -60,7 +60,9 @@ void testToLlmPromptWithUserMessage() { assertThat(prompt.getInstructions()).hasSize(1); Message message = prompt.getInstructions().get(0); assertThat(message).isInstanceOf(UserMessage.class); - assertThat(((UserMessage) message).getText()).isEqualTo("Hello, how are you?"); + UserMessage userMessage = (UserMessage) message; + assertThat(userMessage.getText()).isEqualTo("Hello, how are you?"); + assertThat(userMessage.getMedia()).isEmpty(); } @Test @@ -444,4 +446,181 @@ void testCombineMultipleSystemMessagesForGeminiCompatibility() { assertThat(secondMessage).isInstanceOf(UserMessage.class); assertThat(((UserMessage) secondMessage).getText()).isEqualTo("Hello world"); } + + @Test + void testUserMessageWithInlineMediaData() { + // Test conversion of ADK Content with inline media (image bytes) to Spring AI UserMessage + byte[] imageData = "fake-image-data".getBytes(); + String mimeType = "image/png"; + + Content userContent = + Content.builder() + .role("user") + .parts( + List.of( + Part.fromText("What's in this image?"), + Part.builder() + .inlineData( + com.google.genai.types.Blob.builder() + .mimeType(mimeType) + .data(imageData) + .build()) + .build())) + .build(); + + LlmRequest request = LlmRequest.builder().contents(List.of(userContent)).build(); + + Prompt prompt = messageConverter.toLlmPrompt(request); + + assertThat(prompt.getInstructions()).hasSize(1); + Message message = prompt.getInstructions().get(0); + assertThat(message).isInstanceOf(UserMessage.class); + + UserMessage userMessage = (UserMessage) message; + assertThat(userMessage.getText()).isEqualTo("What's in this image?"); + assertThat(userMessage.getMedia()).hasSize(1); + org.springframework.ai.content.Media media = userMessage.getMedia().get(0); + assertThat(media.getMimeType().toString()).isEqualTo(mimeType); + assertThat(media.getData()).isInstanceOf(byte[].class); + byte[] actualData = (byte[]) media.getData(); + assertThat(actualData).isEqualTo(imageData); + } + + @Test + void testUserMessageWithFileMediaData() { + // Test conversion of ADK Content with file-based media (URI) to Spring AI UserMessage + String fileUri = "gs://bucket/image.jpg"; + String mimeType = "image/jpeg"; + + Content userContent = + Content.builder() + .role("user") + .parts( + List.of( + Part.fromText("Analyze this image"), + Part.builder() + .fileData( + com.google.genai.types.FileData.builder() + .mimeType(mimeType) + .fileUri(fileUri) + .build()) + .build())) + .build(); + + LlmRequest request = LlmRequest.builder().contents(List.of(userContent)).build(); + + Prompt prompt = messageConverter.toLlmPrompt(request); + + assertThat(prompt.getInstructions()).hasSize(1); + Message message = prompt.getInstructions().get(0); + assertThat(message).isInstanceOf(UserMessage.class); + + UserMessage userMessage = (UserMessage) message; + assertThat(userMessage.getText()).isEqualTo("Analyze this image"); + assertThat(userMessage.getMedia()).hasSize(1); + org.springframework.ai.content.Media media = userMessage.getMedia().get(0); + assertThat(media.getMimeType().toString()).isEqualTo(mimeType); + assertThat(media.getData()).isInstanceOf(String.class); + String actualUri = (String) media.getData(); + assertThat(actualUri).isEqualTo(fileUri); + } + + @Test + void testUserMessageWithMultipleMediaAttachments() { + // Test conversion with multiple media attachments + byte[] image1 = "image1-data".getBytes(); + byte[] image2 = "image2-data".getBytes(); + + Content userContent = + Content.builder() + .role("user") + .parts( + List.of( + Part.fromText("Compare these images"), + Part.builder() + .inlineData( + com.google.genai.types.Blob.builder() + .mimeType("image/png") + .data(image1) + .build()) + .build(), + Part.builder() + .inlineData( + com.google.genai.types.Blob.builder() + .mimeType("image/jpeg") + .data(image2) + .build()) + .build())) + .build(); + + LlmRequest request = LlmRequest.builder().contents(List.of(userContent)).build(); + + Prompt prompt = messageConverter.toLlmPrompt(request); + + assertThat(prompt.getInstructions()).hasSize(1); + UserMessage userMessage = (UserMessage) prompt.getInstructions().get(0); + assertThat(userMessage.getText()).isEqualTo("Compare these images"); + assertThat(userMessage.getMedia()).hasSize(2); + } + + @Test + void testUserMessageWithInvalidMimeTypeGracefullySkipsMediaPart() { + // Test that an invalid MIME type string causes the media part to be skipped gracefully + byte[] imageData = "fake-image-data".getBytes(); + + Content userContent = + Content.builder() + .role("user") + .parts( + List.of( + Part.fromText("What's in this image?"), + Part.builder() + .inlineData( + com.google.genai.types.Blob.builder() + .mimeType("invalid/mime/type!!!") // invalid MIME type + .data(imageData) + .build()) + .build())) + .build(); + + LlmRequest request = LlmRequest.builder().contents(List.of(userContent)).build(); + + // Should not throw — invalid MIME type is silently skipped + Prompt prompt = messageConverter.toLlmPrompt(request); + + assertThat(prompt.getInstructions()).hasSize(1); + UserMessage userMessage = (UserMessage) prompt.getInstructions().get(0); + assertThat(userMessage.getText()).isEqualTo("What's in this image?"); + // Media part is skipped due to invalid MIME type + assertThat(userMessage.getMedia()).isEmpty(); + } + + @Test + void testUserMessageWithMediaOnly() { + // Test conversion with media but no text + byte[] imageData = "image-only".getBytes(); + + Content userContent = + Content.builder() + .role("user") + .parts( + List.of( + Part.builder() + .inlineData( + com.google.genai.types.Blob.builder() + .mimeType("image/png") + .data(imageData) + .build()) + .build())) + .build(); + + LlmRequest request = LlmRequest.builder().contents(List.of(userContent)).build(); + + Prompt prompt = messageConverter.toLlmPrompt(request); + + assertThat(prompt.getInstructions()).hasSize(1); + UserMessage userMessage = (UserMessage) prompt.getInstructions().get(0); + assertThat(userMessage.getText()).isEmpty(); + assertThat(userMessage.getMedia()).hasSize(1); + } } From 3650c7f547bd372681f41deba3c78a94a5c3cf94 Mon Sep 17 00:00:00 2001 From: Google Team Member Date: Wed, 25 Mar 2026 07:48:14 -0700 Subject: [PATCH 14/15] chore: update google-genai version to 1.44.0 PiperOrigin-RevId: 889240232 --- pom.xml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pom.xml b/pom.xml index 40332472f..bdd2f33b1 100644 --- a/pom.xml +++ b/pom.xml @@ -51,7 +51,7 @@ 1.51.0 0.17.2 2.47.0 - 1.43.0 + 1.44.0 4.33.5 5.11.4 5.20.0 From 84dff10a3ee7f47e30a40409e56b5e9365c69815 Mon Sep 17 00:00:00 2001 From: Google Team Member Date: Wed, 25 Mar 2026 08:02:43 -0700 Subject: [PATCH 15/15] fix: Fixing tracing for function calls Fixing when the execute tool happens in the graph. Refactor and simplify the tracing logic for function and tool calls within the ADK. The primary goal is to consolidate multiple tracing events into more cohesive operations, specifically merging "tool_call" and "tool_response" into a single "execute_tool" operation. ### Key Changes: * **Consolidated Tool Tracing:** Replaced the separate `traceToolCall` and `traceToolResponse` methods with a unified `traceToolExecution` in `Tracing.java`. This reduces span noise by representing a tool's lifecycle as a single "execute_tool" operation containing both arguments and results (or errors). * **Standardized Operation Names:** Introduced constants for core Gen AI operations: `invoke_agent`, `execute_tool`, `send_data`, and `call_llm`. * **Improved Error Tracing:** `traceToolExecution` and `traceCallLlm` now explicitly accept an optional `Exception`, allowing them to automatically set the span status to error and record the exception. * **Refactored Tracing API:** * `traceSendData` and other methods now require an explicit `Span` argument, moving away from implicit context lookups where appropriate. * Added `traceMergedToolCalls` to specifically handle the telemetry for parallel tool executions. * **Flow Logic Cleanup:** Simplified `Functions.java` and `BaseLlmFlow.java` by removing redundant context passing and aligning with the new consolidated tracing methods. * **Test Suite Updates:** Significantly updated `ContextPropagationTest.java` to reflect the new tracing model. Several manual hierarchy tests were removed in favor of testing the consolidated `execute_tool` logic and updated attributes. PiperOrigin-RevId: 889246953 --- .../adk/flows/llmflows/BaseLlmFlow.java | 17 +- .../google/adk/flows/llmflows/Functions.java | 42 +-- .../com/google/adk/telemetry/Tracing.java | 184 ++++++----- .../com/google/adk/agents/LlmAgentTest.java | 10 +- .../com/google/adk/runner/RunnerTest.java | 30 +- .../adk/telemetry/ContextPropagationTest.java | 303 +++--------------- 6 files changed, 208 insertions(+), 378 deletions(-) diff --git a/core/src/main/java/com/google/adk/flows/llmflows/BaseLlmFlow.java b/core/src/main/java/com/google/adk/flows/llmflows/BaseLlmFlow.java index d4fe1b838..aa62b9f31 100644 --- a/core/src/main/java/com/google/adk/flows/llmflows/BaseLlmFlow.java +++ b/core/src/main/java/com/google/adk/flows/llmflows/BaseLlmFlow.java @@ -233,7 +233,7 @@ private Flowable callLlm( callLlmContext) .doOnSubscribe( s -> - Tracing.traceCallLlm( + traceCallLlm( span, context, eventForCallbackUsage.id(), @@ -520,6 +520,7 @@ public Flowable runLive(InvocationContext invocationContext) { .doOnComplete( () -> Tracing.traceSendData( + Span.current(), invocationContext, eventIdForSendData, llmRequestAfterPreprocess.contents())) @@ -529,6 +530,7 @@ public Flowable runLive(InvocationContext invocationContext) { span.setStatus(StatusCode.ERROR, error.getMessage()); span.recordException(error); Tracing.traceSendData( + Span.current(), invocationContext, eventIdForSendData, llmRequestAfterPreprocess.contents()); @@ -706,6 +708,19 @@ private Flowable buildPostprocessingEvents( return processorEvents.concatWith(Flowable.just(modelResponseEvent)).concatWith(functionEvents); } + /** + * Traces an LLM call without an associated exception. This is an overload for {@link + * Tracing#traceCallLlm} for successful calls. + */ + private void traceCallLlm( + Span span, + InvocationContext context, + String eventId, + LlmRequest llmRequest, + LlmResponse llmResponse) { + Tracing.traceCallLlm(span, context, eventId, llmRequest, llmResponse, null); + } + private Event buildModelResponseEvent( Event baseEventForLlmResponse, LlmRequest llmRequest, LlmResponse llmResponse) { Event.Builder eventBuilder = diff --git a/core/src/main/java/com/google/adk/flows/llmflows/Functions.java b/core/src/main/java/com/google/adk/flows/llmflows/Functions.java index 84a8141ea..0b0e5b4d5 100644 --- a/core/src/main/java/com/google/adk/flows/llmflows/Functions.java +++ b/core/src/main/java/com/google/adk/flows/llmflows/Functions.java @@ -178,8 +178,12 @@ public static Maybe handleFunctionCalls( if (events.size() > 1) { return Maybe.just(mergedEvent) - .doOnSuccess(event -> Tracing.traceToolResponse(event.id(), event)) - .compose(Tracing.trace("tool_response").setParent(parentContext)); + .compose( + Tracing.trace("execute_tool (merged)") + .setParent(parentContext) + .onSuccess( + (span, event) -> + Tracing.traceMergedToolCalls(span, event.id(), event))); } return Maybe.just(mergedEvent); }); @@ -269,10 +273,8 @@ private static Function> getFunctionCallMapper( tool, toolContext, functionCall, - functionArgs, - parentContext) - : callTool( - tool, functionArgs, toolContext, parentContext)) + functionArgs) + : callTool(tool, functionArgs, toolContext)) .compose(Tracing.withContext(parentContext))); return postProcessFunctionResult( @@ -296,8 +298,7 @@ private static Maybe> processFunctionLive( BaseTool tool, ToolContext toolContext, FunctionCall functionCall, - Map args, - Context parentContext) { + Map args) { // Case 1: Handle a call to stopStreaming if (functionCall.name().get().equals("stopStreaming") && args.containsKey("functionName")) { String functionNameToStop = (String) args.get("functionName"); @@ -365,7 +366,7 @@ private static Maybe> processFunctionLive( } // Case 3: Fallback for regular, non-streaming tools - return callTool(tool, args, toolContext, parentContext); + return callTool(tool, args, toolContext); } public static Set getLongRunningFunctionCalls( @@ -426,12 +427,22 @@ private static Maybe postProcessFunctionResult( Event event = buildResponseEvent( tool, finalFunctionResult, toolContext, invocationContext); - Tracing.traceToolResponse(event.id(), event); return Maybe.just(event); }); }) .compose( - Tracing.trace("tool_response [" + tool.name() + "]").setParent(parentContext)); + Tracing.trace("execute_tool [" + tool.name() + "]") + .setParent(parentContext) + .onSuccess( + (span, event) -> + Tracing.traceToolExecution( + span, + tool.name(), + tool.description(), + tool.getClass().getSimpleName(), + functionArgs, + event, + null))); } private static Optional mergeParallelFunctionResponseEvents( @@ -579,17 +590,10 @@ private static Maybe> maybeInvokeAfterToolCall( } private static Maybe> callTool( - BaseTool tool, Map args, ToolContext toolContext, Context parentContext) { + BaseTool tool, Map args, ToolContext toolContext) { return tool.runAsync(args, toolContext) .toMaybe() - .doOnSubscribe( - d -> - Tracing.traceToolCall( - tool.name(), tool.description(), tool.getClass().getSimpleName(), args)) .doOnError(t -> Span.current().recordException(t)) - .compose( - Tracing.>trace("tool_call [" + tool.name() + "]") - .setParent(parentContext)) .onErrorResumeNext( e -> Maybe.error( diff --git a/core/src/main/java/com/google/adk/telemetry/Tracing.java b/core/src/main/java/com/google/adk/telemetry/Tracing.java index 215e317e1..589215073 100644 --- a/core/src/main/java/com/google/adk/telemetry/Tracing.java +++ b/core/src/main/java/com/google/adk/telemetry/Tracing.java @@ -33,6 +33,7 @@ import io.opentelemetry.api.GlobalOpenTelemetry; import io.opentelemetry.api.common.AttributeKey; import io.opentelemetry.api.trace.Span; +import io.opentelemetry.api.trace.StatusCode; import io.opentelemetry.api.trace.Tracer; import io.opentelemetry.context.Context; import io.opentelemetry.context.Scope; @@ -61,6 +62,7 @@ import java.util.function.BiConsumer; import java.util.function.Consumer; import java.util.function.Supplier; +import org.jspecify.annotations.Nullable; import org.reactivestreams.Publisher; import org.reactivestreams.Subscriber; import org.reactivestreams.Subscription; @@ -77,6 +79,11 @@ public class Tracing { private static final Logger log = LoggerFactory.getLogger(Tracing.class); + private static final String INVOKE_AGENT_OPERATION = "invoke_agent"; + private static final String EXECUTE_TOOL_OPERATION = "execute_tool"; + private static final String SEND_DATA_OPERATION = "send_data"; + private static final String CALL_LLM_OPERATION = "call_llm"; + private static final AttributeKey> GEN_AI_RESPONSE_FINISH_REASONS = AttributeKey.stringArrayKey("gen_ai.response.finish_reasons"); @@ -134,15 +141,6 @@ public class Tracing { private Tracing() {} - private static void traceWithSpan(String methodName, Consumer traceAction) { - Span span = Span.current(); - if (!span.getSpanContext().isValid()) { - log.trace("{}: No valid span in current context.", methodName); - return; - } - traceAction.accept(span); - } - private static void setInvocationAttributes( Span span, InvocationContext invocationContext, String eventId) { span.setAttribute(ADK_INVOCATION_ID, invocationContext.invocationId()); @@ -159,12 +157,6 @@ private static void setInvocationAttributes( } } - private static void setToolExecutionAttributes(Span span) { - span.setAttribute(GEN_AI_OPERATION_NAME, "execute_tool"); - span.setAttribute(ADK_LLM_REQUEST, "{}"); - span.setAttribute(ADK_LLM_RESPONSE, "{}"); - } - private static void setJsonAttribute(Span span, AttributeKey key, Object value) { if (!CAPTURE_MESSAGE_CONTENT_IN_SPANS) { span.setAttribute(key, "{}"); @@ -198,7 +190,7 @@ public static void setTracerForTesting(Tracer tracer) { */ public static void traceAgentInvocation( Span span, String agentName, String agentDescription, InvocationContext invocationContext) { - span.setAttribute(GEN_AI_OPERATION_NAME, "invoke_agent"); + span.setAttribute(GEN_AI_OPERATION_NAME, INVOKE_AGENT_OPERATION); span.setAttribute(GEN_AI_AGENT_DESCRIPTION, agentDescription); span.setAttribute(GEN_AI_AGENT_NAME, agentName); if (invocationContext.session() != null && invocationContext.session().id() != null) { @@ -207,58 +199,62 @@ public static void traceAgentInvocation( } /** - * Traces tool call arguments. - * - * @param args The arguments to the tool call. - */ - public static void traceToolCall( - String toolName, String toolDescription, String toolType, Map args) { - traceWithSpan( - "traceToolCall", - span -> { - setToolExecutionAttributes(span); - span.setAttribute(GEN_AI_TOOL_NAME, toolName); - span.setAttribute(GEN_AI_TOOL_DESCRIPTION, toolDescription); - span.setAttribute(GEN_AI_TOOL_TYPE, toolType); - - setJsonAttribute(span, ADK_TOOL_CALL_ARGS, args); - }); - } - - /** - * Traces tool response event. + * Traces a tool execution, including its arguments, response, and any potential error. * - * @param eventId The ID of the event. - * @param functionResponseEvent The function response event. + * @param span The span representing the tool execution. + * @param toolName The name of the tool. + * @param toolDescription The tool's description. + * @param toolType The tool's type (e.g., "FunctionTool"). + * @param args The arguments passed to the tool. + * @param functionResponseEvent The event containing the tool's response, if successful. + * @param error The exception thrown during execution, if any. */ - public static void traceToolResponse(String eventId, Event functionResponseEvent) { - traceWithSpan( - "traceToolResponse", - span -> { - setToolExecutionAttributes(span); - span.setAttribute(ADK_EVENT_ID, eventId); - - FunctionResponse functionResponse = - functionResponseEvent.functionResponses().stream().findFirst().orElse(null); - - String toolCallId = ""; - Object toolResponse = ""; - if (functionResponse != null) { - toolCallId = functionResponse.id().orElse(toolCallId); - if (functionResponse.response().isPresent()) { - toolResponse = functionResponse.response().get(); - } - } - - span.setAttribute(GEN_AI_TOOL_CALL_ID, toolCallId); + public static void traceToolExecution( + Span span, + String toolName, + String toolDescription, + String toolType, + Map args, + @Nullable Event functionResponseEvent, + @Nullable Exception error) { + span.setAttribute(GEN_AI_OPERATION_NAME, EXECUTE_TOOL_OPERATION); + span.setAttribute(GEN_AI_TOOL_NAME, toolName); + span.setAttribute(GEN_AI_TOOL_DESCRIPTION, toolDescription); + span.setAttribute(GEN_AI_TOOL_TYPE, toolType); + + setJsonAttribute(span, ADK_TOOL_CALL_ARGS, args); + + if (functionResponseEvent != null) { + span.setAttribute(ADK_EVENT_ID, functionResponseEvent.id()); + FunctionResponse functionResponse = + functionResponseEvent.functionResponses().stream().findFirst().orElse(null); + + String toolCallId = ""; + Object toolResponse = ""; + if (functionResponse != null) { + toolCallId = functionResponse.id().orElse(toolCallId); + if (functionResponse.response().isPresent()) { + toolResponse = functionResponse.response().get(); + } + } + span.setAttribute(GEN_AI_TOOL_CALL_ID, toolCallId); + Object finalToolResponse = + (toolResponse instanceof Map) ? toolResponse : ImmutableMap.of("result", toolResponse); + setJsonAttribute(span, ADK_TOOL_RESPONSE, finalToolResponse); + } else { + // Set placeholder if no response event is available (e.g., due to an error) + span.setAttribute(GEN_AI_TOOL_CALL_ID, ""); + setJsonAttribute(span, ADK_TOOL_RESPONSE, "{}"); + } - Object finalToolResponse = - (toolResponse instanceof Map) - ? toolResponse - : ImmutableMap.of("result", toolResponse); + // Also set empty LLM attributes for UI compatibility, like in traceToolResponse + span.setAttribute(ADK_LLM_REQUEST, "{}"); + span.setAttribute(ADK_LLM_RESPONSE, "{}"); - setJsonAttribute(span, ADK_TOOL_RESPONSE, finalToolResponse); - }); + if (error != null) { + span.setStatus(StatusCode.ERROR, error.getMessage()); + span.recordException(error); + } } /** @@ -303,8 +299,10 @@ public static void traceCallLlm( InvocationContext invocationContext, String eventId, LlmRequest llmRequest, - LlmResponse llmResponse) { + LlmResponse llmResponse, + @Nullable Exception error) { span.setAttribute(GEN_AI_SYSTEM, "gcp.vertex.agent"); + span.setAttribute(GEN_AI_OPERATION_NAME, CALL_LLM_OPERATION); llmRequest.model().ifPresent(modelName -> span.setAttribute(GEN_AI_REQUEST_MODEL, modelName)); setInvocationAttributes(span, invocationContext, eventId); @@ -312,6 +310,11 @@ public static void traceCallLlm( setJsonAttribute(span, ADK_LLM_REQUEST, buildLlmRequestForTrace(llmRequest)); setJsonAttribute(span, ADK_LLM_RESPONSE, llmResponse); + if (error != null) { + span.setStatus(StatusCode.ERROR, error.getMessage()); + span.recordException(error); + } + llmRequest .config() .ifPresent( @@ -352,18 +355,45 @@ public static void traceCallLlm( * @param data A list of content objects being sent. */ public static void traceSendData( - InvocationContext invocationContext, String eventId, List data) { - traceWithSpan( - "traceSendData", - span -> { - setInvocationAttributes(span, invocationContext, eventId); - - ImmutableList safeData = - Optional.ofNullable(data).orElse(ImmutableList.of()).stream() - .filter(Objects::nonNull) - .collect(toImmutableList()); - setJsonAttribute(span, ADK_DATA, safeData); - }); + Span span, InvocationContext invocationContext, String eventId, List data) { + if (!span.getSpanContext().isValid()) { + log.trace("traceSendData: No valid span in current context."); + return; + } + setInvocationAttributes(span, invocationContext, eventId); + span.setAttribute(GEN_AI_OPERATION_NAME, SEND_DATA_OPERATION); + + ImmutableList safeData = + Optional.ofNullable(data).orElse(ImmutableList.of()).stream() + .filter(Objects::nonNull) + .collect(toImmutableList()); + setJsonAttribute(span, ADK_DATA, safeData); + } + + /** + * Traces merged tool call events. + * + *

Calling this function is not needed for telemetry purposes. This is provided for preventing + * /debug/trace requests (typically sent by web UI). + * + * @param responseEventId The ID of the response event. + * @param functionResponseEvent The merged response event. + */ + public static void traceMergedToolCalls( + Span span, String responseEventId, Event functionResponseEvent) { + if (!span.getSpanContext().isValid()) { + log.trace("traceMergedToolCalls: No valid span in current context."); + return; + } + span.setAttribute(GEN_AI_OPERATION_NAME, EXECUTE_TOOL_OPERATION); + span.setAttribute(GEN_AI_TOOL_NAME, "(merged tools)"); + span.setAttribute(GEN_AI_TOOL_DESCRIPTION, "(merged tools)"); + span.setAttribute(GEN_AI_TOOL_CALL_ID, responseEventId); + span.setAttribute(ADK_TOOL_CALL_ARGS, "N/A"); + span.setAttribute(ADK_EVENT_ID, responseEventId); + setJsonAttribute(span, ADK_TOOL_RESPONSE, functionResponseEvent); + span.setAttribute(ADK_LLM_REQUEST, "{}"); + span.setAttribute(ADK_LLM_RESPONSE, "{}"); } /** diff --git a/core/src/test/java/com/google/adk/agents/LlmAgentTest.java b/core/src/test/java/com/google/adk/agents/LlmAgentTest.java index a9e7a6f8d..e40a83aa0 100644 --- a/core/src/test/java/com/google/adk/agents/LlmAgentTest.java +++ b/core/src/test/java/com/google/adk/agents/LlmAgentTest.java @@ -494,12 +494,10 @@ public void runAsync_withTools_createsToolSpans() throws InterruptedException { List spans = openTelemetryRule.getSpans(); SpanData agentSpan = findSpanByName(spans, "invoke_agent test agent"); List llmSpans = findSpansByName(spans, "call_llm"); - List toolCallSpans = findSpansByName(spans, "tool_call [echo_tool]"); - List toolResponseSpans = findSpansByName(spans, "tool_response [echo_tool]"); + List toolSpans = findSpansByName(spans, "execute_tool [echo_tool]"); assertThat(llmSpans).hasSize(2); - assertThat(toolCallSpans).hasSize(1); - assertThat(toolResponseSpans).hasSize(1); + assertThat(toolSpans).hasSize(1); String agentSpanId = agentSpan.getSpanContext().getSpanId(); llmSpans.forEach(s -> assertEquals(agentSpanId, s.getParentSpanContext().getSpanId())); @@ -507,9 +505,7 @@ public void runAsync_withTools_createsToolSpans() throws InterruptedException { // The tool calls and responses are children of the first LLM call that produced the function // call. String firstLlmSpanId = llmSpans.get(0).getSpanContext().getSpanId(); - toolCallSpans.forEach(s -> assertEquals(firstLlmSpanId, s.getParentSpanContext().getSpanId())); - toolResponseSpans.forEach( - s -> assertEquals(firstLlmSpanId, s.getParentSpanContext().getSpanId())); + toolSpans.forEach(s -> assertEquals(firstLlmSpanId, s.getParentSpanContext().getSpanId())); } @Test diff --git a/core/src/test/java/com/google/adk/runner/RunnerTest.java b/core/src/test/java/com/google/adk/runner/RunnerTest.java index efd565c16..b68b6ff5f 100644 --- a/core/src/test/java/com/google/adk/runner/RunnerTest.java +++ b/core/src/test/java/com/google/adk/runner/RunnerTest.java @@ -1061,21 +1061,16 @@ public void runAsync_createsToolSpansWithCorrectParent() { List spans = openTelemetryRule.getSpans(); List llmSpans = spans.stream().filter(s -> s.getName().equals("call_llm")).toList(); - List toolCallSpans = - spans.stream().filter(s -> s.getName().equals("tool_call [echo_tool]")).toList(); - List toolResponseSpans = - spans.stream().filter(s -> s.getName().equals("tool_response [echo_tool]")).toList(); + List toolSpans = + spans.stream().filter(s -> s.getName().equals("execute_tool [echo_tool]")).toList(); assertThat(llmSpans).hasSize(2); - assertThat(toolCallSpans).hasSize(1); - assertThat(toolResponseSpans).hasSize(1); + assertThat(toolSpans).hasSize(1); List llmSpanIds = llmSpans.stream().map(s -> s.getSpanContext().getSpanId()).toList(); - String toolCallParentId = toolCallSpans.get(0).getParentSpanContext().getSpanId(); - String toolResponseParentId = toolResponseSpans.get(0).getParentSpanContext().getSpanId(); + String toolParentId = toolSpans.get(0).getParentSpanContext().getSpanId(); - assertThat(toolCallParentId).isEqualTo(toolResponseParentId); - assertThat(llmSpanIds).contains(toolCallParentId); + assertThat(llmSpanIds).contains(toolParentId); } @Test @@ -1101,22 +1096,17 @@ public void runLive_createsToolSpansWithCorrectParent() throws Exception { List spans = openTelemetryRule.getSpans(); List llmSpans = spans.stream().filter(s -> s.getName().equals("call_llm")).toList(); - List toolCallSpans = - spans.stream().filter(s -> s.getName().equals("tool_call [echo_tool]")).toList(); - List toolResponseSpans = - spans.stream().filter(s -> s.getName().equals("tool_response [echo_tool]")).toList(); + List toolSpans = + spans.stream().filter(s -> s.getName().equals("execute_tool [echo_tool]")).toList(); // In runLive, there is one call_llm span for the execution assertThat(llmSpans).hasSize(1); - assertThat(toolCallSpans).hasSize(1); - assertThat(toolResponseSpans).hasSize(1); + assertThat(toolSpans).hasSize(1); List llmSpanIds = llmSpans.stream().map(s -> s.getSpanContext().getSpanId()).toList(); - String toolCallParentId = toolCallSpans.get(0).getParentSpanContext().getSpanId(); - String toolResponseParentId = toolResponseSpans.get(0).getParentSpanContext().getSpanId(); + String toolParentId = toolSpans.get(0).getParentSpanContext().getSpanId(); - assertThat(toolCallParentId).isEqualTo(toolResponseParentId); - assertThat(llmSpanIds).contains(toolCallParentId); + assertThat(llmSpanIds).contains(toolParentId); } @Test diff --git a/core/src/test/java/com/google/adk/telemetry/ContextPropagationTest.java b/core/src/test/java/com/google/adk/telemetry/ContextPropagationTest.java index b13904934..44877e972 100644 --- a/core/src/test/java/com/google/adk/telemetry/ContextPropagationTest.java +++ b/core/src/test/java/com/google/adk/telemetry/ContextPropagationTest.java @@ -18,7 +18,6 @@ import static com.google.common.truth.Truth.assertThat; import static org.junit.Assert.assertEquals; -import static org.junit.Assert.assertFalse; import static org.junit.Assert.assertTrue; import com.google.adk.agents.BaseAgent; @@ -100,242 +99,6 @@ public void tearDown() { Tracing.setTracerForTesting(originalTracer); } - @Test - public void testToolCallSpanLinksToParent() { - // Given: Parent span is active - Span parentSpan = tracer.spanBuilder("parent").startSpan(); - - try (Scope scope = parentSpan.makeCurrent()) { - // When: ADK creates tool_call span with setParent(Context.current()) - Span toolCallSpan = - tracer.spanBuilder("tool_call [testTool]").setParent(Context.current()).startSpan(); - - try (Scope toolScope = toolCallSpan.makeCurrent()) { - // Simulate tool execution - } finally { - toolCallSpan.end(); - } - } finally { - parentSpan.end(); - } - - // Then: tool_call should be child of parent - SpanData parentSpanData = findSpanByName("parent"); - SpanData toolCallSpanData = findSpanByName("tool_call [testTool]"); - - // Verify parent-child relationship - assertEquals( - "Tool call should have same trace ID as parent", - parentSpanData.getSpanContext().getTraceId(), - toolCallSpanData.getSpanContext().getTraceId()); - - assertParent(parentSpanData, toolCallSpanData); - } - - @Test - public void testToolCallWithoutParentCreatesRootSpan() { - // Given: No parent span active - // When: ADK creates tool_call span with setParent(Context.current()) - try (Scope s = Context.root().makeCurrent()) { - Span toolCallSpan = - tracer.spanBuilder("tool_call [testTool]").setParent(Context.current()).startSpan(); - - try (Scope scope = toolCallSpan.makeCurrent()) { - // Work - } finally { - toolCallSpan.end(); - } - } - - // Then: Should create root span (backward compatible) - List spans = openTelemetryRule.getSpans(); - assertThat(spans).hasSize(1); - - SpanData toolCallSpanData = spans.get(0); - assertFalse( - "Tool call should be root span when no parent exists", - toolCallSpanData.getParentSpanContext().isValid()); - } - - @Test - public void testNestedSpanHierarchy() { - // Test: parent → invocation → tool_call → tool_response hierarchy - - Span parentSpan = tracer.spanBuilder("parent").startSpan(); - - try (Scope parentScope = parentSpan.makeCurrent()) { - - Span invocationSpan = - tracer.spanBuilder("invocation").setParent(Context.current()).startSpan(); - - try (Scope invocationScope = invocationSpan.makeCurrent()) { - - Span toolCallSpan = - tracer.spanBuilder("tool_call [testTool]").setParent(Context.current()).startSpan(); - - try (Scope toolScope = toolCallSpan.makeCurrent()) { - - Span toolResponseSpan = - tracer - .spanBuilder("tool_response [testTool]") - .setParent(Context.current()) - .startSpan(); - - toolResponseSpan.end(); - } finally { - toolCallSpan.end(); - } - } finally { - invocationSpan.end(); - } - } finally { - parentSpan.end(); - } - - // Verify complete hierarchy - List spans = openTelemetryRule.getSpans(); - // The 4 spans are: "parent", "invocation", "tool_call [testTool]", and "tool_response - // [testTool]". - assertThat(spans).hasSize(4); - - SpanData parentSpanData = findSpanByName("parent"); - String parentTraceId = parentSpanData.getSpanContext().getTraceId(); - - // All spans should have same trace ID - for (SpanData span : openTelemetryRule.getSpans()) { - assertEquals( - "All spans should be in same trace", parentTraceId, span.getSpanContext().getTraceId()); - } - - // Verify parent-child relationships - SpanData invocationSpanData = findSpanByName("invocation"); - SpanData toolCallSpanData = findSpanByName("tool_call [testTool]"); - SpanData toolResponseSpanData = findSpanByName("tool_response [testTool]"); - - // invocation should be child of parent - assertParent(parentSpanData, invocationSpanData); - - // tool_call should be child of invocation - assertParent(invocationSpanData, toolCallSpanData); - - // tool_response should be child of tool_call - assertParent(toolCallSpanData, toolResponseSpanData); - } - - @Test - public void testMultipleSpansInParallel() { - // Test: Multiple tool calls in parallel should all link to same parent - - Span parentSpan = tracer.spanBuilder("parent").startSpan(); - - try (Scope parentScope = parentSpan.makeCurrent()) { - // Simulate parallel tool calls - Span toolCall1 = - tracer.spanBuilder("tool_call [tool1]").setParent(Context.current()).startSpan(); - Span toolCall2 = - tracer.spanBuilder("tool_call [tool2]").setParent(Context.current()).startSpan(); - Span toolCall3 = - tracer.spanBuilder("tool_call [tool3]").setParent(Context.current()).startSpan(); - - toolCall1.end(); - toolCall2.end(); - toolCall3.end(); - } finally { - parentSpan.end(); - } - - // Verify all tool calls link to same parent - SpanData parentSpanData = findSpanByName("parent"); - String parentTraceId = parentSpanData.getSpanContext().getTraceId(); - - // All tool calls should have same trace ID and parent span ID - List toolCallSpans = - openTelemetryRule.getSpans().stream() - .filter(s -> s.getName().startsWith("tool_call")) - .toList(); - - assertThat(toolCallSpans).hasSize(3); - - toolCallSpans.forEach( - span -> { - assertEquals( - "Tool call should have same trace ID as parent", - parentTraceId, - span.getSpanContext().getTraceId()); - assertParent(parentSpanData, span); - }); - } - - @Test - public void testInvokeAgentSpanLinksToInvocation() { - // Test: invoke_agent span should link to invocation span - - Span invocationSpan = tracer.spanBuilder("invocation").startSpan(); - - try (Scope invocationScope = invocationSpan.makeCurrent()) { - Span invokeAgentSpan = - tracer.spanBuilder("invoke_agent test-agent").setParent(Context.current()).startSpan(); - - try (Scope agentScope = invokeAgentSpan.makeCurrent()) { - // Simulate agent work - } finally { - invokeAgentSpan.end(); - } - } finally { - invocationSpan.end(); - } - - SpanData invocationSpanData = findSpanByName("invocation"); - SpanData invokeAgentSpanData = findSpanByName("invoke_agent test-agent"); - - assertParent(invocationSpanData, invokeAgentSpanData); - } - - @Test - public void testCallLlmSpanLinksToAgentRun() { - // Test: call_llm span should link to agent_run span - - Span invokeAgentSpan = tracer.spanBuilder("invoke_agent test-agent").startSpan(); - - try (Scope agentScope = invokeAgentSpan.makeCurrent()) { - Span callLlmSpan = tracer.spanBuilder("call_llm").setParent(Context.current()).startSpan(); - - try (Scope llmScope = callLlmSpan.makeCurrent()) { - // Simulate LLM call - } finally { - callLlmSpan.end(); - } - } finally { - invokeAgentSpan.end(); - } - - List spans = openTelemetryRule.getSpans(); - assertThat(spans).hasSize(2); - - SpanData invokeAgentSpanData = findSpanByName("invoke_agent test-agent"); - SpanData callLlmSpanData = findSpanByName("call_llm"); - - assertParent(invokeAgentSpanData, callLlmSpanData); - } - - @Test - public void testSpanCreatedWithinParentScopeIsCorrectlyParented() { - // Test: Simulates creating a span within the scope of a parent - - Span parentSpan = tracer.spanBuilder("invocation").startSpan(); - try (Scope scope = parentSpan.makeCurrent()) { - Span agentSpan = tracer.spanBuilder("invoke_agent").setParent(Context.current()).startSpan(); - agentSpan.end(); - } finally { - parentSpan.end(); - } - - SpanData parentSpanData = findSpanByName("invocation"); - SpanData agentSpanData = findSpanByName("invoke_agent"); - - assertParent(parentSpanData, agentSpanData); - } - @Test public void testTraceFlowable() throws InterruptedException { Span parentSpan = tracer.spanBuilder("parent").startSpan(); @@ -475,8 +238,14 @@ public void testTraceAgentInvocation() { public void testTraceToolCall() { Span span = tracer.spanBuilder("test").startSpan(); try (Scope scope = span.makeCurrent()) { - Tracing.traceToolCall( - "tool-name", "tool-description", "tool-type", ImmutableMap.of("arg1", "value1")); + Tracing.traceToolExecution( + span, + "tool-name", + "tool-description", + "tool-type", + ImmutableMap.of("arg1", "value1"), + null, + null); } finally { span.end(); } @@ -513,7 +282,14 @@ public void testTraceToolResponse() { .build()) .build())) .build(); - Tracing.traceToolResponse("event-1", functionResponseEvent); + Tracing.traceToolExecution( + span, + "tool-name", + "tool-description", + "tool-type", + ImmutableMap.of(), + functionResponseEvent, + null); } finally { span.end(); } @@ -524,6 +300,10 @@ public void testTraceToolResponse() { assertEquals("execute_tool", attrs.get(AttributeKey.stringKey("gen_ai.operation.name"))); assertEquals("event-1", attrs.get(AttributeKey.stringKey("gcp.vertex.agent.event_id"))); assertEquals("tool-call-id", attrs.get(AttributeKey.stringKey("gen_ai.tool_call.id"))); + assertEquals("tool-name", attrs.get(AttributeKey.stringKey("gen_ai.tool.name"))); + assertEquals("tool-description", attrs.get(AttributeKey.stringKey("gen_ai.tool.description"))); + assertEquals("tool-type", attrs.get(AttributeKey.stringKey("gen_ai.tool.type"))); + assertEquals("{}", attrs.get(AttributeKey.stringKey("gcp.vertex.agent.tool_call_args"))); assertEquals( "{\"result\":\"tool-result\"}", attrs.get(AttributeKey.stringKey("gcp.vertex.agent.tool_response"))); @@ -550,7 +330,8 @@ public void testTraceCallLlm() { .totalTokenCount(30) .build()) .build(); - Tracing.traceCallLlm(span, buildInvocationContext(), "event-1", llmRequest, llmResponse); + Tracing.traceCallLlm( + span, buildInvocationContext(), "event-1", llmRequest, llmResponse, null); } finally { span.end(); } @@ -559,6 +340,7 @@ public void testTraceCallLlm() { SpanData spanData = spans.get(0); Attributes attrs = spanData.getAttributes(); assertEquals("gcp.vertex.agent", attrs.get(AttributeKey.stringKey("gen_ai.system"))); + assertEquals("call_llm", attrs.get(AttributeKey.stringKey("gen_ai.operation.name"))); assertEquals("gemini-pro", attrs.get(AttributeKey.stringKey("gen_ai.request.model"))); assertEquals( "test-invocation-id", attrs.get(AttributeKey.stringKey("gcp.vertex.agent.invocation_id"))); @@ -581,6 +363,7 @@ public void testTraceSendData() { Span span = tracer.spanBuilder("test").startSpan(); try (Scope scope = span.makeCurrent()) { Tracing.traceSendData( + span, buildInvocationContext(), "event-1", ImmutableList.of(Content.fromParts(Part.fromText("hello")))); @@ -591,6 +374,7 @@ public void testTraceSendData() { assertThat(spans).hasSize(1); SpanData spanData = spans.get(0); Attributes attrs = spanData.getAttributes(); + assertEquals("send_data", attrs.get(AttributeKey.stringKey("gen_ai.operation.name"))); assertEquals( "test-invocation-id", attrs.get(AttributeKey.stringKey("gcp.vertex.agent.invocation_id"))); assertEquals("event-1", attrs.get(AttributeKey.stringKey("gcp.vertex.agent.event_id"))); @@ -687,8 +471,7 @@ public void testAgentWithToolCallTraceHierarchy() throws InterruptedException { // invocation // └── invoke_agent test_agent // ├── call_llm - // │ ├── tool_call [search_flights] - // │ └── tool_response [search_flights] + // │ └── execute_tool [search_flights] // └── call_llm SearchFlightsTool searchFlightsTool = new SearchFlightsTool(); @@ -716,8 +499,7 @@ public void testAgentWithToolCallTraceHierarchy() throws InterruptedException { SpanData invocation = findSpanByName("invocation"); SpanData invokeAgent = findSpanByName("invoke_agent test_agent"); - SpanData toolCall = findSpanByName("tool_call [search_flights]"); - SpanData toolResponse = findSpanByName("tool_response [search_flights]"); + SpanData toolResponse = findSpanByName("execute_tool [search_flights]"); List callLlmSpans = openTelemetryRule.getSpans().stream() .filter(s -> s.getName().equals("call_llm")) @@ -733,12 +515,28 @@ public void testAgentWithToolCallTraceHierarchy() throws InterruptedException { assertParent(invocation, invokeAgent); // ├── call_llm 1 assertParent(invokeAgent, callLlm1); - // │ ├── tool_call [search_flights] - assertParent(callLlm1, toolCall); - // │ └── tool_response [search_flights] + // │ └── execute_tool [search_flights] assertParent(callLlm1, toolResponse); // └── call_llm 2 assertParent(invokeAgent, callLlm2); + + // Assert attributes + assertEquals( + "invoke_agent", + invokeAgent.getAttributes().get(AttributeKey.stringKey("gen_ai.operation.name"))); + assertEquals( + "call_llm", callLlm1.getAttributes().get(AttributeKey.stringKey("gen_ai.operation.name"))); + assertEquals( + "execute_tool", + toolResponse.getAttributes().get(AttributeKey.stringKey("gen_ai.operation.name"))); + assertEquals( + "search_flights", + toolResponse.getAttributes().get(AttributeKey.stringKey("gen_ai.tool.name"))); + assertEquals( + "execute_tool", + toolResponse.getAttributes().get(AttributeKey.stringKey("gen_ai.operation.name"))); + assertEquals( + "call_llm", callLlm2.getAttributes().get(AttributeKey.stringKey("gen_ai.operation.name"))); } @Test @@ -748,8 +546,7 @@ public void testNestedAgentTraceHierarchy() throws InterruptedException { // invocation // └── invoke_agent AgentA // ├── call_llm - // │ ├── tool_call [transfer_to_agent] - // │ └── tool_response [transfer_to_agent] + // │ └── execute_tool [transfer_to_agent] // └── invoke_agent AgentB // └── call_llm TestLlm llm = @@ -776,9 +573,8 @@ public void testNestedAgentTraceHierarchy() throws InterruptedException { SpanData invocation = findSpanByName("invocation"); SpanData agentASpan = findSpanByName("invoke_agent AgentA"); - SpanData toolCall = findSpanByName("tool_call [transfer_to_agent]"); + SpanData executeTool = findSpanByName("execute_tool [transfer_to_agent]"); SpanData agentBSpan = findSpanByName("invoke_agent AgentB"); - SpanData toolResponse = findSpanByName("tool_response [transfer_to_agent]"); List callLlmSpans = openTelemetryRule.getSpans().stream() @@ -792,8 +588,7 @@ public void testNestedAgentTraceHierarchy() throws InterruptedException { assertParent(invocation, agentASpan); assertParent(agentASpan, agentACallLlm1); - assertParent(agentACallLlm1, toolCall); - assertParent(agentACallLlm1, toolResponse); + assertParent(agentACallLlm1, executeTool); assertParent(agentASpan, agentBSpan); assertParent(agentBSpan, agentBCallLlm); }