diff --git a/contrib/firestore-session-service/pom.xml b/contrib/firestore-session-service/pom.xml index ed1ecd09b..34b577984 100644 --- a/contrib/firestore-session-service/pom.xml +++ b/contrib/firestore-session-service/pom.xml @@ -14,7 +14,9 @@ See the License for the specific language governing permissions and limitations under the License. --> - + 4.0.0 @@ -49,7 +51,6 @@ com.google.cloud google-cloud-firestore - 3.30.3 com.google.truth diff --git a/contrib/langchain4j/src/main/java/com/google/adk/models/langchain4j/LangChain4j.java b/contrib/langchain4j/src/main/java/com/google/adk/models/langchain4j/LangChain4j.java index 3ccb1e029..97331e7b4 100644 --- a/contrib/langchain4j/src/main/java/com/google/adk/models/langchain4j/LangChain4j.java +++ b/contrib/langchain4j/src/main/java/com/google/adk/models/langchain4j/LangChain4j.java @@ -23,6 +23,7 @@ import com.google.adk.models.BaseLlmConnection; import com.google.adk.models.LlmRequest; import com.google.adk.models.LlmResponse; +import com.google.auto.value.AutoValue; import com.google.genai.types.Blob; import com.google.genai.types.Content; import com.google.genai.types.FunctionCall; @@ -30,11 +31,11 @@ import com.google.genai.types.FunctionDeclaration; import com.google.genai.types.FunctionResponse; import com.google.genai.types.GenerateContentConfig; +import com.google.genai.types.GenerateContentResponseUsageMetadata; import com.google.genai.types.Part; import com.google.genai.types.Schema; import com.google.genai.types.ToolConfig; import com.google.genai.types.Type; -import dev.langchain4j.Experimental; import dev.langchain4j.agent.tool.ToolExecutionRequest; import dev.langchain4j.agent.tool.ToolSpecification; import dev.langchain4j.data.audio.Audio; @@ -52,6 +53,7 @@ import dev.langchain4j.data.pdf.PdfFile; import dev.langchain4j.data.video.Video; import dev.langchain4j.exception.UnsupportedFeatureException; +import dev.langchain4j.model.TokenCountEstimator; import dev.langchain4j.model.chat.ChatModel; import dev.langchain4j.model.chat.StreamingChatModel; import dev.langchain4j.model.chat.request.ChatRequest; @@ -65,6 +67,7 @@ import dev.langchain4j.model.chat.request.json.JsonStringSchema; import dev.langchain4j.model.chat.response.ChatResponse; import dev.langchain4j.model.chat.response.StreamingChatResponseHandler; +import dev.langchain4j.model.output.TokenUsage; import io.reactivex.rxjava3.core.BackpressureStrategy; import io.reactivex.rxjava3.core.Flowable; import java.util.ArrayList; @@ -72,66 +75,101 @@ import java.util.HashMap; import java.util.List; import java.util.Map; -import java.util.Objects; import java.util.UUID; +import org.jspecify.annotations.Nullable; -@Experimental -public class LangChain4j extends BaseLlm { +@AutoValue +public abstract class LangChain4j extends BaseLlm { private static final TypeReference> MAP_TYPE_REFERENCE = new TypeReference<>() {}; - private final ChatModel chatModel; - private final StreamingChatModel streamingChatModel; - private final ObjectMapper objectMapper; + LangChain4j() { + super(""); + } + + @Nullable + public abstract ChatModel chatModel(); + + @Nullable + public abstract StreamingChatModel streamingChatModel(); + + public abstract ObjectMapper objectMapper(); + + public abstract String modelName(); + + @Nullable + public abstract TokenCountEstimator tokenCountEstimator(); + + @Override + public String model() { + return modelName(); + } + + public static Builder builder() { + return new AutoValue_LangChain4j.Builder().objectMapper(new ObjectMapper()); + } + + @AutoValue.Builder + public abstract static class Builder { + public abstract Builder chatModel(ChatModel chatModel); + + public abstract Builder streamingChatModel(StreamingChatModel streamingChatModel); + + public abstract Builder tokenCountEstimator(TokenCountEstimator tokenCountEstimator); + + public abstract Builder objectMapper(ObjectMapper objectMapper); + + public abstract Builder modelName(String modelName); + + public abstract LangChain4j build(); + } public LangChain4j(ChatModel chatModel) { - super( - Objects.requireNonNull( - chatModel.defaultRequestParameters().modelName(), "chat model name cannot be null")); - this.chatModel = Objects.requireNonNull(chatModel, "chatModel cannot be null"); - this.streamingChatModel = null; - this.objectMapper = new ObjectMapper(); + this(chatModel, null, null, chatModel.defaultRequestParameters().modelName(), null); } public LangChain4j(ChatModel chatModel, String modelName) { - super(Objects.requireNonNull(modelName, "chat model name cannot be null")); - this.chatModel = Objects.requireNonNull(chatModel, "chatModel cannot be null"); - this.streamingChatModel = null; - this.objectMapper = new ObjectMapper(); + this(chatModel, null, null, modelName, null); } public LangChain4j(StreamingChatModel streamingChatModel) { - super( - Objects.requireNonNull( - streamingChatModel.defaultRequestParameters().modelName(), - "streaming chat model name cannot be null")); - this.chatModel = null; - this.streamingChatModel = - Objects.requireNonNull(streamingChatModel, "streamingChatModel cannot be null"); - this.objectMapper = new ObjectMapper(); + this( + null, + streamingChatModel, + null, + streamingChatModel.defaultRequestParameters().modelName(), + null); } public LangChain4j(StreamingChatModel streamingChatModel, String modelName) { - super(Objects.requireNonNull(modelName, "streaming chat model name cannot be null")); - this.chatModel = null; - this.streamingChatModel = - Objects.requireNonNull(streamingChatModel, "streamingChatModel cannot be null"); - this.objectMapper = new ObjectMapper(); + this(null, streamingChatModel, null, modelName, null); } public LangChain4j(ChatModel chatModel, StreamingChatModel streamingChatModel, String modelName) { - super(Objects.requireNonNull(modelName, "model name cannot be null")); - this.chatModel = Objects.requireNonNull(chatModel, "chatModel cannot be null"); - this.streamingChatModel = - Objects.requireNonNull(streamingChatModel, "streamingChatModel cannot be null"); - this.objectMapper = new ObjectMapper(); + this(chatModel, streamingChatModel, null, modelName, null); + } + + private LangChain4j( + ChatModel chatModel, + StreamingChatModel streamingChatModel, + ObjectMapper objectMapper, + String modelName, + TokenCountEstimator tokenCountEstimator) { + this(); + LangChain4j.builder() + .chatModel(chatModel) + .streamingChatModel(streamingChatModel) + .objectMapper(objectMapper) + .modelName(modelName) + .tokenCountEstimator(tokenCountEstimator) + .build(); } @Override public Flowable generateContent(LlmRequest llmRequest, boolean stream) { if (stream) { - if (this.streamingChatModel == null) { + if (this.streamingChatModel() == null) { return Flowable.error(new IllegalStateException("StreamingChatModel is not configured")); } @@ -139,54 +177,57 @@ public Flowable generateContent(LlmRequest llmRequest, boolean stre return Flowable.create( emitter -> { - streamingChatModel.chat( - chatRequest, - new StreamingChatResponseHandler() { - @Override - public void onPartialResponse(String s) { - emitter.onNext( - LlmResponse.builder().content(Content.fromParts(Part.fromText(s))).build()); - } - - @Override - public void onCompleteResponse(ChatResponse chatResponse) { - if (chatResponse.aiMessage().hasToolExecutionRequests()) { - AiMessage aiMessage = chatResponse.aiMessage(); - toParts(aiMessage).stream() - .map(Part::functionCall) - .forEach( - functionCall -> { - functionCall.ifPresent( - function -> { - emitter.onNext( - LlmResponse.builder() - .content( - Content.fromParts( - Part.fromFunctionCall( - function.name().orElse(""), - function.args().orElse(Map.of())))) - .build()); - }); - }); - } - emitter.onComplete(); - } - - @Override - public void onError(Throwable throwable) { - emitter.onError(throwable); - } - }); + streamingChatModel() + .chat( + chatRequest, + new StreamingChatResponseHandler() { + @Override + public void onPartialResponse(String s) { + emitter.onNext( + LlmResponse.builder() + .content(Content.fromParts(Part.fromText(s))) + .build()); + } + + @Override + public void onCompleteResponse(ChatResponse chatResponse) { + if (chatResponse.aiMessage().hasToolExecutionRequests()) { + AiMessage aiMessage = chatResponse.aiMessage(); + toParts(aiMessage).stream() + .map(Part::functionCall) + .forEach( + functionCall -> { + functionCall.ifPresent( + function -> { + emitter.onNext( + LlmResponse.builder() + .content( + Content.fromParts( + Part.fromFunctionCall( + function.name().orElse(""), + function.args().orElse(Map.of())))) + .build()); + }); + }); + } + emitter.onComplete(); + } + + @Override + public void onError(Throwable throwable) { + emitter.onError(throwable); + } + }); }, BackpressureStrategy.BUFFER); } else { - if (this.chatModel == null) { + if (this.chatModel() == null) { return Flowable.error(new IllegalStateException("ChatModel is not configured")); } ChatRequest chatRequest = toChatRequest(llmRequest); - ChatResponse chatResponse = chatModel.chat(chatRequest); - LlmResponse llmResponse = toLlmResponse(chatResponse); + ChatResponse chatResponse = chatModel().chat(chatRequest); + LlmResponse llmResponse = toLlmResponse(chatResponse, chatRequest); return Flowable.just(llmResponse); } @@ -413,7 +454,7 @@ private AiMessage toAiMessage(Content content) { private String toJson(Object object) { try { - return objectMapper.writeValueAsString(object); + return objectMapper().writeValueAsString(object); } catch (JsonProcessingException e) { throw new RuntimeException(e); } @@ -511,11 +552,38 @@ private JsonSchemaElement toJsonSchemaElement(Schema schema) { } } - private LlmResponse toLlmResponse(ChatResponse chatResponse) { + private LlmResponse toLlmResponse(ChatResponse chatResponse, ChatRequest chatRequest) { Content content = Content.builder().role("model").parts(toParts(chatResponse.aiMessage())).build(); - return LlmResponse.builder().content(content).build(); + LlmResponse.Builder builder = LlmResponse.builder().content(content); + TokenUsage tokenUsage = chatResponse.tokenUsage(); + if (tokenCountEstimator() != null) { + try { + int estimatedInput = + tokenCountEstimator().estimateTokenCountInMessages(chatRequest.messages()); + int estimatedOutput = + tokenCountEstimator().estimateTokenCountInText(chatResponse.aiMessage().text()); + int estimatedTotal = estimatedInput + estimatedOutput; + builder.usageMetadata( + GenerateContentResponseUsageMetadata.builder() + .promptTokenCount(estimatedInput) + .candidatesTokenCount(estimatedOutput) + .totalTokenCount(estimatedTotal) + .build()); + } catch (Exception e) { + e.printStackTrace(); + } + } else if (tokenUsage != null) { + builder.usageMetadata( + GenerateContentResponseUsageMetadata.builder() + .promptTokenCount(tokenUsage.inputTokenCount()) + .candidatesTokenCount(tokenUsage.outputTokenCount()) + .totalTokenCount(tokenUsage.totalTokenCount()) + .build()); + } + + return builder.build(); } private List toParts(AiMessage aiMessage) { @@ -539,14 +607,17 @@ private List toParts(AiMessage aiMessage) { }); return parts; } else { - Part part = Part.builder().text(aiMessage.text()).build(); - return List.of(part); + String text = aiMessage.text(); + if (text == null) { + return List.of(); + } + return List.of(Part.builder().text(text).build()); } } private Map toArgs(ToolExecutionRequest toolExecutionRequest) { try { - return objectMapper.readValue(toolExecutionRequest.arguments(), MAP_TYPE_REFERENCE); + return objectMapper().readValue(toolExecutionRequest.arguments(), MAP_TYPE_REFERENCE); } catch (JsonProcessingException e) { throw new RuntimeException(e); } diff --git a/contrib/langchain4j/src/test/java/com/google/adk/models/langchain4j/LangChain4jIntegrationTest.java b/contrib/langchain4j/src/test/java/com/google/adk/models/langchain4j/LangChain4jIntegrationTest.java index 191e48017..5b6d3f3ad 100644 --- a/contrib/langchain4j/src/test/java/com/google/adk/models/langchain4j/LangChain4jIntegrationTest.java +++ b/contrib/langchain4j/src/test/java/com/google/adk/models/langchain4j/LangChain4jIntegrationTest.java @@ -62,7 +62,8 @@ void testSimpleAgent() { LlmAgent.builder() .name("science-app") .description("Science teacher agent") - .model(new LangChain4j(claudeModel, CLAUDE_4_6_SONNET)) + .model( + LangChain4j.builder().chatModel(claudeModel).modelName(CLAUDE_4_6_SONNET).build()) .instruction( """ You are a helpful science teacher that explains science concepts @@ -98,7 +99,8 @@ void testSingleAgentWithTools() { LlmAgent.builder() .name("friendly-weather-app") .description("Friend agent that knows about the weather") - .model(new LangChain4j(claudeModel, CLAUDE_4_6_SONNET)) + .model( + LangChain4j.builder().chatModel(claudeModel).modelName(CLAUDE_4_6_SONNET).build()) .instruction( """ You are a friendly assistant. @@ -183,7 +185,7 @@ void testAgentTool() { LlmAgent.builder() .name("friendly-weather-app") .description("Friend agent that knows about the weather") - .model(new LangChain4j(gptModel)) + .model(LangChain4j.builder().chatModel(gptModel).modelName(GPT_4_O_MINI).build()) .instruction( """ You are a friendly assistant. @@ -246,7 +248,7 @@ void testSubAgent() { LlmAgent.builder() .name("greeterAgent") .description("Friendly agent that greets users") - .model(new LangChain4j(gptModel)) + .model(LangChain4j.builder().chatModel(gptModel).modelName(GPT_4_O_MINI).build()) .instruction( """ You are a friendly that greets users. @@ -257,7 +259,7 @@ void testSubAgent() { LlmAgent.builder() .name("farewellAgent") .description("Friendly agent that says goodbye to users") - .model(new LangChain4j(gptModel)) + .model(LangChain4j.builder().chatModel(gptModel).modelName(GPT_4_O_MINI).build()) .instruction( """ You are a friendly that says goodbye to users. @@ -355,7 +357,11 @@ void testSimpleStreamingResponse() { .modelName(CLAUDE_4_6_SONNET) .build(); - LangChain4j lc4jClaude = new LangChain4j(claudeStreamingModel, CLAUDE_4_6_SONNET); + LangChain4j lc4jClaude = + LangChain4j.builder() + .streamingChatModel(claudeStreamingModel) + .modelName(CLAUDE_4_6_SONNET) + .build(); // when Flowable responses = @@ -413,7 +419,11 @@ void testStreamingRunConfig() { When someone greets you, respond with "Hello". If someone asks about the weather, call the `getWeather` function. """) - .model(new LangChain4j(streamingModel, "GPT_4_O_MINI")) + .model( + LangChain4j.builder() + .streamingChatModel(streamingModel) + .modelName("GPT_4_O_MINI") + .build()) // .model(new LangChain4j(streamingModel, // CLAUDE_3_7_SONNET_20250219)) .tools(FunctionTool.create(ToolExample.class, "getWeather")) diff --git a/contrib/langchain4j/src/test/java/com/google/adk/models/langchain4j/LangChain4jTest.java b/contrib/langchain4j/src/test/java/com/google/adk/models/langchain4j/LangChain4jTest.java index 076bb79a3..a1ec7a3c2 100644 --- a/contrib/langchain4j/src/test/java/com/google/adk/models/langchain4j/LangChain4jTest.java +++ b/contrib/langchain4j/src/test/java/com/google/adk/models/langchain4j/LangChain4jTest.java @@ -19,6 +19,7 @@ import static org.mockito.ArgumentMatchers.any; import static org.mockito.Mockito.*; +import com.fasterxml.jackson.databind.ObjectMapper; import com.google.adk.models.LlmRequest; import com.google.adk.models.LlmResponse; import com.google.adk.tools.FunctionTool; @@ -26,6 +27,7 @@ import dev.langchain4j.agent.tool.ToolExecutionRequest; import dev.langchain4j.data.message.AiMessage; import dev.langchain4j.data.message.UserMessage; +import dev.langchain4j.model.TokenCountEstimator; import dev.langchain4j.model.chat.ChatModel; import dev.langchain4j.model.chat.StreamingChatModel; import dev.langchain4j.model.chat.request.ChatRequest; @@ -33,6 +35,7 @@ import dev.langchain4j.model.chat.request.json.JsonStringSchema; import dev.langchain4j.model.chat.response.ChatResponse; import dev.langchain4j.model.chat.response.StreamingChatResponseHandler; +import dev.langchain4j.model.output.TokenUsage; import io.reactivex.rxjava3.core.Flowable; import java.util.ArrayList; import java.util.List; @@ -57,8 +60,26 @@ void setUp() { chatModel = mock(ChatModel.class); streamingChatModel = mock(StreamingChatModel.class); - langChain4j = new LangChain4j(chatModel, MODEL_NAME); - streamingLangChain4j = new LangChain4j(streamingChatModel, MODEL_NAME); + langChain4j = LangChain4j.builder().chatModel(chatModel).modelName(MODEL_NAME).build(); + streamingLangChain4j = + LangChain4j.builder().streamingChatModel(streamingChatModel).modelName(MODEL_NAME).build(); + } + + @Test + void testBuilder() { + ObjectMapper customMapper = new ObjectMapper(); + LangChain4j customLc4j = + LangChain4j.builder() + .chatModel(chatModel) + .streamingChatModel(streamingChatModel) + .objectMapper(customMapper) + .modelName("custom-model") + .build(); + + assertThat(customLc4j.chatModel()).isEqualTo(chatModel); + assertThat(customLc4j.streamingChatModel()).isEqualTo(streamingChatModel); + assertThat(customLc4j.objectMapper()).isEqualTo(customMapper); + assertThat(customLc4j.modelName()).isEqualTo("custom-model"); } @Test @@ -812,4 +833,164 @@ void testGenerateContentWithMcpToolParametersJsonSchemaAsSchema() { assertThat(capturedRequest.toolSpecifications().get(0).name()).isEqualTo("mcpTool"); assertThat(capturedRequest.toolSpecifications().get(0).description()).isEqualTo("An MCP tool"); } + + @Test + @DisplayName( + "Should use TokenCountEstimator to estimate token usage when TokenUsage is not available") + void testTokenCountEstimatorFallback() { + // Given + // Create a mock TokenCountEstimator + final TokenCountEstimator tokenCountEstimator = mock(TokenCountEstimator.class); + when(tokenCountEstimator.estimateTokenCountInMessages(any())).thenReturn(50); // Input tokens + when(tokenCountEstimator.estimateTokenCountInText(any())).thenReturn(20); // Output tokens + + // Create LangChain4j with the TokenCountEstimator using Builder + final LangChain4j langChain4jWithEstimator = + LangChain4j.builder() + .chatModel(chatModel) + .modelName(MODEL_NAME) + .tokenCountEstimator(tokenCountEstimator) + .build(); + + // Create a LlmRequest + final LlmRequest llmRequest = + LlmRequest.builder() + .contents(List.of(Content.fromParts(Part.fromText("What is the weather today?")))) + .build(); + + // Mock ChatResponse WITHOUT TokenUsage (simulating when LLM doesn't provide token counts) + final ChatResponse chatResponse = mock(ChatResponse.class); + final AiMessage aiMessage = AiMessage.from("The weather is sunny today."); + when(chatResponse.aiMessage()).thenReturn(aiMessage); + when(chatResponse.tokenUsage()).thenReturn(null); // No token usage from LLM + when(chatModel.chat(any(ChatRequest.class))).thenReturn(chatResponse); + + // When + final LlmResponse response = + langChain4jWithEstimator.generateContent(llmRequest, false).blockingFirst(); + + // Then + // Verify the response has usage metadata estimated by TokenCountEstimator + assertThat(response).isNotNull(); + assertThat(response.content()).isPresent(); + assertThat(response.content().get().text()).isEqualTo("The weather is sunny today."); + + // IMPORTANT: Verify that token usage was estimated via the TokenCountEstimator + assertThat(response.usageMetadata()).isPresent(); + final GenerateContentResponseUsageMetadata usageMetadata = response.usageMetadata().get(); + assertThat(usageMetadata.promptTokenCount()).isEqualTo(Optional.of(50)); // From estimator + assertThat(usageMetadata.candidatesTokenCount()).isEqualTo(Optional.of(20)); // From estimator + assertThat(usageMetadata.totalTokenCount()).isEqualTo(Optional.of(70)); // 50 + 20 + + // Verify the estimator was actually called + verify(tokenCountEstimator).estimateTokenCountInMessages(any()); + verify(tokenCountEstimator).estimateTokenCountInText("The weather is sunny today."); + } + + @Test + @DisplayName("Should prioritize TokenCountEstimator over TokenUsage when estimator is provided") + void testTokenCountEstimatorPriority() { + // Given + // Create a mock TokenCountEstimator + final TokenCountEstimator tokenCountEstimator = mock(TokenCountEstimator.class); + when(tokenCountEstimator.estimateTokenCountInMessages(any())).thenReturn(100); // From estimator + when(tokenCountEstimator.estimateTokenCountInText(any())).thenReturn(50); // From estimator + + // Create LangChain4j with the TokenCountEstimator using Builder + final LangChain4j langChain4jWithEstimator = + LangChain4j.builder() + .chatModel(chatModel) + .modelName(MODEL_NAME) + .tokenCountEstimator(tokenCountEstimator) + .build(); + + // Create a LlmRequest + final LlmRequest llmRequest = + LlmRequest.builder() + .contents(List.of(Content.fromParts(Part.fromText("What is the weather today?")))) + .build(); + + // Mock ChatResponse WITH actual TokenUsage from the LLM + final ChatResponse chatResponse = mock(ChatResponse.class); + final AiMessage aiMessage = AiMessage.from("The weather is sunny today."); + final TokenUsage actualTokenUsage = new TokenUsage(30, 15, 45); // Actual token counts from LLM + when(chatResponse.aiMessage()).thenReturn(aiMessage); + when(chatResponse.tokenUsage()).thenReturn(actualTokenUsage); // LLM provides token usage + when(chatModel.chat(any(ChatRequest.class))).thenReturn(chatResponse); + + // When + final LlmResponse response = + langChain4jWithEstimator.generateContent(llmRequest, false).blockingFirst(); + + // Then + // IMPORTANT: When TokenCountEstimator is present, it takes priority over TokenUsage + assertThat(response).isNotNull(); + assertThat(response.usageMetadata()).isPresent(); + final GenerateContentResponseUsageMetadata usageMetadata = response.usageMetadata().get(); + assertThat(usageMetadata.promptTokenCount()).isEqualTo(Optional.of(100)); // From estimator + assertThat(usageMetadata.candidatesTokenCount()).isEqualTo(Optional.of(50)); // From estimator + assertThat(usageMetadata.totalTokenCount()).isEqualTo(Optional.of(150)); // 100 + 50 + + // Verify the estimator was called (it takes priority) + verify(tokenCountEstimator).estimateTokenCountInMessages(any()); + verify(tokenCountEstimator).estimateTokenCountInText("The weather is sunny today."); + } + + @Test + @DisplayName("Should not include usageMetadata when TokenUsage is null and no estimator provided") + void testNoUsageMetadataWithoutEstimator() { + // Given + // Create LangChain4j WITHOUT TokenCountEstimator (default behavior) + final LangChain4j langChain4jNoEstimator = + LangChain4j.builder().chatModel(chatModel).modelName(MODEL_NAME).build(); + + // Create a LlmRequest + final LlmRequest llmRequest = + LlmRequest.builder() + .contents(List.of(Content.fromParts(Part.fromText("Hello, world!")))) + .build(); + + // Mock ChatResponse WITHOUT TokenUsage + final ChatResponse chatResponse = mock(ChatResponse.class); + final AiMessage aiMessage = AiMessage.from("Hello! How can I help you?"); + when(chatResponse.aiMessage()).thenReturn(aiMessage); + when(chatResponse.tokenUsage()).thenReturn(null); // No token usage from LLM + when(chatModel.chat(any(ChatRequest.class))).thenReturn(chatResponse); + + // When + final LlmResponse response = + langChain4jNoEstimator.generateContent(llmRequest, false).blockingFirst(); + + // Then + // Verify the response does NOT have usage metadata + assertThat(response).isNotNull(); + assertThat(response.content()).isPresent(); + assertThat(response.content().get().text()).isEqualTo("Hello! How can I help you?"); + + // IMPORTANT: usageMetadata should be empty when no TokenUsage and no estimator + assertThat(response.usageMetadata()).isEmpty(); + } + + @Test + @DisplayName("Should handle null AiMessage text without throwing NPE") + void testGenerateContentWithNullAiMessageText() { + // Given + final LlmRequest llmRequest = + LlmRequest.builder().contents(List.of(Content.fromParts(Part.fromText("Hello")))).build(); + + final ChatResponse chatResponse = mock(ChatResponse.class); + final AiMessage aiMessage = mock(AiMessage.class); + when(aiMessage.text()).thenReturn(null); + when(aiMessage.hasToolExecutionRequests()).thenReturn(false); + when(chatResponse.aiMessage()).thenReturn(aiMessage); + when(chatModel.chat(any(ChatRequest.class))).thenReturn(chatResponse); + + // When + final Flowable responseFlowable = langChain4j.generateContent(llmRequest, false); + final LlmResponse response = responseFlowable.blockingFirst(); + // Then - no NPE thrown, and content has no text parts + assertThat(response).isNotNull(); + assertThat(response.content()).isPresent(); + assertThat(response.content().get().parts().orElse(List.of())).isEmpty(); + } } diff --git a/contrib/spring-ai/DOCUMENT-GEMINI.md b/contrib/spring-ai/DOCUMENT-GEMINI.md deleted file mode 100644 index 393562528..000000000 --- a/contrib/spring-ai/DOCUMENT-GEMINI.md +++ /dev/null @@ -1,86 +0,0 @@ -# Documentation for the ADK Spring AI Library - -## 📖 Overview -The `google-adk-spring-ai` library provides an integration layer between the Google Agent Development Kit (ADK) and the Spring AI project. It allows developers to use Spring AI's `ChatModel`, `StreamingChatModel`, and `EmbeddingModel` as `BaseLlm` and `Embedding` implementations within the ADK framework. - -The library handles the conversion between ADK's request/response formats and Spring AI's prompt/chat response formats. It also includes auto-configuration to automatically expose Spring AI models as ADK `SpringAI` and `SpringAIEmbedding` beans in a Spring Boot application. - -## 🛠️ Building -To include this library in your project, use the following Maven coordinates: - -```xml - - com.google.adk - google-adk-spring-ai - 0.3.1-SNAPSHOT - -``` - -You will also need to include a dependency for the specific Spring AI model you want to use, for example: -```xml - - org.springframework.ai - spring-ai-openai - -``` - -## 🚀 Usage -The primary way to use this library is through Spring Boot auto-configuration. By including the `google-adk-spring-ai` dependency and a Spring AI model dependency (e.g., `spring-ai-openai`), the library will automatically create a `SpringAI` bean. This bean can then be injected and used as a `BaseLlm` in the ADK. - -**Example `application.properties`:** -```properties -# OpenAI configuration -spring.ai.openai.api-key=${OPENAI_API_KEY} -spring.ai.openai.chat.options.model=gpt-4o-mini -spring.ai.openai.chat.options.temperature=0.7 - -# ADK Spring AI configuration -adk.spring-ai.model=gpt-4o-mini -adk.spring-ai.validation.enabled=true -``` - -**Example usage in a Spring service:** -```java -import com.google.adk.models.BaseLlm; -import com.google.adk.models.LlmRequest; -import com.google.adk.models.LlmResponse; -import org.springframework.beans.factory.annotation.Autowired; -import org.springframework.stereotype.Service; -import reactor.core.publisher.Mono; - -@Service -public class MyAgentService { - - private final BaseLlm llm; - - @Autowired - public MyAgentService(BaseLlm llm) { - this.llm = llm; - } - - public Mono generateResponse(String prompt) { - LlmRequest request = LlmRequest.builder() - .addText(prompt) - .build(); - return Mono.from(llm.generateContent(request)) - .map(llmResponse -> llmResponse.content().get().parts().get(0).text().get()); - } -} -``` - -## 📚 API Reference -### Key Classes -- **`SpringAI`**: The main class that wraps a Spring AI `ChatModel` and/or `StreamingChatModel` and implements the ADK `BaseLlm` interface. - - **Methods**: - - `generateContent(LlmRequest llmRequest, boolean stream)`: Generates content, either streaming or non-streaming, by calling the underlying Spring AI model. It converts the ADK `LlmRequest` to a Spring AI `Prompt` and the `ChatResponse` back to an ADK `LlmResponse`. - -- **`SpringAIEmbedding`**: Wraps a Spring AI `EmbeddingModel` to be used for generating embeddings within the ADK framework. - -- **`SpringAIAutoConfiguration`**: The Spring Boot auto-configuration class that automatically discovers and configures `SpringAI` and `SpringAIEmbedding` beans based on the `ChatModel`, `StreamingChatModel`, and `EmbeddingModel` beans present in the application context. - -- **`SpringAIProperties`**: A configuration properties class (`@ConfigurationProperties("adk.spring-ai")`) that allows for customization of the Spring AI integration. - - **Properties**: - - `model`: The model name to use. - - `validation.enabled`: Whether to enable configuration validation. - - `validation.fail-fast`: Whether to fail fast on validation errors. - - `observability.enabled`: Whether to enable observability features. diff --git a/contrib/spring-ai/README.md b/contrib/spring-ai/README.md index 588b3d122..0ce7de4fe 100644 --- a/contrib/spring-ai/README.md +++ b/contrib/spring-ai/README.md @@ -18,21 +18,21 @@ To use ADK Java with the Spring AI integration in your application, add the foll com.google.adk google-adk - 0.3.1-SNAPSHOT + 1.0.1-rc.1-SNAPSHOT com.google.adk google-adk-spring-ai - 0.3.1-SNAPSHOT + 1.0.1-rc.1-SNAPSHOT org.springframework.ai spring-ai-bom - 1.1.1 + 2.0.0-M3 pom import @@ -109,14 +109,14 @@ Add the Spring AI provider dependencies for the AI services you want to use: org.springframework.boot spring-boot-starter-parent - 3.2.0 + 4.0.2 17 - 1.1.1 - 0.3.1-SNAPSHOT + 2.0.0-M3 + 1.0.1-rc.1-SNAPSHOT @@ -271,7 +271,7 @@ public class MyAdkSpringAiApplication { .anthropicApi(anthropicApi) .build(); - return new SpringAI(chatModel, "claude-3-5-sonnet-20241022"); + return new SpringAI(chatModel, "claude-sonnet-4-6"); } @Bean @@ -312,7 +312,7 @@ spring: api-key: ${ANTHROPIC_API_KEY} chat: options: - model: claude-3-5-sonnet-20241022 + model: claude-sonnet-4-6 temperature: 0.7 # ADK Spring AI Configuration @@ -365,13 +365,13 @@ The main adapter class that implements `BaseLlm` and wraps Spring AI `ChatModel` **Usage:** ```java // With ChatModel only -SpringAI springAI = new SpringAI(chatModel, "claude-sonnet-4-20250514"); +SpringAI springAI = new SpringAI(chatModel, "claude-sonnet-4-6"); // With both ChatModel and StreamingChatModel -SpringAI springAI = new SpringAI(chatModel, streamingChatModel, "claude-sonnet-4-20250514"); +SpringAI springAI = new SpringAI(chatModel, streamingChatModel, "claude-sonnet-4-6"); // With observability configuration -SpringAI springAI = new SpringAI(chatModel, "claude-sonnet-4-20250514", observabilityConfig); +SpringAI springAI = new SpringAI(chatModel, "claude-sonnet-4-6", observabilityConfig); ``` #### 2. MessageConverter (MessageConverter.java) @@ -533,7 +533,7 @@ The library works with any Spring AI provider: - Features: Chat, streaming, function calling, embeddings 2. **Anthropic** (`spring-ai-anthropic`) - - Models: Claude 3.5 Sonnet, Claude 3 Haiku + - Models: Claude 4.x Sonnet, Claude 4.x Haiku - Features: Chat, streaming, function calling - **Note:** Requires proper function schema registration @@ -563,7 +563,7 @@ The library works with any Spring AI provider: #### Anthropic - **Function Calling:** Requires explicit schema registration using `inputSchema()` method -- **Model Names:** Use full model names like `claude-3-5-sonnet-20241022` +- **Model Names:** Use full model names like `claude-sonnet-4-6` - **API Key:** Requires `ANTHROPIC_API_KEY` environment variable #### OpenAI diff --git a/contrib/spring-ai/src/main/java/com/google/adk/models/springai/MessageConverter.java b/contrib/spring-ai/src/main/java/com/google/adk/models/springai/MessageConverter.java index 036a898bb..3983b08a5 100644 --- a/contrib/spring-ai/src/main/java/com/google/adk/models/springai/MessageConverter.java +++ b/contrib/spring-ai/src/main/java/com/google/adk/models/springai/MessageConverter.java @@ -221,8 +221,7 @@ private List handleUserContent(Content content) { } catch (Exception e) { // Log warning but continue processing other parts // In production, consider proper logging framework - System.err.println( - "Warning: Failed to parse media mime type: " + blob.mimeType().get()); + System.err.println("Warning: Failed to process media part: " + e.getMessage()); } } } else if (part.fileData().isPresent()) { @@ -235,19 +234,14 @@ private List handleUserContent(Content content) { URI uri = URI.create(fileData.fileUri().get()); mediaList.add(new Media(mimeType, uri)); } catch (Exception e) { - System.err.println( - "Warning: Failed to parse media mime type: " + fileData.mimeType().get()); + System.err.println("Warning: Failed to process media part: " + e.getMessage()); } } } } List messages = new ArrayList<>(); - // Create UserMessage with text - // TODO: Media attachments support - UserMessage constructors with media are private in Spring - // AI 1.1.0 - // For now, only text content is supported - messages.add(new UserMessage(textBuilder.toString())); + messages.add(UserMessage.builder().text(textBuilder.toString()).media(mediaList).build()); messages.addAll(toolResponseMessages); return messages; diff --git a/contrib/spring-ai/src/main/java/com/google/adk/models/springai/ToolConverter.java b/contrib/spring-ai/src/main/java/com/google/adk/models/springai/ToolConverter.java index 95dafadb4..4012ee5d6 100644 --- a/contrib/spring-ai/src/main/java/com/google/adk/models/springai/ToolConverter.java +++ b/contrib/spring-ai/src/main/java/com/google/adk/models/springai/ToolConverter.java @@ -23,6 +23,7 @@ import java.util.HashMap; import java.util.List; import java.util.Map; +import java.util.Set; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import org.springframework.ai.tool.ToolCallback; @@ -172,6 +173,17 @@ public List convertToSpringAiTools(Map tools) { } catch (Exception e) { logger.error("Error serializing schema to JSON: {}", e.getMessage(), e); } + } else if (declaration.parametersJsonSchema().isPresent()) { + callbackBuilder.inputType(Map.class); + try { + String schemaJson = + new com.fasterxml.jackson.databind.ObjectMapper() + .writeValueAsString(declaration.parametersJsonSchema().get()); + callbackBuilder.inputSchema(schemaJson); + logger.debug("Set input schema JSON from parametersJsonSchema: {}", schemaJson); + } catch (Exception e) { + logger.error("Error serializing parametersJsonSchema to JSON: {}", e.getMessage(), e); + } } toolCallbacks.add(callbackBuilder.build()); @@ -187,45 +199,63 @@ public List convertToSpringAiTools(Map tools) { */ private Map processArguments( Map args, FunctionDeclaration declaration) { - // If the arguments already match the expected format, return as-is if (declaration.parameters().isPresent()) { var schema = declaration.parameters().get(); if (schema.properties().isPresent()) { - var expectedParams = schema.properties().get().keySet(); - - // Check if all expected parameters are present at the top level - boolean allParamsPresent = expectedParams.stream().allMatch(args::containsKey); - if (allParamsPresent) { - return args; + return normalizeArguments(args, schema.properties().get().keySet()); + } + } else if (declaration.parametersJsonSchema().isPresent()) { + try { + @SuppressWarnings("unchecked") + Map schemaMap = + new com.fasterxml.jackson.databind.ObjectMapper() + .convertValue(declaration.parametersJsonSchema().get(), Map.class); + Object propertiesObj = schemaMap.get("properties"); + if (propertiesObj instanceof Map) { + @SuppressWarnings("unchecked") + Set expectedParams = ((Map) propertiesObj).keySet(); + return normalizeArguments(args, expectedParams); } + } catch (Exception e) { + logger.warn( + "Error processing parametersJsonSchema for argument mapping: {}", e.getMessage()); + } + } - // Check if arguments are nested under a single key (common pattern) - if (args.size() == 1) { - var singleValue = args.values().iterator().next(); - if (singleValue instanceof Map) { - @SuppressWarnings("unchecked") - Map nestedArgs = (Map) singleValue; - boolean allNestedParamsPresent = - expectedParams.stream().allMatch(nestedArgs::containsKey); - if (allNestedParamsPresent) { - return nestedArgs; - } - } - } + // If no processing worked, return original args and let ADK handle the error + return args; + } - // Check if we have a single parameter function and got a direct value - if (expectedParams.size() == 1) { - String expectedParam = expectedParams.iterator().next(); - if (args.size() == 1 && !args.containsKey(expectedParam)) { - // Try to map the single value to the expected parameter name - Object singleValue = args.values().iterator().next(); - return Map.of(expectedParam, singleValue); - } + private Map normalizeArguments( + Map args, Set expectedParams) { + // Check if all expected parameters are present at the top level + boolean allParamsPresent = expectedParams.stream().allMatch(args::containsKey); + if (allParamsPresent) { + return args; + } + + // Check if arguments are nested under a single key (common pattern) + if (args.size() == 1) { + var singleValue = args.values().iterator().next(); + if (singleValue instanceof Map) { + @SuppressWarnings("unchecked") + Map nestedArgs = (Map) singleValue; + boolean allNestedParamsPresent = expectedParams.stream().allMatch(nestedArgs::containsKey); + if (allNestedParamsPresent) { + return nestedArgs; } } } - // If no processing worked, return original args and let ADK handle the error + // Check if we have a single parameter function and got a direct value + if (expectedParams.size() == 1) { + String expectedParam = expectedParams.iterator().next(); + if (args.size() == 1 && !args.containsKey(expectedParam)) { + Object singleValue = args.values().iterator().next(); + return Map.of(expectedParam, singleValue); + } + } + return args; } diff --git a/contrib/spring-ai/src/test/java/com/google/adk/models/springai/MessageConverterTest.java b/contrib/spring-ai/src/test/java/com/google/adk/models/springai/MessageConverterTest.java index a57644b5d..b861a71f2 100644 --- a/contrib/spring-ai/src/test/java/com/google/adk/models/springai/MessageConverterTest.java +++ b/contrib/spring-ai/src/test/java/com/google/adk/models/springai/MessageConverterTest.java @@ -60,7 +60,9 @@ void testToLlmPromptWithUserMessage() { assertThat(prompt.getInstructions()).hasSize(1); Message message = prompt.getInstructions().get(0); assertThat(message).isInstanceOf(UserMessage.class); - assertThat(((UserMessage) message).getText()).isEqualTo("Hello, how are you?"); + UserMessage userMessage = (UserMessage) message; + assertThat(userMessage.getText()).isEqualTo("Hello, how are you?"); + assertThat(userMessage.getMedia()).isEmpty(); } @Test @@ -444,4 +446,181 @@ void testCombineMultipleSystemMessagesForGeminiCompatibility() { assertThat(secondMessage).isInstanceOf(UserMessage.class); assertThat(((UserMessage) secondMessage).getText()).isEqualTo("Hello world"); } + + @Test + void testUserMessageWithInlineMediaData() { + // Test conversion of ADK Content with inline media (image bytes) to Spring AI UserMessage + byte[] imageData = "fake-image-data".getBytes(); + String mimeType = "image/png"; + + Content userContent = + Content.builder() + .role("user") + .parts( + List.of( + Part.fromText("What's in this image?"), + Part.builder() + .inlineData( + com.google.genai.types.Blob.builder() + .mimeType(mimeType) + .data(imageData) + .build()) + .build())) + .build(); + + LlmRequest request = LlmRequest.builder().contents(List.of(userContent)).build(); + + Prompt prompt = messageConverter.toLlmPrompt(request); + + assertThat(prompt.getInstructions()).hasSize(1); + Message message = prompt.getInstructions().get(0); + assertThat(message).isInstanceOf(UserMessage.class); + + UserMessage userMessage = (UserMessage) message; + assertThat(userMessage.getText()).isEqualTo("What's in this image?"); + assertThat(userMessage.getMedia()).hasSize(1); + org.springframework.ai.content.Media media = userMessage.getMedia().get(0); + assertThat(media.getMimeType().toString()).isEqualTo(mimeType); + assertThat(media.getData()).isInstanceOf(byte[].class); + byte[] actualData = (byte[]) media.getData(); + assertThat(actualData).isEqualTo(imageData); + } + + @Test + void testUserMessageWithFileMediaData() { + // Test conversion of ADK Content with file-based media (URI) to Spring AI UserMessage + String fileUri = "gs://bucket/image.jpg"; + String mimeType = "image/jpeg"; + + Content userContent = + Content.builder() + .role("user") + .parts( + List.of( + Part.fromText("Analyze this image"), + Part.builder() + .fileData( + com.google.genai.types.FileData.builder() + .mimeType(mimeType) + .fileUri(fileUri) + .build()) + .build())) + .build(); + + LlmRequest request = LlmRequest.builder().contents(List.of(userContent)).build(); + + Prompt prompt = messageConverter.toLlmPrompt(request); + + assertThat(prompt.getInstructions()).hasSize(1); + Message message = prompt.getInstructions().get(0); + assertThat(message).isInstanceOf(UserMessage.class); + + UserMessage userMessage = (UserMessage) message; + assertThat(userMessage.getText()).isEqualTo("Analyze this image"); + assertThat(userMessage.getMedia()).hasSize(1); + org.springframework.ai.content.Media media = userMessage.getMedia().get(0); + assertThat(media.getMimeType().toString()).isEqualTo(mimeType); + assertThat(media.getData()).isInstanceOf(String.class); + String actualUri = (String) media.getData(); + assertThat(actualUri).isEqualTo(fileUri); + } + + @Test + void testUserMessageWithMultipleMediaAttachments() { + // Test conversion with multiple media attachments + byte[] image1 = "image1-data".getBytes(); + byte[] image2 = "image2-data".getBytes(); + + Content userContent = + Content.builder() + .role("user") + .parts( + List.of( + Part.fromText("Compare these images"), + Part.builder() + .inlineData( + com.google.genai.types.Blob.builder() + .mimeType("image/png") + .data(image1) + .build()) + .build(), + Part.builder() + .inlineData( + com.google.genai.types.Blob.builder() + .mimeType("image/jpeg") + .data(image2) + .build()) + .build())) + .build(); + + LlmRequest request = LlmRequest.builder().contents(List.of(userContent)).build(); + + Prompt prompt = messageConverter.toLlmPrompt(request); + + assertThat(prompt.getInstructions()).hasSize(1); + UserMessage userMessage = (UserMessage) prompt.getInstructions().get(0); + assertThat(userMessage.getText()).isEqualTo("Compare these images"); + assertThat(userMessage.getMedia()).hasSize(2); + } + + @Test + void testUserMessageWithInvalidMimeTypeGracefullySkipsMediaPart() { + // Test that an invalid MIME type string causes the media part to be skipped gracefully + byte[] imageData = "fake-image-data".getBytes(); + + Content userContent = + Content.builder() + .role("user") + .parts( + List.of( + Part.fromText("What's in this image?"), + Part.builder() + .inlineData( + com.google.genai.types.Blob.builder() + .mimeType("invalid/mime/type!!!") // invalid MIME type + .data(imageData) + .build()) + .build())) + .build(); + + LlmRequest request = LlmRequest.builder().contents(List.of(userContent)).build(); + + // Should not throw — invalid MIME type is silently skipped + Prompt prompt = messageConverter.toLlmPrompt(request); + + assertThat(prompt.getInstructions()).hasSize(1); + UserMessage userMessage = (UserMessage) prompt.getInstructions().get(0); + assertThat(userMessage.getText()).isEqualTo("What's in this image?"); + // Media part is skipped due to invalid MIME type + assertThat(userMessage.getMedia()).isEmpty(); + } + + @Test + void testUserMessageWithMediaOnly() { + // Test conversion with media but no text + byte[] imageData = "image-only".getBytes(); + + Content userContent = + Content.builder() + .role("user") + .parts( + List.of( + Part.builder() + .inlineData( + com.google.genai.types.Blob.builder() + .mimeType("image/png") + .data(imageData) + .build()) + .build())) + .build(); + + LlmRequest request = LlmRequest.builder().contents(List.of(userContent)).build(); + + Prompt prompt = messageConverter.toLlmPrompt(request); + + assertThat(prompt.getInstructions()).hasSize(1); + UserMessage userMessage = (UserMessage) prompt.getInstructions().get(0); + assertThat(userMessage.getText()).isEmpty(); + assertThat(userMessage.getMedia()).hasSize(1); + } } diff --git a/contrib/spring-ai/src/test/java/com/google/adk/models/springai/ToolConverterArgumentProcessingTest.java b/contrib/spring-ai/src/test/java/com/google/adk/models/springai/ToolConverterArgumentProcessingTest.java index 301a145e0..77b988837 100644 --- a/contrib/spring-ai/src/test/java/com/google/adk/models/springai/ToolConverterArgumentProcessingTest.java +++ b/contrib/spring-ai/src/test/java/com/google/adk/models/springai/ToolConverterArgumentProcessingTest.java @@ -115,6 +115,90 @@ private Map invokeProcessArguments( return (Map) method.invoke(converter, args, declaration); } + @Test + void testArgumentProcessingWithParametersJsonSchema_correctFormat() throws Exception { + ToolConverter converter = new ToolConverter(); + Method processArguments = getProcessArgumentsMethod(converter); + + com.google.genai.types.FunctionDeclaration declaration = + com.google.genai.types.FunctionDeclaration.builder() + .name("getWeatherInfo") + .description("Get weather information") + .parametersJsonSchema( + Map.of( + "type", "object", "properties", Map.of("location", Map.of("type", "string")))) + .build(); + + Map correctArgs = Map.of("location", "San Francisco"); + Map processedArgs = + invokeProcessArguments(processArguments, converter, correctArgs, declaration); + + assertThat(processedArgs).isEqualTo(correctArgs); + } + + @Test + void testArgumentProcessingWithParametersJsonSchema_nestedFormat() throws Exception { + ToolConverter converter = new ToolConverter(); + Method processArguments = getProcessArgumentsMethod(converter); + + com.google.genai.types.FunctionDeclaration declaration = + com.google.genai.types.FunctionDeclaration.builder() + .name("getWeatherInfo") + .description("Get weather information") + .parametersJsonSchema( + Map.of( + "type", "object", "properties", Map.of("location", Map.of("type", "string")))) + .build(); + + Map nestedArgs = Map.of("args", Map.of("location", "San Francisco")); + Map processedArgs = + invokeProcessArguments(processArguments, converter, nestedArgs, declaration); + + assertThat(processedArgs).containsEntry("location", "San Francisco"); + } + + @Test + void testArgumentProcessingWithParametersJsonSchema_directValue() throws Exception { + ToolConverter converter = new ToolConverter(); + Method processArguments = getProcessArgumentsMethod(converter); + + com.google.genai.types.FunctionDeclaration declaration = + com.google.genai.types.FunctionDeclaration.builder() + .name("getWeatherInfo") + .description("Get weather information") + .parametersJsonSchema( + Map.of( + "type", "object", "properties", Map.of("location", Map.of("type", "string")))) + .build(); + + Map directValueArgs = Map.of("value", "San Francisco"); + Map processedArgs = + invokeProcessArguments(processArguments, converter, directValueArgs, declaration); + + assertThat(processedArgs).containsEntry("location", "San Francisco"); + } + + @Test + void testArgumentProcessingWithParametersJsonSchema_noMatch() throws Exception { + ToolConverter converter = new ToolConverter(); + Method processArguments = getProcessArgumentsMethod(converter); + + com.google.genai.types.FunctionDeclaration declaration = + com.google.genai.types.FunctionDeclaration.builder() + .name("getWeatherInfo") + .description("Get weather information") + .parametersJsonSchema( + Map.of( + "type", "object", "properties", Map.of("location", Map.of("type", "string")))) + .build(); + + Map wrongArgs = Map.of("city", "San Francisco", "country", "USA"); + Map processedArgs = + invokeProcessArguments(processArguments, converter, wrongArgs, declaration); + + assertThat(processedArgs).isEqualTo(wrongArgs); + } + public static class WeatherTools { public static Map getWeatherInfo(String location) { return Map.of( diff --git a/contrib/spring-ai/src/test/java/com/google/adk/models/springai/ToolConverterTest.java b/contrib/spring-ai/src/test/java/com/google/adk/models/springai/ToolConverterTest.java index 231c8e1fe..1f3044159 100644 --- a/contrib/spring-ai/src/test/java/com/google/adk/models/springai/ToolConverterTest.java +++ b/contrib/spring-ai/src/test/java/com/google/adk/models/springai/ToolConverterTest.java @@ -26,6 +26,7 @@ import java.util.Optional; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; +import org.springframework.ai.tool.ToolCallback; class ToolConverterTest { @@ -178,4 +179,37 @@ void testToolMetadata() { assertThat(metadata.getDescription()).isEqualTo("Test description"); assertThat(metadata.getDeclaration()).isEqualTo(function); } + + @Test + void testConvertToSpringAiToolsWithParametersJsonSchema() { + Map jsonSchema = + Map.of( + "type", + "object", + "properties", + Map.of("location", Map.of("type", "string", "description", "City name")), + "required", + List.of("location")); + + FunctionDeclaration function = + FunctionDeclaration.builder() + .name("get_weather") + .description("Get weather for a location") + .parametersJsonSchema(jsonSchema) + .build(); + + BaseTool testTool = + new BaseTool("get_weather", "Get weather for a location") { + @Override + public Optional declaration() { + return Optional.of(function); + } + }; + + Map tools = Map.of("get_weather", testTool); + List toolCallbacks = toolConverter.convertToSpringAiTools(tools); + + assertThat(toolCallbacks).hasSize(1); + assertThat(toolCallbacks.get(0).getToolDefinition().name()).isEqualTo("get_weather"); + } } diff --git a/core/src/main/java/com/google/adk/agents/ParallelAgent.java b/core/src/main/java/com/google/adk/agents/ParallelAgent.java index f30d951aa..2593ec13a 100644 --- a/core/src/main/java/com/google/adk/agents/ParallelAgent.java +++ b/core/src/main/java/com/google/adk/agents/ParallelAgent.java @@ -16,11 +16,14 @@ package com.google.adk.agents; import static com.google.common.base.Strings.isNullOrEmpty; -import static com.google.common.collect.ImmutableList.toImmutableList; import com.google.adk.agents.ConfigAgentUtils.ConfigurationException; import com.google.adk.events.Event; +import com.google.errorprone.annotations.CanIgnoreReturnValue; import io.reactivex.rxjava3.core.Flowable; +import io.reactivex.rxjava3.core.Scheduler; +import io.reactivex.rxjava3.schedulers.Schedulers; +import java.util.ArrayList; import java.util.List; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -35,6 +38,7 @@ public class ParallelAgent extends BaseAgent { private static final Logger logger = LoggerFactory.getLogger(ParallelAgent.class); + private final Scheduler scheduler; /** * Constructor for ParallelAgent. @@ -44,24 +48,35 @@ public class ParallelAgent extends BaseAgent { * @param subAgents The list of sub-agents to run in parallel. * @param beforeAgentCallback Optional callback before the agent runs. * @param afterAgentCallback Optional callback after the agent runs. + * @param scheduler The scheduler to use for parallel execution. */ private ParallelAgent( String name, String description, List subAgents, List beforeAgentCallback, - List afterAgentCallback) { + List afterAgentCallback, + Scheduler scheduler) { super(name, description, subAgents, beforeAgentCallback, afterAgentCallback); + this.scheduler = scheduler; } /** Builder for {@link ParallelAgent}. */ public static class Builder extends BaseAgent.Builder { + private Scheduler scheduler = Schedulers.io(); + + @CanIgnoreReturnValue + public Builder scheduler(Scheduler scheduler) { + this.scheduler = scheduler; + return this; + } + @Override public ParallelAgent build() { return new ParallelAgent( - name, description, subAgents, beforeAgentCallback, afterAgentCallback); + name, description, subAgents, beforeAgentCallback, afterAgentCallback, scheduler); } } @@ -129,10 +144,11 @@ protected Flowable runAsyncImpl(InvocationContext invocationContext) { } var updatedInvocationContext = setBranchForCurrentAgent(this, invocationContext); - return Flowable.merge( - currentSubAgents.stream() - .map(subAgent -> subAgent.runAsync(updatedInvocationContext)) - .collect(toImmutableList())); + List> agentFlowables = new ArrayList<>(); + for (BaseAgent subAgent : currentSubAgents) { + agentFlowables.add(subAgent.runAsync(updatedInvocationContext).subscribeOn(scheduler)); + } + return Flowable.merge(agentFlowables); } /** diff --git a/core/src/main/java/com/google/adk/flows/llmflows/BaseLlmFlow.java b/core/src/main/java/com/google/adk/flows/llmflows/BaseLlmFlow.java index d4fe1b838..aa62b9f31 100644 --- a/core/src/main/java/com/google/adk/flows/llmflows/BaseLlmFlow.java +++ b/core/src/main/java/com/google/adk/flows/llmflows/BaseLlmFlow.java @@ -233,7 +233,7 @@ private Flowable callLlm( callLlmContext) .doOnSubscribe( s -> - Tracing.traceCallLlm( + traceCallLlm( span, context, eventForCallbackUsage.id(), @@ -520,6 +520,7 @@ public Flowable runLive(InvocationContext invocationContext) { .doOnComplete( () -> Tracing.traceSendData( + Span.current(), invocationContext, eventIdForSendData, llmRequestAfterPreprocess.contents())) @@ -529,6 +530,7 @@ public Flowable runLive(InvocationContext invocationContext) { span.setStatus(StatusCode.ERROR, error.getMessage()); span.recordException(error); Tracing.traceSendData( + Span.current(), invocationContext, eventIdForSendData, llmRequestAfterPreprocess.contents()); @@ -706,6 +708,19 @@ private Flowable buildPostprocessingEvents( return processorEvents.concatWith(Flowable.just(modelResponseEvent)).concatWith(functionEvents); } + /** + * Traces an LLM call without an associated exception. This is an overload for {@link + * Tracing#traceCallLlm} for successful calls. + */ + private void traceCallLlm( + Span span, + InvocationContext context, + String eventId, + LlmRequest llmRequest, + LlmResponse llmResponse) { + Tracing.traceCallLlm(span, context, eventId, llmRequest, llmResponse, null); + } + private Event buildModelResponseEvent( Event baseEventForLlmResponse, LlmRequest llmRequest, LlmResponse llmResponse) { Event.Builder eventBuilder = diff --git a/core/src/main/java/com/google/adk/flows/llmflows/Functions.java b/core/src/main/java/com/google/adk/flows/llmflows/Functions.java index 84a8141ea..0b0e5b4d5 100644 --- a/core/src/main/java/com/google/adk/flows/llmflows/Functions.java +++ b/core/src/main/java/com/google/adk/flows/llmflows/Functions.java @@ -178,8 +178,12 @@ public static Maybe handleFunctionCalls( if (events.size() > 1) { return Maybe.just(mergedEvent) - .doOnSuccess(event -> Tracing.traceToolResponse(event.id(), event)) - .compose(Tracing.trace("tool_response").setParent(parentContext)); + .compose( + Tracing.trace("execute_tool (merged)") + .setParent(parentContext) + .onSuccess( + (span, event) -> + Tracing.traceMergedToolCalls(span, event.id(), event))); } return Maybe.just(mergedEvent); }); @@ -269,10 +273,8 @@ private static Function> getFunctionCallMapper( tool, toolContext, functionCall, - functionArgs, - parentContext) - : callTool( - tool, functionArgs, toolContext, parentContext)) + functionArgs) + : callTool(tool, functionArgs, toolContext)) .compose(Tracing.withContext(parentContext))); return postProcessFunctionResult( @@ -296,8 +298,7 @@ private static Maybe> processFunctionLive( BaseTool tool, ToolContext toolContext, FunctionCall functionCall, - Map args, - Context parentContext) { + Map args) { // Case 1: Handle a call to stopStreaming if (functionCall.name().get().equals("stopStreaming") && args.containsKey("functionName")) { String functionNameToStop = (String) args.get("functionName"); @@ -365,7 +366,7 @@ private static Maybe> processFunctionLive( } // Case 3: Fallback for regular, non-streaming tools - return callTool(tool, args, toolContext, parentContext); + return callTool(tool, args, toolContext); } public static Set getLongRunningFunctionCalls( @@ -426,12 +427,22 @@ private static Maybe postProcessFunctionResult( Event event = buildResponseEvent( tool, finalFunctionResult, toolContext, invocationContext); - Tracing.traceToolResponse(event.id(), event); return Maybe.just(event); }); }) .compose( - Tracing.trace("tool_response [" + tool.name() + "]").setParent(parentContext)); + Tracing.trace("execute_tool [" + tool.name() + "]") + .setParent(parentContext) + .onSuccess( + (span, event) -> + Tracing.traceToolExecution( + span, + tool.name(), + tool.description(), + tool.getClass().getSimpleName(), + functionArgs, + event, + null))); } private static Optional mergeParallelFunctionResponseEvents( @@ -579,17 +590,10 @@ private static Maybe> maybeInvokeAfterToolCall( } private static Maybe> callTool( - BaseTool tool, Map args, ToolContext toolContext, Context parentContext) { + BaseTool tool, Map args, ToolContext toolContext) { return tool.runAsync(args, toolContext) .toMaybe() - .doOnSubscribe( - d -> - Tracing.traceToolCall( - tool.name(), tool.description(), tool.getClass().getSimpleName(), args)) .doOnError(t -> Span.current().recordException(t)) - .compose( - Tracing.>trace("tool_call [" + tool.name() + "]") - .setParent(parentContext)) .onErrorResumeNext( e -> Maybe.error( diff --git a/core/src/main/java/com/google/adk/models/Claude.java b/core/src/main/java/com/google/adk/models/Claude.java index ebb786e35..01feda1d4 100644 --- a/core/src/main/java/com/google/adk/models/Claude.java +++ b/core/src/main/java/com/google/adk/models/Claude.java @@ -31,8 +31,7 @@ import com.anthropic.models.messages.ToolUnion; import com.anthropic.models.messages.ToolUseBlockParam; import com.fasterxml.jackson.core.type.TypeReference; -import com.fasterxml.jackson.databind.ObjectMapper; -import com.fasterxml.jackson.datatype.jdk8.Jdk8Module; +import com.google.adk.JsonBaseModel; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; import com.google.genai.types.Content; @@ -170,9 +169,22 @@ private ContentBlockParam partToAnthropicMessageBlock(Part part) { .build()); } else if (part.functionResponse().isPresent()) { String content = ""; - if (part.functionResponse().get().response().isPresent() - && part.functionResponse().get().response().get().getOrDefault("result", null) != null) { - content = part.functionResponse().get().response().get().get("result").toString(); + if (part.functionResponse().get().response().isPresent()) { + Map responseData = part.functionResponse().get().response().get(); + + Object contentObj = responseData.get("content"); + Object resultObj = responseData.get("result"); + + if (contentObj instanceof List list && !list.isEmpty()) { + // Native MCP format: list of content blocks + content = extractMcpContentBlocks(list); + } else if (resultObj != null) { + // ADK tool result object + content = resultObj instanceof String s ? s : serializeToJson(resultObj); + } else if (!responseData.isEmpty()) { + // Fallback: arbitrary JSON structure + content = serializeToJson(responseData); + } } return ContentBlockParam.ofToolResult( ToolResultBlockParam.builder() @@ -184,6 +196,30 @@ private ContentBlockParam partToAnthropicMessageBlock(Part part) { throw new UnsupportedOperationException("Not supported yet."); } + private String extractMcpContentBlocks(List list) { + List textBlocks = new ArrayList<>(); + for (Object item : list) { + if (item instanceof Map m && "text".equals(m.get("type"))) { + Object textObj = m.get("text"); + textBlocks.add(textObj != null ? String.valueOf(textObj) : ""); + } else if (item instanceof String s) { + textBlocks.add(s); + } else { + textBlocks.add(serializeToJson(item)); + } + } + return String.join("\n", textBlocks); + } + + private String serializeToJson(Object obj) { + try { + return JsonBaseModel.getMapper().writeValueAsString(obj); + } catch (Exception e) { + logger.warn("Failed to serialize object to JSON", e); + return String.valueOf(obj); + } + } + private void updateTypeString(Map valueDict) { if (valueDict == null) { return; @@ -221,10 +257,9 @@ private Tool functionDeclarationToAnthropicTool(FunctionDeclaration functionDecl .get() .forEach( (key, schema) -> { - ObjectMapper objectMapper = new ObjectMapper(); - objectMapper.registerModule(new Jdk8Module()); Map schemaMap = - objectMapper.convertValue(schema, new TypeReference>() {}); + JsonBaseModel.getMapper() + .convertValue(schema, new TypeReference>() {}); updateTypeString(schemaMap); properties.put(key, schemaMap); }); diff --git a/core/src/main/java/com/google/adk/runner/Runner.java b/core/src/main/java/com/google/adk/runner/Runner.java index 1f7d924ab..2bfbca881 100644 --- a/core/src/main/java/com/google/adk/runner/Runner.java +++ b/core/src/main/java/com/google/adk/runner/Runner.java @@ -425,36 +425,6 @@ public Flowable runAsync(String userId, String sessionId, Content newMess return runAsync(userId, sessionId, newMessage, RunConfig.builder().build()); } - /** - * See {@link #runAsync(Session, Content, RunConfig, Map)}. - * - * @deprecated Use runAsync with sessionId. - */ - @Deprecated(since = "0.4.0", forRemoval = true) - public Flowable runAsync(Session session, Content newMessage, RunConfig runConfig) { - return runAsync(session, newMessage, runConfig, /* stateDelta= */ null); - } - - /** - * Runs the agent asynchronously using a provided Session object. - * - * @param session The session to run the agent in. - * @param newMessage The new message from the user to process. - * @param runConfig Configuration for the agent run. - * @param stateDelta Optional map of state updates to merge into the session for this run. - * @return A Flowable stream of {@link Event} objects generated by the agent during execution. - * @deprecated Use runAsync with sessionId. - */ - @Deprecated(since = "0.4.0", forRemoval = true) - public Flowable runAsync( - Session session, - Content newMessage, - RunConfig runConfig, - @Nullable Map stateDelta) { - return runAsyncImpl(session, newMessage, runConfig, stateDelta) - .compose(Tracing.trace("invocation")); - } - /** * Runs the agent asynchronously using a provided Session object. * @@ -559,7 +529,7 @@ private Flowable runAgentWithFreshSession( contextWithUpdatedSession .agent() .runAsync(contextWithUpdatedSession) - .flatMap( + .concatMap( agentEvent -> this.sessionService .appendEvent(updatedSession, agentEvent) @@ -735,18 +705,6 @@ protected Flowable runLiveImpl( }); } - /** - * Runs the agent asynchronously with a default user ID. - * - * @return stream of generated events. - */ - @Deprecated(since = "0.5.0", forRemoval = true) - public Flowable runWithSessionId( - String sessionId, Content newMessage, RunConfig runConfig) { - // TODO(b/410859954): Add user_id to getter or method signature. Assuming "tmp-user" for now. - return this.runAsync("tmp-user", sessionId, newMessage, runConfig); - } - /** * Checks if the agent and its parent chain allow transfer up the tree. * diff --git a/core/src/main/java/com/google/adk/sessions/VertexAiSessionService.java b/core/src/main/java/com/google/adk/sessions/VertexAiSessionService.java index 4336f96c9..99e7e3479 100644 --- a/core/src/main/java/com/google/adk/sessions/VertexAiSessionService.java +++ b/core/src/main/java/com/google/adk/sessions/VertexAiSessionService.java @@ -128,15 +128,17 @@ public Single listSessions(String appName, String userId) .map( listSessionsResponseMap -> parseListSessionsResponse(listSessionsResponseMap, appName, userId)) - .defaultIfEmpty(ListSessionsResponse.builder().build()); + .defaultIfEmpty(ListSessionsResponse.builder().sessions(new ArrayList<>()).build()); } private ListSessionsResponse parseListSessionsResponse( JsonNode listSessionsResponseMap, String appName, String userId) { + JsonNode sessionsNode = listSessionsResponseMap.get("sessions"); + if (sessionsNode == null || sessionsNode.isNull() || sessionsNode.isEmpty()) { + return ListSessionsResponse.builder().build(); + } List> apiSessions = - objectMapper.convertValue( - listSessionsResponseMap.get("sessions"), - new TypeReference>>() {}); + objectMapper.convertValue(sessionsNode, new TypeReference>>() {}); List sessions = new ArrayList<>(); for (Map apiSession : apiSessions) { @@ -172,7 +174,7 @@ public Single listEvents(String appName, String userId, Stri private ListEventsResponse parseListEventsResponse(JsonNode listEventsResponse) { JsonNode sessionEventsNode = listEventsResponse.get("sessionEvents"); if (sessionEventsNode == null || sessionEventsNode.isEmpty()) { - return ListEventsResponse.builder().events(new ArrayList<>()).build(); + return ListEventsResponse.builder().build(); } return ListEventsResponse.builder() .events( diff --git a/core/src/main/java/com/google/adk/telemetry/Tracing.java b/core/src/main/java/com/google/adk/telemetry/Tracing.java index 215e317e1..589215073 100644 --- a/core/src/main/java/com/google/adk/telemetry/Tracing.java +++ b/core/src/main/java/com/google/adk/telemetry/Tracing.java @@ -33,6 +33,7 @@ import io.opentelemetry.api.GlobalOpenTelemetry; import io.opentelemetry.api.common.AttributeKey; import io.opentelemetry.api.trace.Span; +import io.opentelemetry.api.trace.StatusCode; import io.opentelemetry.api.trace.Tracer; import io.opentelemetry.context.Context; import io.opentelemetry.context.Scope; @@ -61,6 +62,7 @@ import java.util.function.BiConsumer; import java.util.function.Consumer; import java.util.function.Supplier; +import org.jspecify.annotations.Nullable; import org.reactivestreams.Publisher; import org.reactivestreams.Subscriber; import org.reactivestreams.Subscription; @@ -77,6 +79,11 @@ public class Tracing { private static final Logger log = LoggerFactory.getLogger(Tracing.class); + private static final String INVOKE_AGENT_OPERATION = "invoke_agent"; + private static final String EXECUTE_TOOL_OPERATION = "execute_tool"; + private static final String SEND_DATA_OPERATION = "send_data"; + private static final String CALL_LLM_OPERATION = "call_llm"; + private static final AttributeKey> GEN_AI_RESPONSE_FINISH_REASONS = AttributeKey.stringArrayKey("gen_ai.response.finish_reasons"); @@ -134,15 +141,6 @@ public class Tracing { private Tracing() {} - private static void traceWithSpan(String methodName, Consumer traceAction) { - Span span = Span.current(); - if (!span.getSpanContext().isValid()) { - log.trace("{}: No valid span in current context.", methodName); - return; - } - traceAction.accept(span); - } - private static void setInvocationAttributes( Span span, InvocationContext invocationContext, String eventId) { span.setAttribute(ADK_INVOCATION_ID, invocationContext.invocationId()); @@ -159,12 +157,6 @@ private static void setInvocationAttributes( } } - private static void setToolExecutionAttributes(Span span) { - span.setAttribute(GEN_AI_OPERATION_NAME, "execute_tool"); - span.setAttribute(ADK_LLM_REQUEST, "{}"); - span.setAttribute(ADK_LLM_RESPONSE, "{}"); - } - private static void setJsonAttribute(Span span, AttributeKey key, Object value) { if (!CAPTURE_MESSAGE_CONTENT_IN_SPANS) { span.setAttribute(key, "{}"); @@ -198,7 +190,7 @@ public static void setTracerForTesting(Tracer tracer) { */ public static void traceAgentInvocation( Span span, String agentName, String agentDescription, InvocationContext invocationContext) { - span.setAttribute(GEN_AI_OPERATION_NAME, "invoke_agent"); + span.setAttribute(GEN_AI_OPERATION_NAME, INVOKE_AGENT_OPERATION); span.setAttribute(GEN_AI_AGENT_DESCRIPTION, agentDescription); span.setAttribute(GEN_AI_AGENT_NAME, agentName); if (invocationContext.session() != null && invocationContext.session().id() != null) { @@ -207,58 +199,62 @@ public static void traceAgentInvocation( } /** - * Traces tool call arguments. - * - * @param args The arguments to the tool call. - */ - public static void traceToolCall( - String toolName, String toolDescription, String toolType, Map args) { - traceWithSpan( - "traceToolCall", - span -> { - setToolExecutionAttributes(span); - span.setAttribute(GEN_AI_TOOL_NAME, toolName); - span.setAttribute(GEN_AI_TOOL_DESCRIPTION, toolDescription); - span.setAttribute(GEN_AI_TOOL_TYPE, toolType); - - setJsonAttribute(span, ADK_TOOL_CALL_ARGS, args); - }); - } - - /** - * Traces tool response event. + * Traces a tool execution, including its arguments, response, and any potential error. * - * @param eventId The ID of the event. - * @param functionResponseEvent The function response event. + * @param span The span representing the tool execution. + * @param toolName The name of the tool. + * @param toolDescription The tool's description. + * @param toolType The tool's type (e.g., "FunctionTool"). + * @param args The arguments passed to the tool. + * @param functionResponseEvent The event containing the tool's response, if successful. + * @param error The exception thrown during execution, if any. */ - public static void traceToolResponse(String eventId, Event functionResponseEvent) { - traceWithSpan( - "traceToolResponse", - span -> { - setToolExecutionAttributes(span); - span.setAttribute(ADK_EVENT_ID, eventId); - - FunctionResponse functionResponse = - functionResponseEvent.functionResponses().stream().findFirst().orElse(null); - - String toolCallId = ""; - Object toolResponse = ""; - if (functionResponse != null) { - toolCallId = functionResponse.id().orElse(toolCallId); - if (functionResponse.response().isPresent()) { - toolResponse = functionResponse.response().get(); - } - } - - span.setAttribute(GEN_AI_TOOL_CALL_ID, toolCallId); + public static void traceToolExecution( + Span span, + String toolName, + String toolDescription, + String toolType, + Map args, + @Nullable Event functionResponseEvent, + @Nullable Exception error) { + span.setAttribute(GEN_AI_OPERATION_NAME, EXECUTE_TOOL_OPERATION); + span.setAttribute(GEN_AI_TOOL_NAME, toolName); + span.setAttribute(GEN_AI_TOOL_DESCRIPTION, toolDescription); + span.setAttribute(GEN_AI_TOOL_TYPE, toolType); + + setJsonAttribute(span, ADK_TOOL_CALL_ARGS, args); + + if (functionResponseEvent != null) { + span.setAttribute(ADK_EVENT_ID, functionResponseEvent.id()); + FunctionResponse functionResponse = + functionResponseEvent.functionResponses().stream().findFirst().orElse(null); + + String toolCallId = ""; + Object toolResponse = ""; + if (functionResponse != null) { + toolCallId = functionResponse.id().orElse(toolCallId); + if (functionResponse.response().isPresent()) { + toolResponse = functionResponse.response().get(); + } + } + span.setAttribute(GEN_AI_TOOL_CALL_ID, toolCallId); + Object finalToolResponse = + (toolResponse instanceof Map) ? toolResponse : ImmutableMap.of("result", toolResponse); + setJsonAttribute(span, ADK_TOOL_RESPONSE, finalToolResponse); + } else { + // Set placeholder if no response event is available (e.g., due to an error) + span.setAttribute(GEN_AI_TOOL_CALL_ID, ""); + setJsonAttribute(span, ADK_TOOL_RESPONSE, "{}"); + } - Object finalToolResponse = - (toolResponse instanceof Map) - ? toolResponse - : ImmutableMap.of("result", toolResponse); + // Also set empty LLM attributes for UI compatibility, like in traceToolResponse + span.setAttribute(ADK_LLM_REQUEST, "{}"); + span.setAttribute(ADK_LLM_RESPONSE, "{}"); - setJsonAttribute(span, ADK_TOOL_RESPONSE, finalToolResponse); - }); + if (error != null) { + span.setStatus(StatusCode.ERROR, error.getMessage()); + span.recordException(error); + } } /** @@ -303,8 +299,10 @@ public static void traceCallLlm( InvocationContext invocationContext, String eventId, LlmRequest llmRequest, - LlmResponse llmResponse) { + LlmResponse llmResponse, + @Nullable Exception error) { span.setAttribute(GEN_AI_SYSTEM, "gcp.vertex.agent"); + span.setAttribute(GEN_AI_OPERATION_NAME, CALL_LLM_OPERATION); llmRequest.model().ifPresent(modelName -> span.setAttribute(GEN_AI_REQUEST_MODEL, modelName)); setInvocationAttributes(span, invocationContext, eventId); @@ -312,6 +310,11 @@ public static void traceCallLlm( setJsonAttribute(span, ADK_LLM_REQUEST, buildLlmRequestForTrace(llmRequest)); setJsonAttribute(span, ADK_LLM_RESPONSE, llmResponse); + if (error != null) { + span.setStatus(StatusCode.ERROR, error.getMessage()); + span.recordException(error); + } + llmRequest .config() .ifPresent( @@ -352,18 +355,45 @@ public static void traceCallLlm( * @param data A list of content objects being sent. */ public static void traceSendData( - InvocationContext invocationContext, String eventId, List data) { - traceWithSpan( - "traceSendData", - span -> { - setInvocationAttributes(span, invocationContext, eventId); - - ImmutableList safeData = - Optional.ofNullable(data).orElse(ImmutableList.of()).stream() - .filter(Objects::nonNull) - .collect(toImmutableList()); - setJsonAttribute(span, ADK_DATA, safeData); - }); + Span span, InvocationContext invocationContext, String eventId, List data) { + if (!span.getSpanContext().isValid()) { + log.trace("traceSendData: No valid span in current context."); + return; + } + setInvocationAttributes(span, invocationContext, eventId); + span.setAttribute(GEN_AI_OPERATION_NAME, SEND_DATA_OPERATION); + + ImmutableList safeData = + Optional.ofNullable(data).orElse(ImmutableList.of()).stream() + .filter(Objects::nonNull) + .collect(toImmutableList()); + setJsonAttribute(span, ADK_DATA, safeData); + } + + /** + * Traces merged tool call events. + * + *

Calling this function is not needed for telemetry purposes. This is provided for preventing + * /debug/trace requests (typically sent by web UI). + * + * @param responseEventId The ID of the response event. + * @param functionResponseEvent The merged response event. + */ + public static void traceMergedToolCalls( + Span span, String responseEventId, Event functionResponseEvent) { + if (!span.getSpanContext().isValid()) { + log.trace("traceMergedToolCalls: No valid span in current context."); + return; + } + span.setAttribute(GEN_AI_OPERATION_NAME, EXECUTE_TOOL_OPERATION); + span.setAttribute(GEN_AI_TOOL_NAME, "(merged tools)"); + span.setAttribute(GEN_AI_TOOL_DESCRIPTION, "(merged tools)"); + span.setAttribute(GEN_AI_TOOL_CALL_ID, responseEventId); + span.setAttribute(ADK_TOOL_CALL_ARGS, "N/A"); + span.setAttribute(ADK_EVENT_ID, responseEventId); + setJsonAttribute(span, ADK_TOOL_RESPONSE, functionResponseEvent); + span.setAttribute(ADK_LLM_REQUEST, "{}"); + span.setAttribute(ADK_LLM_RESPONSE, "{}"); } /** diff --git a/core/src/main/java/com/google/adk/tools/SetModelResponseTool.java b/core/src/main/java/com/google/adk/tools/SetModelResponseTool.java index e23d6414a..3b0e411b4 100644 --- a/core/src/main/java/com/google/adk/tools/SetModelResponseTool.java +++ b/core/src/main/java/com/google/adk/tools/SetModelResponseTool.java @@ -16,6 +16,7 @@ package com.google.adk.tools; +import com.google.adk.SchemaUtils; import com.google.genai.types.FunctionDeclaration; import com.google.genai.types.Schema; import io.reactivex.rxjava3.core.Single; @@ -58,6 +59,10 @@ public Optional declaration() { public Single> runAsync(Map args, ToolContext toolContext) { // This tool is a marker for the final response, it doesn't do anything but return its arguments // which will be captured as the final result. - return Single.just(args); + return Single.fromCallable( + () -> { + SchemaUtils.validateMapOnSchema(args, outputSchema, /* isInput= */ false); + return args; + }); } } diff --git a/core/src/main/java/com/google/adk/tools/mcp/AbstractMcpTool.java b/core/src/main/java/com/google/adk/tools/mcp/AbstractMcpTool.java index d9c28e501..3b0c3d70a 100644 --- a/core/src/main/java/com/google/adk/tools/mcp/AbstractMcpTool.java +++ b/core/src/main/java/com/google/adk/tools/mcp/AbstractMcpTool.java @@ -16,7 +16,6 @@ package com.google.adk.tools.mcp; -import com.fasterxml.jackson.core.JsonProcessingException; import com.fasterxml.jackson.core.type.TypeReference; import com.fasterxml.jackson.databind.ObjectMapper; import com.google.adk.tools.BaseTool; @@ -24,13 +23,9 @@ import com.google.common.collect.ImmutableMap; import com.google.genai.types.FunctionDeclaration; import io.modelcontextprotocol.spec.McpSchema.CallToolResult; -import io.modelcontextprotocol.spec.McpSchema.Content; import io.modelcontextprotocol.spec.McpSchema.JsonSchema; -import io.modelcontextprotocol.spec.McpSchema.TextContent; import io.modelcontextprotocol.spec.McpSchema.Tool; import io.modelcontextprotocol.spec.McpSchema.ToolAnnotations; -import java.util.ArrayList; -import java.util.List; import java.util.Map; import java.util.Optional; @@ -116,51 +111,6 @@ protected static Map wrapCallResult( return ImmutableMap.of("error", "MCP framework error: CallToolResult was null"); } - List contents = callResult.content(); - Boolean isToolError = callResult.isError(); - - if (isToolError != null && isToolError) { - String errorMessage = "Tool execution failed."; - if (contents != null - && !contents.isEmpty() - && contents.get(0) instanceof TextContent textContent) { - if (textContent.text() != null && !textContent.text().isEmpty()) { - errorMessage += " Details: " + textContent.text(); - } - } - return ImmutableMap.of("error", errorMessage); - } - - if (contents == null || contents.isEmpty()) { - return ImmutableMap.of(); - } - - List textOutputs = new ArrayList<>(); - for (Content content : contents) { - if (content instanceof TextContent textContent) { - if (textContent.text() != null) { - textOutputs.add(textContent.text()); - } - } - } - - if (textOutputs.isEmpty()) { - return ImmutableMap.of( - "error", - "Tool '" + mcpToolName + "' returned content that is not TextContent.", - "content_details", - contents.toString()); - } - - List> resultMaps = new ArrayList<>(); - for (String textOutput : textOutputs) { - try { - resultMaps.add( - objectMapper.readValue(textOutput, new TypeReference>() {})); - } catch (JsonProcessingException e) { - resultMaps.add(ImmutableMap.of("text", textOutput)); - } - } - return ImmutableMap.of("text_output", resultMaps); + return objectMapper.convertValue(callResult, new TypeReference>() {}); } } diff --git a/core/src/test/java/com/google/adk/agents/LlmAgentTest.java b/core/src/test/java/com/google/adk/agents/LlmAgentTest.java index a9e7a6f8d..e40a83aa0 100644 --- a/core/src/test/java/com/google/adk/agents/LlmAgentTest.java +++ b/core/src/test/java/com/google/adk/agents/LlmAgentTest.java @@ -494,12 +494,10 @@ public void runAsync_withTools_createsToolSpans() throws InterruptedException { List spans = openTelemetryRule.getSpans(); SpanData agentSpan = findSpanByName(spans, "invoke_agent test agent"); List llmSpans = findSpansByName(spans, "call_llm"); - List toolCallSpans = findSpansByName(spans, "tool_call [echo_tool]"); - List toolResponseSpans = findSpansByName(spans, "tool_response [echo_tool]"); + List toolSpans = findSpansByName(spans, "execute_tool [echo_tool]"); assertThat(llmSpans).hasSize(2); - assertThat(toolCallSpans).hasSize(1); - assertThat(toolResponseSpans).hasSize(1); + assertThat(toolSpans).hasSize(1); String agentSpanId = agentSpan.getSpanContext().getSpanId(); llmSpans.forEach(s -> assertEquals(agentSpanId, s.getParentSpanContext().getSpanId())); @@ -507,9 +505,7 @@ public void runAsync_withTools_createsToolSpans() throws InterruptedException { // The tool calls and responses are children of the first LLM call that produced the function // call. String firstLlmSpanId = llmSpans.get(0).getSpanContext().getSpanId(); - toolCallSpans.forEach(s -> assertEquals(firstLlmSpanId, s.getParentSpanContext().getSpanId())); - toolResponseSpans.forEach( - s -> assertEquals(firstLlmSpanId, s.getParentSpanContext().getSpanId())); + toolSpans.forEach(s -> assertEquals(firstLlmSpanId, s.getParentSpanContext().getSpanId())); } @Test diff --git a/core/src/test/java/com/google/adk/agents/ParallelAgentTest.java b/core/src/test/java/com/google/adk/agents/ParallelAgentTest.java index a6afb5793..e51240c45 100644 --- a/core/src/test/java/com/google/adk/agents/ParallelAgentTest.java +++ b/core/src/test/java/com/google/adk/agents/ParallelAgentTest.java @@ -25,7 +25,10 @@ import com.google.genai.types.Content; import com.google.genai.types.Part; import io.reactivex.rxjava3.core.Flowable; +import io.reactivex.rxjava3.core.Scheduler; import io.reactivex.rxjava3.schedulers.Schedulers; +import io.reactivex.rxjava3.schedulers.TestScheduler; +import io.reactivex.rxjava3.subscribers.TestSubscriber; import java.util.List; import org.junit.Test; import org.junit.runner.RunWith; @@ -36,10 +39,16 @@ public final class ParallelAgentTest { static class TestingAgent extends BaseAgent { private final long delayMillis; + private final Scheduler scheduler; private TestingAgent(String name, String description, long delayMillis) { + this(name, description, delayMillis, Schedulers.computation()); + } + + private TestingAgent(String name, String description, long delayMillis, Scheduler scheduler) { super(name, description, ImmutableList.of(), null, null); this.delayMillis = delayMillis; + this.scheduler = scheduler; } @Override @@ -55,7 +64,7 @@ protected Flowable runAsyncImpl(InvocationContext invocationContext) { .build()); if (delayMillis > 0) { - return event.delay(delayMillis, MILLISECONDS, Schedulers.computation()); + return event.delay(delayMillis, MILLISECONDS, scheduler); } return event; } @@ -110,4 +119,79 @@ public void runAsync_noSubAgents_returnsEmptyFlowable() { assertThat(events).isEmpty(); } + + static class BlockingAgent extends BaseAgent { + private final long sleepMillis; + + private BlockingAgent(String name, long sleepMillis) { + super(name, "Blocking Agent", ImmutableList.of(), null, null); + this.sleepMillis = sleepMillis; + } + + @Override + protected Flowable runAsyncImpl(InvocationContext invocationContext) { + return Flowable.fromCallable( + () -> { + Thread.sleep(sleepMillis); + return Event.builder() + .author(name()) + .branch(invocationContext.branch().orElse(null)) + .invocationId(invocationContext.invocationId()) + .content(Content.fromParts(Part.fromText("Done"))) + .build(); + }); + } + + @Override + protected Flowable runLiveImpl(InvocationContext invocationContext) { + throw new UnsupportedOperationException("Not implemented"); + } + } + + @Test + public void runAsync_blockingSubAgents_shouldExecuteInParallel() { + long sleepTime = 1000; + BlockingAgent agent1 = new BlockingAgent("agent1", sleepTime); + BlockingAgent agent2 = new BlockingAgent("agent2", sleepTime); + + ParallelAgent parallelAgent = + ParallelAgent.builder().name("parallel_agent").subAgents(agent1, agent2).build(); + + InvocationContext invocationContext = createInvocationContext(parallelAgent); + + long startTime = System.currentTimeMillis(); + List events = parallelAgent.runAsync(invocationContext).toList().blockingGet(); + long duration = System.currentTimeMillis() - startTime; + + assertThat(events).hasSize(2); + // If parallel, duration should be less than 1.5 * sleepTime (1500ms). + assertThat(duration).isAtLeast(sleepTime); + assertThat(duration).isLessThan((long) (1.5 * sleepTime)); + } + + @Test + public void runAsync_withTestScheduler_usesVirtualTime() { + TestScheduler testScheduler = new TestScheduler(); + long delayMillis = 1000; + TestingAgent agent = + new TestingAgent("delayed_agent", "Delayed Agent", delayMillis, testScheduler); + + ParallelAgent parallelAgent = + ParallelAgent.builder() + .name("parallel_agent") + .subAgents(agent) + .scheduler(testScheduler) + .build(); + + InvocationContext invocationContext = createInvocationContext(parallelAgent); + + TestSubscriber testSubscriber = parallelAgent.runAsync(invocationContext).test(); + + testScheduler.advanceTimeBy(delayMillis - 100, MILLISECONDS); + testSubscriber.assertNoValues(); + testSubscriber.assertNotComplete(); + testScheduler.advanceTimeBy(200, MILLISECONDS); + testSubscriber.assertValueCount(1); + testSubscriber.assertComplete(); + } } diff --git a/core/src/test/java/com/google/adk/models/ClaudeTest.java b/core/src/test/java/com/google/adk/models/ClaudeTest.java new file mode 100644 index 000000000..677d40627 --- /dev/null +++ b/core/src/test/java/com/google/adk/models/ClaudeTest.java @@ -0,0 +1,97 @@ +/* + * Copyright 2026 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.google.adk.models; + +import static com.google.common.truth.Truth.assertThat; + +import com.anthropic.client.AnthropicClient; +import com.anthropic.models.messages.ContentBlockParam; +import com.anthropic.models.messages.ToolResultBlockParam; +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; +import com.google.genai.types.FunctionResponse; +import com.google.genai.types.Part; +import java.lang.reflect.Method; +import java.util.Map; +import org.junit.Before; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; +import org.mockito.Mockito; + +@RunWith(JUnit4.class) +public final class ClaudeTest { + + private Claude claude; + private Method partToAnthropicMessageBlockMethod; + + @Before + public void setUp() throws Exception { + AnthropicClient mockClient = Mockito.mock(AnthropicClient.class); + claude = new Claude("claude-3-opus", mockClient); + + // Access private method for testing the extraction logic + partToAnthropicMessageBlockMethod = + Claude.class.getDeclaredMethod("partToAnthropicMessageBlock", Part.class); + partToAnthropicMessageBlockMethod.setAccessible(true); + } + + @Test + public void testPartToAnthropicMessageBlock_mcpNativeFormat() throws Exception { + Map responseData = + ImmutableMap.of( + "content", + ImmutableList.of(ImmutableMap.of("type", "text", "text", "Extracted native MCP text"))); + FunctionResponse funcParam = + FunctionResponse.builder().name("test_tool").response(responseData).id("call_123").build(); + Part part = Part.builder().functionResponse(funcParam).build(); + + ContentBlockParam result = + (ContentBlockParam) partToAnthropicMessageBlockMethod.invoke(claude, part); + + ToolResultBlockParam toolResult = result.asToolResult(); + assertThat(toolResult.content().get().asString()).isEqualTo("Extracted native MCP text"); + } + + @Test + public void testPartToAnthropicMessageBlock_legacyResultKey() throws Exception { + Map responseData = ImmutableMap.of("result", "Legacy result text"); + FunctionResponse funcParam = + FunctionResponse.builder().name("test_tool").response(responseData).id("call_123").build(); + Part part = Part.builder().functionResponse(funcParam).build(); + + ContentBlockParam result = + (ContentBlockParam) partToAnthropicMessageBlockMethod.invoke(claude, part); + + ToolResultBlockParam toolResult = result.asToolResult(); + assertThat(toolResult.content().get().asString()).isEqualTo("Legacy result text"); + } + + @Test + public void testPartToAnthropicMessageBlock_jsonFallback() throws Exception { + Map responseData = ImmutableMap.of("custom_key", "custom_value"); + FunctionResponse funcParam = + FunctionResponse.builder().name("test_tool").response(responseData).id("call_123").build(); + Part part = Part.builder().functionResponse(funcParam).build(); + + ContentBlockParam result = + (ContentBlockParam) partToAnthropicMessageBlockMethod.invoke(claude, part); + + ToolResultBlockParam toolResult = result.asToolResult(); + assertThat(toolResult.content().get().asString()).contains("\"custom_key\":\"custom_value\""); + } +} diff --git a/core/src/test/java/com/google/adk/runner/RunnerTest.java b/core/src/test/java/com/google/adk/runner/RunnerTest.java index a3e21cb73..b68b6ff5f 100644 --- a/core/src/test/java/com/google/adk/runner/RunnerTest.java +++ b/core/src/test/java/com/google/adk/runner/RunnerTest.java @@ -26,6 +26,7 @@ import static com.google.common.truth.Truth.assertThat; import static java.nio.charset.StandardCharsets.UTF_8; import static java.util.Arrays.stream; +import static java.util.concurrent.TimeUnit.MILLISECONDS; import static org.mockito.ArgumentMatchers.any; import static org.mockito.Mockito.CALLS_REAL_METHODS; import static org.mockito.Mockito.mock; @@ -33,6 +34,7 @@ import static org.mockito.Mockito.verify; import static org.mockito.Mockito.when; +import com.google.adk.agents.BaseAgent; import com.google.adk.agents.InvocationContext; import com.google.adk.agents.LiveRequestQueue; import com.google.adk.agents.LlmAgent; @@ -43,6 +45,7 @@ import com.google.adk.flows.llmflows.Functions; import com.google.adk.models.LlmResponse; import com.google.adk.plugins.BasePlugin; +import com.google.adk.sessions.BaseSessionService; import com.google.adk.sessions.Session; import com.google.adk.sessions.SessionKey; import com.google.adk.summarizer.EventsCompactionConfig; @@ -851,6 +854,45 @@ public void beforeRunCallback_withStateDelta_seesMergedState() { assertThat(sessionInCallback.state()).containsEntry("number", 123); } + @Test + public void runAsync_ensureEventsAreAppendedInOrder() throws Exception { + Event event1 = TestUtils.createEvent("1"); + Event event2 = TestUtils.createEvent("2"); + BaseAgent mockAgent = TestUtils.createSubAgent("test agent", event1, event2); + + BaseSessionService mockSessionService = mock(BaseSessionService.class); + + when(mockSessionService.getSession(any(), any(), any(), any())).thenReturn(Maybe.just(session)); + when(mockSessionService.appendEvent(any(), any())) + .thenAnswer( + invocation -> { + Event eventArg = invocation.getArgument(1); + Single result = Single.just(eventArg); + if (eventArg.id().equals("1")) { + // Artificially delay the first event to ensure it is appended first. + return result.delay(100, MILLISECONDS); + } + return result; + }); + + Runner mockRunner = + Runner.builder() + .agent(mockAgent) + .appName("test") + .sessionService(mockSessionService) + .build(); + + List results = + mockRunner + .runAsync("user", session.id(), createContent("user message")) + .toList() + .blockingGet(); + + assertThat(simplifyEvents(results)) + .containsExactly("author: content for event 1", "author: content for event 2") + .inOrder(); + } + private Content createContent(String text) { return Content.builder().parts(Part.builder().text(text).build()).build(); } @@ -1019,21 +1061,16 @@ public void runAsync_createsToolSpansWithCorrectParent() { List spans = openTelemetryRule.getSpans(); List llmSpans = spans.stream().filter(s -> s.getName().equals("call_llm")).toList(); - List toolCallSpans = - spans.stream().filter(s -> s.getName().equals("tool_call [echo_tool]")).toList(); - List toolResponseSpans = - spans.stream().filter(s -> s.getName().equals("tool_response [echo_tool]")).toList(); + List toolSpans = + spans.stream().filter(s -> s.getName().equals("execute_tool [echo_tool]")).toList(); assertThat(llmSpans).hasSize(2); - assertThat(toolCallSpans).hasSize(1); - assertThat(toolResponseSpans).hasSize(1); + assertThat(toolSpans).hasSize(1); List llmSpanIds = llmSpans.stream().map(s -> s.getSpanContext().getSpanId()).toList(); - String toolCallParentId = toolCallSpans.get(0).getParentSpanContext().getSpanId(); - String toolResponseParentId = toolResponseSpans.get(0).getParentSpanContext().getSpanId(); + String toolParentId = toolSpans.get(0).getParentSpanContext().getSpanId(); - assertThat(toolCallParentId).isEqualTo(toolResponseParentId); - assertThat(llmSpanIds).contains(toolCallParentId); + assertThat(llmSpanIds).contains(toolParentId); } @Test @@ -1059,22 +1096,17 @@ public void runLive_createsToolSpansWithCorrectParent() throws Exception { List spans = openTelemetryRule.getSpans(); List llmSpans = spans.stream().filter(s -> s.getName().equals("call_llm")).toList(); - List toolCallSpans = - spans.stream().filter(s -> s.getName().equals("tool_call [echo_tool]")).toList(); - List toolResponseSpans = - spans.stream().filter(s -> s.getName().equals("tool_response [echo_tool]")).toList(); + List toolSpans = + spans.stream().filter(s -> s.getName().equals("execute_tool [echo_tool]")).toList(); // In runLive, there is one call_llm span for the execution assertThat(llmSpans).hasSize(1); - assertThat(toolCallSpans).hasSize(1); - assertThat(toolResponseSpans).hasSize(1); + assertThat(toolSpans).hasSize(1); List llmSpanIds = llmSpans.stream().map(s -> s.getSpanContext().getSpanId()).toList(); - String toolCallParentId = toolCallSpans.get(0).getParentSpanContext().getSpanId(); - String toolResponseParentId = toolResponseSpans.get(0).getParentSpanContext().getSpanId(); + String toolParentId = toolSpans.get(0).getParentSpanContext().getSpanId(); - assertThat(toolCallParentId).isEqualTo(toolResponseParentId); - assertThat(llmSpanIds).contains(toolCallParentId); + assertThat(llmSpanIds).contains(toolParentId); } @Test diff --git a/core/src/test/java/com/google/adk/sessions/MockApiAnswer.java b/core/src/test/java/com/google/adk/sessions/MockApiAnswer.java index 111b1dce3..743bcee8d 100644 --- a/core/src/test/java/com/google/adk/sessions/MockApiAnswer.java +++ b/core/src/test/java/com/google/adk/sessions/MockApiAnswer.java @@ -36,14 +36,25 @@ class MockApiAnswer implements Answer { private final Map sessionMap; private final Map eventMap; + private final String rawApiResponse; MockApiAnswer(Map sessionMap, Map eventMap) { this.sessionMap = sessionMap; this.eventMap = eventMap; + this.rawApiResponse = null; + } + + MockApiAnswer(String rawApiResponse) { + this.sessionMap = null; + this.eventMap = null; + this.rawApiResponse = rawApiResponse; } @Override public ApiResponse answer(InvocationOnMock invocation) throws Throwable { + if (rawApiResponse != null) { + return responseWithBody(rawApiResponse); + } String httpMethod = invocation.getArgument(0); String path = invocation.getArgument(1); if (httpMethod.equals("POST") && SESSIONS_REGEX.matcher(path).matches()) { diff --git a/core/src/test/java/com/google/adk/sessions/VertexAiSessionServiceTest.java b/core/src/test/java/com/google/adk/sessions/VertexAiSessionServiceTest.java index def4faf4c..dd62263d7 100644 --- a/core/src/test/java/com/google/adk/sessions/VertexAiSessionServiceTest.java +++ b/core/src/test/java/com/google/adk/sessions/VertexAiSessionServiceTest.java @@ -37,7 +37,6 @@ public class VertexAiSessionServiceTest { private static final ObjectMapper mapper = JsonBaseModel.getMapper(); - private static final String MOCK_SESSION_STRING_1 = """ { @@ -319,6 +318,24 @@ public void listSessions_empty() { .isEmpty(); } + @Test + public void listSessions_missingSessionsField_returnsEmpty() { + when(mockApiClient.request("GET", "reasoningEngines/123/sessions?filter=user_id=userX", "")) + .thenAnswer(new MockApiAnswer("{}")); + + assertThat(vertexAiSessionService.listSessions("123", "userX").blockingGet().sessions()) + .isEmpty(); + } + + @Test + public void listSessions_nullSessionsField_returnsEmpty() { + when(mockApiClient.request("GET", "reasoningEngines/123/sessions?filter=user_id=userY", "")) + .thenAnswer(new MockApiAnswer("{\"sessions\": null}")); + + assertThat(vertexAiSessionService.listSessions("123", "userY").blockingGet().sessions()) + .isEmpty(); + } + @Test public void listEvents_empty() { assertThat(vertexAiSessionService.listEvents("789", "user1", "3").blockingGet().events()) diff --git a/core/src/test/java/com/google/adk/telemetry/ContextPropagationTest.java b/core/src/test/java/com/google/adk/telemetry/ContextPropagationTest.java index b13904934..44877e972 100644 --- a/core/src/test/java/com/google/adk/telemetry/ContextPropagationTest.java +++ b/core/src/test/java/com/google/adk/telemetry/ContextPropagationTest.java @@ -18,7 +18,6 @@ import static com.google.common.truth.Truth.assertThat; import static org.junit.Assert.assertEquals; -import static org.junit.Assert.assertFalse; import static org.junit.Assert.assertTrue; import com.google.adk.agents.BaseAgent; @@ -100,242 +99,6 @@ public void tearDown() { Tracing.setTracerForTesting(originalTracer); } - @Test - public void testToolCallSpanLinksToParent() { - // Given: Parent span is active - Span parentSpan = tracer.spanBuilder("parent").startSpan(); - - try (Scope scope = parentSpan.makeCurrent()) { - // When: ADK creates tool_call span with setParent(Context.current()) - Span toolCallSpan = - tracer.spanBuilder("tool_call [testTool]").setParent(Context.current()).startSpan(); - - try (Scope toolScope = toolCallSpan.makeCurrent()) { - // Simulate tool execution - } finally { - toolCallSpan.end(); - } - } finally { - parentSpan.end(); - } - - // Then: tool_call should be child of parent - SpanData parentSpanData = findSpanByName("parent"); - SpanData toolCallSpanData = findSpanByName("tool_call [testTool]"); - - // Verify parent-child relationship - assertEquals( - "Tool call should have same trace ID as parent", - parentSpanData.getSpanContext().getTraceId(), - toolCallSpanData.getSpanContext().getTraceId()); - - assertParent(parentSpanData, toolCallSpanData); - } - - @Test - public void testToolCallWithoutParentCreatesRootSpan() { - // Given: No parent span active - // When: ADK creates tool_call span with setParent(Context.current()) - try (Scope s = Context.root().makeCurrent()) { - Span toolCallSpan = - tracer.spanBuilder("tool_call [testTool]").setParent(Context.current()).startSpan(); - - try (Scope scope = toolCallSpan.makeCurrent()) { - // Work - } finally { - toolCallSpan.end(); - } - } - - // Then: Should create root span (backward compatible) - List spans = openTelemetryRule.getSpans(); - assertThat(spans).hasSize(1); - - SpanData toolCallSpanData = spans.get(0); - assertFalse( - "Tool call should be root span when no parent exists", - toolCallSpanData.getParentSpanContext().isValid()); - } - - @Test - public void testNestedSpanHierarchy() { - // Test: parent → invocation → tool_call → tool_response hierarchy - - Span parentSpan = tracer.spanBuilder("parent").startSpan(); - - try (Scope parentScope = parentSpan.makeCurrent()) { - - Span invocationSpan = - tracer.spanBuilder("invocation").setParent(Context.current()).startSpan(); - - try (Scope invocationScope = invocationSpan.makeCurrent()) { - - Span toolCallSpan = - tracer.spanBuilder("tool_call [testTool]").setParent(Context.current()).startSpan(); - - try (Scope toolScope = toolCallSpan.makeCurrent()) { - - Span toolResponseSpan = - tracer - .spanBuilder("tool_response [testTool]") - .setParent(Context.current()) - .startSpan(); - - toolResponseSpan.end(); - } finally { - toolCallSpan.end(); - } - } finally { - invocationSpan.end(); - } - } finally { - parentSpan.end(); - } - - // Verify complete hierarchy - List spans = openTelemetryRule.getSpans(); - // The 4 spans are: "parent", "invocation", "tool_call [testTool]", and "tool_response - // [testTool]". - assertThat(spans).hasSize(4); - - SpanData parentSpanData = findSpanByName("parent"); - String parentTraceId = parentSpanData.getSpanContext().getTraceId(); - - // All spans should have same trace ID - for (SpanData span : openTelemetryRule.getSpans()) { - assertEquals( - "All spans should be in same trace", parentTraceId, span.getSpanContext().getTraceId()); - } - - // Verify parent-child relationships - SpanData invocationSpanData = findSpanByName("invocation"); - SpanData toolCallSpanData = findSpanByName("tool_call [testTool]"); - SpanData toolResponseSpanData = findSpanByName("tool_response [testTool]"); - - // invocation should be child of parent - assertParent(parentSpanData, invocationSpanData); - - // tool_call should be child of invocation - assertParent(invocationSpanData, toolCallSpanData); - - // tool_response should be child of tool_call - assertParent(toolCallSpanData, toolResponseSpanData); - } - - @Test - public void testMultipleSpansInParallel() { - // Test: Multiple tool calls in parallel should all link to same parent - - Span parentSpan = tracer.spanBuilder("parent").startSpan(); - - try (Scope parentScope = parentSpan.makeCurrent()) { - // Simulate parallel tool calls - Span toolCall1 = - tracer.spanBuilder("tool_call [tool1]").setParent(Context.current()).startSpan(); - Span toolCall2 = - tracer.spanBuilder("tool_call [tool2]").setParent(Context.current()).startSpan(); - Span toolCall3 = - tracer.spanBuilder("tool_call [tool3]").setParent(Context.current()).startSpan(); - - toolCall1.end(); - toolCall2.end(); - toolCall3.end(); - } finally { - parentSpan.end(); - } - - // Verify all tool calls link to same parent - SpanData parentSpanData = findSpanByName("parent"); - String parentTraceId = parentSpanData.getSpanContext().getTraceId(); - - // All tool calls should have same trace ID and parent span ID - List toolCallSpans = - openTelemetryRule.getSpans().stream() - .filter(s -> s.getName().startsWith("tool_call")) - .toList(); - - assertThat(toolCallSpans).hasSize(3); - - toolCallSpans.forEach( - span -> { - assertEquals( - "Tool call should have same trace ID as parent", - parentTraceId, - span.getSpanContext().getTraceId()); - assertParent(parentSpanData, span); - }); - } - - @Test - public void testInvokeAgentSpanLinksToInvocation() { - // Test: invoke_agent span should link to invocation span - - Span invocationSpan = tracer.spanBuilder("invocation").startSpan(); - - try (Scope invocationScope = invocationSpan.makeCurrent()) { - Span invokeAgentSpan = - tracer.spanBuilder("invoke_agent test-agent").setParent(Context.current()).startSpan(); - - try (Scope agentScope = invokeAgentSpan.makeCurrent()) { - // Simulate agent work - } finally { - invokeAgentSpan.end(); - } - } finally { - invocationSpan.end(); - } - - SpanData invocationSpanData = findSpanByName("invocation"); - SpanData invokeAgentSpanData = findSpanByName("invoke_agent test-agent"); - - assertParent(invocationSpanData, invokeAgentSpanData); - } - - @Test - public void testCallLlmSpanLinksToAgentRun() { - // Test: call_llm span should link to agent_run span - - Span invokeAgentSpan = tracer.spanBuilder("invoke_agent test-agent").startSpan(); - - try (Scope agentScope = invokeAgentSpan.makeCurrent()) { - Span callLlmSpan = tracer.spanBuilder("call_llm").setParent(Context.current()).startSpan(); - - try (Scope llmScope = callLlmSpan.makeCurrent()) { - // Simulate LLM call - } finally { - callLlmSpan.end(); - } - } finally { - invokeAgentSpan.end(); - } - - List spans = openTelemetryRule.getSpans(); - assertThat(spans).hasSize(2); - - SpanData invokeAgentSpanData = findSpanByName("invoke_agent test-agent"); - SpanData callLlmSpanData = findSpanByName("call_llm"); - - assertParent(invokeAgentSpanData, callLlmSpanData); - } - - @Test - public void testSpanCreatedWithinParentScopeIsCorrectlyParented() { - // Test: Simulates creating a span within the scope of a parent - - Span parentSpan = tracer.spanBuilder("invocation").startSpan(); - try (Scope scope = parentSpan.makeCurrent()) { - Span agentSpan = tracer.spanBuilder("invoke_agent").setParent(Context.current()).startSpan(); - agentSpan.end(); - } finally { - parentSpan.end(); - } - - SpanData parentSpanData = findSpanByName("invocation"); - SpanData agentSpanData = findSpanByName("invoke_agent"); - - assertParent(parentSpanData, agentSpanData); - } - @Test public void testTraceFlowable() throws InterruptedException { Span parentSpan = tracer.spanBuilder("parent").startSpan(); @@ -475,8 +238,14 @@ public void testTraceAgentInvocation() { public void testTraceToolCall() { Span span = tracer.spanBuilder("test").startSpan(); try (Scope scope = span.makeCurrent()) { - Tracing.traceToolCall( - "tool-name", "tool-description", "tool-type", ImmutableMap.of("arg1", "value1")); + Tracing.traceToolExecution( + span, + "tool-name", + "tool-description", + "tool-type", + ImmutableMap.of("arg1", "value1"), + null, + null); } finally { span.end(); } @@ -513,7 +282,14 @@ public void testTraceToolResponse() { .build()) .build())) .build(); - Tracing.traceToolResponse("event-1", functionResponseEvent); + Tracing.traceToolExecution( + span, + "tool-name", + "tool-description", + "tool-type", + ImmutableMap.of(), + functionResponseEvent, + null); } finally { span.end(); } @@ -524,6 +300,10 @@ public void testTraceToolResponse() { assertEquals("execute_tool", attrs.get(AttributeKey.stringKey("gen_ai.operation.name"))); assertEquals("event-1", attrs.get(AttributeKey.stringKey("gcp.vertex.agent.event_id"))); assertEquals("tool-call-id", attrs.get(AttributeKey.stringKey("gen_ai.tool_call.id"))); + assertEquals("tool-name", attrs.get(AttributeKey.stringKey("gen_ai.tool.name"))); + assertEquals("tool-description", attrs.get(AttributeKey.stringKey("gen_ai.tool.description"))); + assertEquals("tool-type", attrs.get(AttributeKey.stringKey("gen_ai.tool.type"))); + assertEquals("{}", attrs.get(AttributeKey.stringKey("gcp.vertex.agent.tool_call_args"))); assertEquals( "{\"result\":\"tool-result\"}", attrs.get(AttributeKey.stringKey("gcp.vertex.agent.tool_response"))); @@ -550,7 +330,8 @@ public void testTraceCallLlm() { .totalTokenCount(30) .build()) .build(); - Tracing.traceCallLlm(span, buildInvocationContext(), "event-1", llmRequest, llmResponse); + Tracing.traceCallLlm( + span, buildInvocationContext(), "event-1", llmRequest, llmResponse, null); } finally { span.end(); } @@ -559,6 +340,7 @@ public void testTraceCallLlm() { SpanData spanData = spans.get(0); Attributes attrs = spanData.getAttributes(); assertEquals("gcp.vertex.agent", attrs.get(AttributeKey.stringKey("gen_ai.system"))); + assertEquals("call_llm", attrs.get(AttributeKey.stringKey("gen_ai.operation.name"))); assertEquals("gemini-pro", attrs.get(AttributeKey.stringKey("gen_ai.request.model"))); assertEquals( "test-invocation-id", attrs.get(AttributeKey.stringKey("gcp.vertex.agent.invocation_id"))); @@ -581,6 +363,7 @@ public void testTraceSendData() { Span span = tracer.spanBuilder("test").startSpan(); try (Scope scope = span.makeCurrent()) { Tracing.traceSendData( + span, buildInvocationContext(), "event-1", ImmutableList.of(Content.fromParts(Part.fromText("hello")))); @@ -591,6 +374,7 @@ public void testTraceSendData() { assertThat(spans).hasSize(1); SpanData spanData = spans.get(0); Attributes attrs = spanData.getAttributes(); + assertEquals("send_data", attrs.get(AttributeKey.stringKey("gen_ai.operation.name"))); assertEquals( "test-invocation-id", attrs.get(AttributeKey.stringKey("gcp.vertex.agent.invocation_id"))); assertEquals("event-1", attrs.get(AttributeKey.stringKey("gcp.vertex.agent.event_id"))); @@ -687,8 +471,7 @@ public void testAgentWithToolCallTraceHierarchy() throws InterruptedException { // invocation // └── invoke_agent test_agent // ├── call_llm - // │ ├── tool_call [search_flights] - // │ └── tool_response [search_flights] + // │ └── execute_tool [search_flights] // └── call_llm SearchFlightsTool searchFlightsTool = new SearchFlightsTool(); @@ -716,8 +499,7 @@ public void testAgentWithToolCallTraceHierarchy() throws InterruptedException { SpanData invocation = findSpanByName("invocation"); SpanData invokeAgent = findSpanByName("invoke_agent test_agent"); - SpanData toolCall = findSpanByName("tool_call [search_flights]"); - SpanData toolResponse = findSpanByName("tool_response [search_flights]"); + SpanData toolResponse = findSpanByName("execute_tool [search_flights]"); List callLlmSpans = openTelemetryRule.getSpans().stream() .filter(s -> s.getName().equals("call_llm")) @@ -733,12 +515,28 @@ public void testAgentWithToolCallTraceHierarchy() throws InterruptedException { assertParent(invocation, invokeAgent); // ├── call_llm 1 assertParent(invokeAgent, callLlm1); - // │ ├── tool_call [search_flights] - assertParent(callLlm1, toolCall); - // │ └── tool_response [search_flights] + // │ └── execute_tool [search_flights] assertParent(callLlm1, toolResponse); // └── call_llm 2 assertParent(invokeAgent, callLlm2); + + // Assert attributes + assertEquals( + "invoke_agent", + invokeAgent.getAttributes().get(AttributeKey.stringKey("gen_ai.operation.name"))); + assertEquals( + "call_llm", callLlm1.getAttributes().get(AttributeKey.stringKey("gen_ai.operation.name"))); + assertEquals( + "execute_tool", + toolResponse.getAttributes().get(AttributeKey.stringKey("gen_ai.operation.name"))); + assertEquals( + "search_flights", + toolResponse.getAttributes().get(AttributeKey.stringKey("gen_ai.tool.name"))); + assertEquals( + "execute_tool", + toolResponse.getAttributes().get(AttributeKey.stringKey("gen_ai.operation.name"))); + assertEquals( + "call_llm", callLlm2.getAttributes().get(AttributeKey.stringKey("gen_ai.operation.name"))); } @Test @@ -748,8 +546,7 @@ public void testNestedAgentTraceHierarchy() throws InterruptedException { // invocation // └── invoke_agent AgentA // ├── call_llm - // │ ├── tool_call [transfer_to_agent] - // │ └── tool_response [transfer_to_agent] + // │ └── execute_tool [transfer_to_agent] // └── invoke_agent AgentB // └── call_llm TestLlm llm = @@ -776,9 +573,8 @@ public void testNestedAgentTraceHierarchy() throws InterruptedException { SpanData invocation = findSpanByName("invocation"); SpanData agentASpan = findSpanByName("invoke_agent AgentA"); - SpanData toolCall = findSpanByName("tool_call [transfer_to_agent]"); + SpanData executeTool = findSpanByName("execute_tool [transfer_to_agent]"); SpanData agentBSpan = findSpanByName("invoke_agent AgentB"); - SpanData toolResponse = findSpanByName("tool_response [transfer_to_agent]"); List callLlmSpans = openTelemetryRule.getSpans().stream() @@ -792,8 +588,7 @@ public void testNestedAgentTraceHierarchy() throws InterruptedException { assertParent(invocation, agentASpan); assertParent(agentASpan, agentACallLlm1); - assertParent(agentACallLlm1, toolCall); - assertParent(agentACallLlm1, toolResponse); + assertParent(agentACallLlm1, executeTool); assertParent(agentASpan, agentBSpan); assertParent(agentBSpan, agentBCallLlm); } diff --git a/core/src/test/java/com/google/adk/tools/SetModelResponseToolTest.java b/core/src/test/java/com/google/adk/tools/SetModelResponseToolTest.java new file mode 100644 index 000000000..64b600af9 --- /dev/null +++ b/core/src/test/java/com/google/adk/tools/SetModelResponseToolTest.java @@ -0,0 +1,123 @@ +/* + * Copyright 2025 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.google.adk.tools; + +import static com.google.common.truth.Truth.assertThat; +import static org.junit.Assert.assertThrows; + +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; +import com.google.genai.types.FunctionDeclaration; +import com.google.genai.types.Schema; +import java.util.Map; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +@RunWith(JUnit4.class) +public final class SetModelResponseToolTest { + + @Test + public void declaration_returnsCorrectFunctionDeclaration() { + Schema outputSchema = + Schema.builder() + .type("OBJECT") + .properties(ImmutableMap.of("field1", Schema.builder().type("STRING").build())) + .required(ImmutableList.of("field1")) + .build(); + + SetModelResponseTool tool = new SetModelResponseTool(outputSchema); + FunctionDeclaration declaration = tool.declaration().get(); + + assertThat(declaration.name()).hasValue("set_model_response"); + assertThat(declaration.description()).isPresent(); + assertThat(declaration.description().get()).contains("Set your final response"); + assertThat(declaration.parameters()).hasValue(outputSchema); + } + + @Test + public void runAsync_returnsArgs() { + Schema outputSchema = + Schema.builder() + .type("OBJECT") + .properties(ImmutableMap.of("field1", Schema.builder().type("STRING").build())) + .build(); + + SetModelResponseTool tool = new SetModelResponseTool(outputSchema); + Map args = ImmutableMap.of("field1", "value1"); + + Map result = tool.runAsync(args, null).blockingGet(); + + assertThat(result).isEqualTo(args); + } + + @Test + public void runAsync_validatesArgs() { + Schema outputSchema = + Schema.builder() + .type("OBJECT") + .properties(ImmutableMap.of("field1", Schema.builder().type("STRING").build())) + .required(ImmutableList.of("field1")) + .build(); + + SetModelResponseTool tool = new SetModelResponseTool(outputSchema); + Map invalidArgs = ImmutableMap.of("field2", "value2"); + + // Should throw validation error + IllegalArgumentException exception = + assertThrows( + IllegalArgumentException.class, () -> tool.runAsync(invalidArgs, null).blockingGet()); + + assertThat(exception).hasMessageThat().contains("does not match agent output schema"); + } + + @Test + public void runAsync_validatesComplexArgs() { + Schema complexSchema = + Schema.builder() + .type("OBJECT") + .properties( + ImmutableMap.of( + "id", + Schema.builder().type("INTEGER").build(), + "tags", + Schema.builder() + .type("ARRAY") + .items(Schema.builder().type("STRING").build()) + .build(), + "metadata", + Schema.builder() + .type("OBJECT") + .properties(ImmutableMap.of("key", Schema.builder().type("STRING").build())) + .build())) + .required(ImmutableList.of("id", "tags", "metadata")) + .build(); + + SetModelResponseTool tool = new SetModelResponseTool(complexSchema); + Map complexArgs = + ImmutableMap.of( + "id", 123, + "tags", ImmutableList.of("tag1", "tag2"), + "metadata", ImmutableMap.of("key", "value")); + + Map result = tool.runAsync(complexArgs, null).blockingGet(); + + assertThat(result).containsEntry("id", 123); + assertThat(result).containsEntry("tags", ImmutableList.of("tag1", "tag2")); + assertThat(result).containsEntry("metadata", ImmutableMap.of("key", "value")); + } +} diff --git a/core/src/test/java/com/google/adk/tools/mcp/AbstractMcpToolTest.java b/core/src/test/java/com/google/adk/tools/mcp/AbstractMcpToolTest.java new file mode 100644 index 000000000..e8d9ea631 --- /dev/null +++ b/core/src/test/java/com/google/adk/tools/mcp/AbstractMcpToolTest.java @@ -0,0 +1,62 @@ +/* + * Copyright 2026 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.google.adk.tools.mcp; + +import static com.google.common.truth.Truth.assertThat; + +import com.fasterxml.jackson.databind.ObjectMapper; +import com.google.common.collect.ImmutableList; +import io.modelcontextprotocol.spec.McpSchema.CallToolResult; +import io.modelcontextprotocol.spec.McpSchema.TextContent; +import java.util.List; +import java.util.Map; +import org.junit.Before; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +@RunWith(JUnit4.class) +public final class AbstractMcpToolTest { + + private ObjectMapper objectMapper; + + @Before + public void setUp() { + objectMapper = new ObjectMapper(); + } + + @Test + public void testWrapCallResult_success() { + CallToolResult result = + CallToolResult.builder() + .content(ImmutableList.of(new TextContent("success"))) + .isError(false) + .build(); + + Map map = AbstractMcpTool.wrapCallResult(objectMapper, "my_tool", result); + + assertThat(map).containsKey("content"); + List content = (List) map.get("content"); + assertThat(content).hasSize(1); + + Map contentItem = (Map) content.get(0); + assertThat(contentItem).containsEntry("type", "text"); + assertThat(contentItem).containsEntry("text", "success"); + + assertThat(map).containsEntry("isError", false); + } +} diff --git a/pom.xml b/pom.xml index 2b2453228..cdef19f4d 100644 --- a/pom.xml +++ b/pom.xml @@ -54,7 +54,7 @@ 1.51.0 0.17.2 2.47.0 - 1.43.0 + 1.44.0 4.33.5 5.11.4 5.20.0 @@ -65,7 +65,7 @@ 0.18.1 3.41.0 3.9.0 - 1.11.0 + 1.12.2 2.0.17 1.4.5 1.0.0 diff --git a/tutorials/city-time-weather/pom.xml b/tutorials/city-time-weather/pom.xml index aeb110cf6..19ef08a2d 100644 --- a/tutorials/city-time-weather/pom.xml +++ b/tutorials/city-time-weather/pom.xml @@ -36,16 +36,6 @@ com.google.adk google-adk-dev ${project.version} - - - ch.qos.logback - logback-classic - - - - - org.slf4j - slf4j-simple