diff --git a/.release-please-manifest.json b/.release-please-manifest.json index a1961ec9a..09a252282 100644 --- a/.release-please-manifest.json +++ b/.release-please-manifest.json @@ -1,3 +1,3 @@ { - ".": "1.1.0" + ".": "1.2.0" } diff --git a/CHANGELOG.md b/CHANGELOG.md index 43ffbbfae..c9ae0d827 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,5 +1,27 @@ # Changelog +## [1.2.0](https://github.com/google/adk-java/compare/v1.1.0...v1.2.0) (2026-04-24) + + +### Features + +* Add telemetry headers ([4009905](https://github.com/google/adk-java/commit/40099057e2b59f34e868da4c34dcd9c1194b2fde)) +* Adding functionality to support customer content formating ([52323b4](https://github.com/google/adk-java/commit/52323b44c89f233e2dd794aee33df8ba5318790e)) +* Allowing McpAsycToolset Builder to take in a McpSessionManager ([78766c1](https://github.com/google/adk-java/commit/78766c179192ff8e560502e0365b45f87ecac433)) +* Forward state delta from all events to parent session instead of just the last event ([f4cd1b7](https://github.com/google/adk-java/commit/f4cd1b754b62fcbf82da22aabc695911d416e51a)) +* Implement BigQuery auto-schema upgrade and view creation ([14027d1](https://github.com/google/adk-java/commit/14027d1545237675a507706d792825356575f73c)) +* Make BigQueryAgentAnalyticsPlugin state per-invocation ([629c390](https://github.com/google/adk-java/commit/629c390de9ca0ec49cba18a0689d299f9261c1fa)) +* Support ChatCompletionChunk to LlmResponse conversion ([589328e](https://github.com/google/adk-java/commit/589328ea747ad4a994223af5789320e171ea2aa7)) +* Support plugins in Java AgentTool similar to Python's implementation ([02a08a1](https://github.com/google/adk-java/commit/02a08a10f087975491d55a29329d6011362925ce)) + + +### Bug Fixes + +* Allow BuiltInCodeExecutor for Gemini 3 models ([1a3dd61](https://github.com/google/adk-java/commit/1a3dd612217a05e2f8fff69720087ed1136a09ab)) +* Fix ADK Runner race condition for sequential tool execution ([69680bb](https://github.com/google/adk-java/commit/69680bbeae11578199eca4efcaf5ecddea2dd552)) +* Fix ADK Runner race condition for sequential tool execution ([9031cad](https://github.com/google/adk-java/commit/9031cadc0e53cad8e4fe141e1d9d2bb19a431a12)) +* Removing deprecated Optional methods ([8ef99f9](https://github.com/google/adk-java/commit/8ef99f999c11c1dbf3331563a0566e14188a68f2)) + ## [1.1.0](https://github.com/google/adk-java/compare/v1.0.0...v1.1.0) (2026-04-10) diff --git a/README.md b/README.md index a2337bf55..107a6967b 100644 --- a/README.md +++ b/README.md @@ -50,13 +50,13 @@ If you're using Maven, add the following to your dependencies: com.google.adk google-adk - 1.1.0 + 1.2.0 com.google.adk google-adk-dev - 1.1.0 + 1.2.0 ``` diff --git a/a2a/pom.xml b/a2a/pom.xml index a756ac22f..3e0b049d6 100644 --- a/a2a/pom.xml +++ b/a2a/pom.xml @@ -5,7 +5,7 @@ com.google.adk google-adk-parent - 1.1.0 + 1.2.1-SNAPSHOT google-adk-a2a diff --git a/contrib/firestore-session-service/pom.xml b/contrib/firestore-session-service/pom.xml index 264c20eee..121877444 100644 --- a/contrib/firestore-session-service/pom.xml +++ b/contrib/firestore-session-service/pom.xml @@ -20,7 +20,7 @@ com.google.adk google-adk-parent - 1.1.0 + 1.2.1-SNAPSHOT ../../pom.xml diff --git a/contrib/langchain4j/pom.xml b/contrib/langchain4j/pom.xml index a970f7996..b7e4cb56f 100644 --- a/contrib/langchain4j/pom.xml +++ b/contrib/langchain4j/pom.xml @@ -20,7 +20,7 @@ com.google.adk google-adk-parent - 1.1.0 + 1.2.1-SNAPSHOT ../../pom.xml diff --git a/contrib/planners/pom.xml b/contrib/planners/pom.xml index 5666d9cbb..1f9afa17a 100644 --- a/contrib/planners/pom.xml +++ b/contrib/planners/pom.xml @@ -20,7 +20,7 @@ com.google.adk google-adk-parent - 1.1.0 + 1.2.1-SNAPSHOT ../../pom.xml diff --git a/contrib/samples/a2a_basic/pom.xml b/contrib/samples/a2a_basic/pom.xml index e497a2e0d..1e7af90ae 100644 --- a/contrib/samples/a2a_basic/pom.xml +++ b/contrib/samples/a2a_basic/pom.xml @@ -5,7 +5,7 @@ com.google.adk google-adk-samples - 1.1.0 + 1.2.1-SNAPSHOT .. diff --git a/contrib/samples/a2a_server/pom.xml b/contrib/samples/a2a_server/pom.xml index b1ef659b7..b1414eff4 100644 --- a/contrib/samples/a2a_server/pom.xml +++ b/contrib/samples/a2a_server/pom.xml @@ -5,7 +5,7 @@ com.google.adk google-adk-samples - 1.1.0 + 1.2.1-SNAPSHOT .. diff --git a/contrib/samples/configagent/pom.xml b/contrib/samples/configagent/pom.xml index f4a536eca..097323363 100644 --- a/contrib/samples/configagent/pom.xml +++ b/contrib/samples/configagent/pom.xml @@ -5,7 +5,7 @@ com.google.adk google-adk-samples - 1.1.0 + 1.2.1-SNAPSHOT .. diff --git a/contrib/samples/helloworld/pom.xml b/contrib/samples/helloworld/pom.xml index 61dc3b5b7..4e6ad4892 100644 --- a/contrib/samples/helloworld/pom.xml +++ b/contrib/samples/helloworld/pom.xml @@ -20,7 +20,7 @@ com.google.adk google-adk-samples - 1.1.0 + 1.2.1-SNAPSHOT .. diff --git a/contrib/samples/mcpfilesystem/pom.xml b/contrib/samples/mcpfilesystem/pom.xml index a1d7cac53..f4ad43c84 100644 --- a/contrib/samples/mcpfilesystem/pom.xml +++ b/contrib/samples/mcpfilesystem/pom.xml @@ -20,7 +20,7 @@ com.google.adk google-adk-parent - 1.1.0 + 1.2.1-SNAPSHOT ../../.. diff --git a/contrib/samples/pom.xml b/contrib/samples/pom.xml index affd6d9c5..d9ce06aa7 100644 --- a/contrib/samples/pom.xml +++ b/contrib/samples/pom.xml @@ -5,7 +5,7 @@ com.google.adk google-adk-parent - 1.1.0 + 1.2.1-SNAPSHOT ../.. diff --git a/contrib/spring-ai/pom.xml b/contrib/spring-ai/pom.xml index 057e1c9ad..7e8c61a8a 100644 --- a/contrib/spring-ai/pom.xml +++ b/contrib/spring-ai/pom.xml @@ -20,7 +20,7 @@ com.google.adk google-adk-parent - 1.1.0 + 1.2.1-SNAPSHOT ../../pom.xml diff --git a/core/pom.xml b/core/pom.xml index 9430eef4e..53fd51883 100644 --- a/core/pom.xml +++ b/core/pom.xml @@ -20,7 +20,7 @@ com.google.adk google-adk-parent - 1.1.0 + 1.2.1-SNAPSHOT google-adk diff --git a/core/src/main/java/com/google/adk/Version.java b/core/src/main/java/com/google/adk/Version.java index 4c5a6b1a9..2816d6763 100644 --- a/core/src/main/java/com/google/adk/Version.java +++ b/core/src/main/java/com/google/adk/Version.java @@ -22,7 +22,7 @@ */ public final class Version { // Don't touch this, release-please should keep it up to date. - public static final String JAVA_ADK_VERSION = "1.1.0"; // x-release-please-released-version + public static final String JAVA_ADK_VERSION = "1.2.0"; // x-release-please-released-version private Version() {} } diff --git a/core/src/main/java/com/google/adk/codeexecutors/BuiltInCodeExecutor.java b/core/src/main/java/com/google/adk/codeexecutors/BuiltInCodeExecutor.java index 972082dde..ef9078e4d 100644 --- a/core/src/main/java/com/google/adk/codeexecutors/BuiltInCodeExecutor.java +++ b/core/src/main/java/com/google/adk/codeexecutors/BuiltInCodeExecutor.java @@ -43,7 +43,7 @@ public CodeExecutionResult executeCode( /** Pre-process the LLM request for Gemini 2.0+ models to use the code execution tool. */ public void processLlmRequest(LlmRequest.Builder llmRequestBuilder) { LlmRequest llmRequest = llmRequestBuilder.build(); - if (ModelNameUtils.isGemini2Model(llmRequest.model().orElse(null))) { + if (llmRequest.model().map(ModelNameUtils::isGemini2OrAbove).orElse(false)) { GenerateContentConfig.Builder configBuilder = llmRequest.config().map(c -> c.toBuilder()).orElseGet(GenerateContentConfig::builder); ImmutableList.Builder toolsBuilder = ImmutableList.builder(); diff --git a/core/src/main/java/com/google/adk/flows/llmflows/Functions.java b/core/src/main/java/com/google/adk/flows/llmflows/Functions.java index 0b0e5b4d5..bc810f28f 100644 --- a/core/src/main/java/com/google/adk/flows/llmflows/Functions.java +++ b/core/src/main/java/com/google/adk/flows/llmflows/Functions.java @@ -33,7 +33,6 @@ import com.google.adk.tools.BaseTool; import com.google.adk.tools.FunctionTool; import com.google.adk.tools.ToolContext; -import com.google.common.base.VerifyException; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; import com.google.genai.types.Content; @@ -141,9 +140,12 @@ public static Maybe handleFunctionCalls( Map toolConfirmations) { ImmutableList functionCalls = functionCallEvent.functionCalls(); + List validFunctionCalls = new ArrayList<>(); for (FunctionCall functionCall : functionCalls) { if (!tools.containsKey(functionCall.name().get())) { - throw new VerifyException("Tool not found: " + functionCall.name().get()); + logger.warn("Tool not found: {}", functionCall.name().get()); + } else { + validFunctionCalls.add(functionCall); } } @@ -154,10 +156,10 @@ public static Maybe handleFunctionCalls( Observable functionResponseEventsObservable; if (invocationContext.runConfig().toolExecutionMode() == ToolExecutionMode.SEQUENTIAL) { functionResponseEventsObservable = - Observable.fromIterable(functionCalls).concatMapMaybe(functionCallMapper); + Observable.fromIterable(validFunctionCalls).concatMapMaybe(functionCallMapper); } else { functionResponseEventsObservable = - Observable.fromIterable(functionCalls) + Observable.fromIterable(validFunctionCalls) .concatMapEager(call -> functionCallMapper.apply(call).toObservable()); } return functionResponseEventsObservable @@ -209,9 +211,12 @@ public static Maybe handleFunctionCallsLive( Map toolConfirmations) { ImmutableList functionCalls = functionCallEvent.functionCalls(); + List validFunctionCalls = new ArrayList<>(); for (FunctionCall functionCall : functionCalls) { if (!tools.containsKey(functionCall.name().get())) { - throw new VerifyException("Tool not found: " + functionCall.name().get()); + logger.warn("Tool not found: {}", functionCall.name().get()); + } else { + validFunctionCalls.add(functionCall); } } @@ -222,10 +227,10 @@ public static Maybe handleFunctionCallsLive( Observable responseEventsObservable; if (invocationContext.runConfig().toolExecutionMode() == ToolExecutionMode.SEQUENTIAL) { responseEventsObservable = - Observable.fromIterable(functionCalls).concatMapMaybe(functionCallMapper); + Observable.fromIterable(validFunctionCalls).concatMapMaybe(functionCallMapper); } else { responseEventsObservable = - Observable.fromIterable(functionCalls) + Observable.fromIterable(validFunctionCalls) .concatMapEager(call -> functionCallMapper.apply(call).toObservable()); } @@ -238,7 +243,7 @@ public static Maybe handleFunctionCallsLive( if (events.isEmpty()) { return Maybe.empty(); } - return Maybe.just(Functions.mergeParallelFunctionResponseEvents(events).orElse(null)); + return Maybe.fromOptional(Functions.mergeParallelFunctionResponseEvents(events)); }); } diff --git a/core/src/main/java/com/google/adk/models/chat/ChatCompletionsCommon.java b/core/src/main/java/com/google/adk/models/chat/ChatCompletionsCommon.java index cd5b4d7bf..e26546313 100644 --- a/core/src/main/java/com/google/adk/models/chat/ChatCompletionsCommon.java +++ b/core/src/main/java/com/google/adk/models/chat/ChatCompletionsCommon.java @@ -136,7 +136,7 @@ public FunctionCall toFunctionCall(@Nullable String toolCallId) { if (name != null) { fcBuilder.name(name); } - if (arguments != null) { + if (arguments != null && !arguments.isEmpty()) { try { Map args = objectMapper.readValue(arguments, new TypeReference>() {}); diff --git a/core/src/main/java/com/google/adk/models/chat/ChatCompletionsResponse.java b/core/src/main/java/com/google/adk/models/chat/ChatCompletionsResponse.java index c52389aa3..9645016a9 100644 --- a/core/src/main/java/com/google/adk/models/chat/ChatCompletionsResponse.java +++ b/core/src/main/java/com/google/adk/models/chat/ChatCompletionsResponse.java @@ -19,16 +19,26 @@ import com.fasterxml.jackson.annotation.JsonIgnoreProperties; import com.fasterxml.jackson.annotation.JsonInclude; import com.fasterxml.jackson.annotation.JsonProperty; +import com.fasterxml.jackson.core.JsonProcessingException; +import com.fasterxml.jackson.core.type.TypeReference; +import com.fasterxml.jackson.databind.ObjectMapper; import com.google.adk.models.LlmResponse; +import com.google.common.collect.ImmutableList; import com.google.genai.types.Content; import com.google.genai.types.CustomMetadata; import com.google.genai.types.FinishReason; import com.google.genai.types.FinishReason.Known; +import com.google.genai.types.FunctionCall; import com.google.genai.types.GenerateContentResponseUsageMetadata; import com.google.genai.types.Part; -import java.util.ArrayList; +import java.util.HashMap; import java.util.List; +import java.util.Map; +import java.util.Map.Entry; +import java.util.TreeMap; import org.jspecify.annotations.Nullable; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; /** * Data Transfer Objects for Chat Completion and Chat Completion Chunk API responses. @@ -36,10 +46,61 @@ *

See https://developers.openai.com/api/reference/resources/chat */ @JsonIgnoreProperties(ignoreUnknown = true) -final class ChatCompletionsResponse { +public final class ChatCompletionsResponse { private ChatCompletionsResponse() {} + static @Nullable FinishReason mapFinishReason(String reason) { + if (reason == null) { + return null; + } + return switch (reason) { + case "stop", "tool_calls" -> new FinishReason(Known.STOP.toString()); + case "length" -> new FinishReason(Known.MAX_TOKENS.toString()); + case "content_filter" -> new FinishReason(Known.SAFETY.toString()); + default -> new FinishReason(Known.OTHER.toString()); + }; + } + + static @Nullable GenerateContentResponseUsageMetadata mapUsage(Usage usage) { + if (usage == null) { + return null; + } + GenerateContentResponseUsageMetadata.Builder builder = + GenerateContentResponseUsageMetadata.builder(); + if (usage.promptTokens != null) { + builder.promptTokenCount(usage.promptTokens); + } + if (usage.completionTokens != null) { + builder.candidatesTokenCount(usage.completionTokens); + } + if (usage.totalTokens != null) { + builder.totalTokenCount(usage.totalTokens); + } + if (usage.thoughtsTokenCount != null) { + builder.thoughtsTokenCount(usage.thoughtsTokenCount); + } else if (usage.completionTokensDetails != null + && usage.completionTokensDetails.reasoningTokens != null) { + builder.thoughtsTokenCount(usage.completionTokensDetails.reasoningTokens); + } + return builder.build(); + } + + /** + * Maps the chat role string to the model role string. + * + * @param role the chat role string, or {@code null}. + * @return the model role string, or the input role if it doesn't match the assistant role. + */ + static @Nullable String mapRole(@Nullable String role) { + if (role == null) { + return null; + } + return role.equals(ChatCompletionsCommon.ROLE_ASSISTANT) + ? ChatCompletionsCommon.ROLE_MODEL + : role; + } + /** * See * https://developers.openai.com/api/reference/resources/chat#(resource)%20chat.completions%20%3E%20(model)%20chat_completion%20%3E%20(schema) @@ -95,49 +156,10 @@ public LlmResponse toLlmResponse() { builder.usageMetadata(mapUsage(usage)); } - List customMetadataList = buildCustomMetadata(); + ImmutableList customMetadataList = buildCustomMetadata(); return builder.customMetadata(customMetadataList).build(); } - /** - * Maps the finish reason string to a {@link FinishReason}. - * - * @param reason the finish reason string. - * @return the {@link FinishReason}, or {@code null} if the input reason is null. - */ - private @Nullable FinishReason mapFinishReason(String reason) { - if (reason == null) { - return null; - } - return switch (reason) { - case "stop", "tool_calls" -> new FinishReason(Known.STOP.toString()); - case "length" -> new FinishReason(Known.MAX_TOKENS.toString()); - case "content_filter" -> new FinishReason(Known.SAFETY.toString()); - default -> new FinishReason(Known.OTHER.toString()); - }; - } - - private GenerateContentResponseUsageMetadata mapUsage(Usage usage) { - GenerateContentResponseUsageMetadata.Builder builder = - GenerateContentResponseUsageMetadata.builder(); - if (usage.promptTokens != null) { - builder.promptTokenCount(usage.promptTokens); - } - if (usage.completionTokens != null) { - builder.candidatesTokenCount(usage.completionTokens); - } - if (usage.totalTokens != null) { - builder.totalTokenCount(usage.totalTokens); - } - if (usage.thoughtsTokenCount != null) { - builder.thoughtsTokenCount(usage.thoughtsTokenCount); - } else if (usage.completionTokensDetails != null - && usage.completionTokensDetails.reasoningTokens != null) { - builder.thoughtsTokenCount(usage.completionTokensDetails.reasoningTokens); - } - return builder.build(); - } - /** * Maps the chosen completion to a {@link Content} object. * @@ -152,14 +174,8 @@ private Content mapChoiceToContent(@Nullable Choice choice) { return contentBuilder.build(); } - private String mapRole(@Nullable String role) { - return (role != null && role.equals(ChatCompletionsCommon.ROLE_ASSISTANT)) - ? ChatCompletionsCommon.ROLE_MODEL - : role; - } - - private List mapMessageToParts(Message message) { - List parts = new ArrayList<>(); + private ImmutableList mapMessageToParts(Message message) { + ImmutableList.Builder parts = ImmutableList.builder(); if (message.content != null) { parts.add(Part.fromText(message.content)); } @@ -169,18 +185,19 @@ private List mapMessageToParts(Message message) { if (message.toolCalls != null) { parts.addAll(mapToolCallsToParts(message.toolCalls)); } - return parts; + return parts.build(); } - private List mapToolCallsToParts(List toolCalls) { - List parts = new ArrayList<>(); + private ImmutableList mapToolCallsToParts( + List toolCalls) { + ImmutableList.Builder parts = ImmutableList.builder(); for (ChatCompletionsCommon.ToolCall toolCall : toolCalls) { Part part = toolCall.toPart(); if (part != null) { parts.add(part); } } - return parts; + return parts.build(); } /** @@ -188,8 +205,8 @@ private List mapToolCallsToParts(List tool * * @return a list of {@link CustomMetadata}, which will be empty if no relevant fields are set. */ - private List buildCustomMetadata() { - List customMetadataList = new ArrayList<>(); + private ImmutableList buildCustomMetadata() { + ImmutableList.Builder customMetadataList = ImmutableList.builder(); if (id != null) { customMetadataList.add( CustomMetadata.builder() @@ -225,7 +242,7 @@ private List buildCustomMetadata() { .stringValue(serviceTier) .build()); } - return customMetadataList; + return customMetadataList.build(); } } @@ -489,4 +506,342 @@ static class Audio { /** See class definition for more details. */ public String transcript; } + + /** Accumulates chunks into a final response. */ + static class ChatCompletionChunkCollection { + private static final ObjectMapper objectMapper = new ObjectMapper(); + private static final Logger logger = + LoggerFactory.getLogger(ChatCompletionChunkCollection.class); + + private final StringBuilder contentParts = new StringBuilder(); + private final Map toolCallParts = new TreeMap<>(); + private final Map toolCallArgsAccumulator = new HashMap<>(); + private String role = ""; + private String model = ""; + private Usage usage; + private final Map customMetadataMap = new HashMap<>(); + + private ImmutableList getCustomMetadataList() { + ImmutableList.Builder list = ImmutableList.builder(); + for (Entry entry : customMetadataMap.entrySet()) { + list.add( + CustomMetadata.builder().key(entry.getKey()).stringValue(entry.getValue()).build()); + } + return list.build(); + } + + /** + * Processes a single chunk of a chat completion response. + * + * @param chunk the chunk to process, or {@code null}. + * @return a list of {@link LlmResponse} objects generated from this chunk. + */ + public ImmutableList processChunk(ChatCompletionChunk chunk) { + if (chunk == null) { + return ImmutableList.of(); + } + + updateState(chunk); + + ImmutableList.Builder responses = ImmutableList.builder(); + if (chunk.choices == null || chunk.choices.isEmpty()) { + addGenericResponseIfSet(responses); + return responses.build(); + } + + // The ADK only supports n=1 choices. If more than 1 choice is returned, all choices + // after the first will be dropped. + if (chunk.choices.size() > 1) { + logger.error( + "Multiple choices found in streaming response but only the first one will be used."); + } + ChunkChoice choice = chunk.choices.get(0); + + ImmutableList chunkParts = mapDeltaToParts(choice); + + responses.add(buildPartialResponse(chunkParts)); + + if (choice.finishReason != null && !choice.finishReason.isEmpty()) { + responses.add(buildFinalResponse(choice)); + } + + return responses.build(); + } + + /** + * Updates the internal state (model, usage, metadata) from the chunk. + * + * @param chunk the chunk to read from. + */ + private void updateState(ChatCompletionChunk chunk) { + if (chunk.model != null) { + this.model = chunk.model; + } + if (chunk.usage != null) { + this.usage = chunk.usage; + } + + if (chunk.id != null) { + customMetadataMap.put(ChatCompletionsCommon.METADATA_KEY_ID, chunk.id); + } + if (chunk.created != null) { + customMetadataMap.put(ChatCompletionsCommon.METADATA_KEY_CREATED, chunk.created.toString()); + } + if (chunk.object != null) { + customMetadataMap.put(ChatCompletionsCommon.METADATA_KEY_OBJECT, chunk.object); + } + if (chunk.systemFingerprint != null) { + customMetadataMap.put( + ChatCompletionsCommon.METADATA_KEY_SYSTEM_FINGERPRINT, chunk.systemFingerprint); + } + if (chunk.serviceTier != null) { + customMetadataMap.put(ChatCompletionsCommon.METADATA_KEY_SERVICE_TIER, chunk.serviceTier); + } + } + + /** + * Adds a generic response to the list if usage or metadata is set but choices are empty. + * + * @param responses the list to add to. + */ + private void addGenericResponseIfSet(ImmutableList.Builder responses) { + if (this.usage != null || !customMetadataMap.isEmpty()) { + responses.add( + LlmResponse.builder() + .partial(true) + .modelVersion(this.model) + .usageMetadata(mapUsage(this.usage)) + .customMetadata(getCustomMetadataList()) + .build()); + } + } + + /** + * Maps the choice's delta to a list of parts and updates state. + * + * @param choice the choice to map. + * @return a list of {@link Part}s for this chunk. + */ + private ImmutableList mapDeltaToParts(ChunkChoice choice) { + ImmutableList.Builder chunkParts = ImmutableList.builder(); + if (choice.delta != null) { + updateRole(choice.delta.role); + appendContent(choice.delta.content, chunkParts); + appendRefusal(choice.delta.refusal, chunkParts); + appendToolCalls(choice.delta.toolCalls, chunkParts); + } + return chunkParts.build(); + } + + /** + * Updates the accumulated role if the delta contains a valid role. + * + * @param deltaRole the role string from the delta, or {@code null}. + */ + private void updateRole(@Nullable String deltaRole) { + if (deltaRole != null && !deltaRole.isEmpty()) { + String mapped = ChatCompletionsResponse.mapRole(deltaRole); + if (mapped != null) { + this.role = mapped; + } + } + } + + /** + * Appends content to the accumulator and adds it to the chunk parts. + * + * @param content the content string, or {@code null}. + * @param chunkParts the list of parts for this chunk. + */ + private void appendContent(@Nullable String content, ImmutableList.Builder chunkParts) { + if (content != null && !content.isEmpty()) { + contentParts.append(content); + chunkParts.add(Part.fromText(content)); + } + } + + /** + * Appends refusal to the accumulator and adds it to the chunk parts. + * + * @param refusal the refusal string, or {@code null}. + * @param chunkParts the list of parts for this chunk. + */ + private void appendRefusal(@Nullable String refusal, ImmutableList.Builder chunkParts) { + if (refusal != null && !refusal.isEmpty()) { + if (contentParts.length() > 0) { + contentParts.append("\n"); + } + contentParts.append(refusal); + chunkParts.add(Part.fromText(refusal)); + } + } + + /** + * Appends tool calls to the accumulator and adds them to the chunk parts. + * + * @param toolCalls the list of tool calls, or {@code null}. + * @param chunkParts the list of parts for this chunk. + */ + private void appendToolCalls( + @Nullable List toolCalls, + ImmutableList.Builder chunkParts) { + if (toolCalls != null) { + for (ChatCompletionsCommon.ToolCall toolCall : toolCalls) { + Part p = upsertToolCall(toolCall); + if (p != null) { + chunkParts.add(p); + } + } + } + } + + /** + * Builds a partial {@link LlmResponse} for the current chunk parts. + * + * @param chunkParts the parts for this chunk. + * @return the partial response. + */ + private LlmResponse buildPartialResponse(List chunkParts) { + return LlmResponse.builder() + .partial(true) + .content(Content.builder().role(this.role).parts(chunkParts).build()) + .modelVersion(this.model) + .usageMetadata(mapUsage(this.usage)) + .customMetadata(getCustomMetadataList()) + .build(); + } + + /** + * Builds the final {@link LlmResponse} with all accumulated content. + * + * @param choice the choice containing the finish reason. + * @return the final response. + */ + private LlmResponse buildFinalResponse(ChunkChoice choice) { + return LlmResponse.builder() + .content(Content.builder().role(this.role).parts(getContentParts()).build()) + .finishReason(ChatCompletionsResponse.mapFinishReason(choice.finishReason)) + .modelVersion(this.model) + .usageMetadata(mapUsage(this.usage)) + .customMetadata(getCustomMetadataList()) + .build(); + } + + /** + * Upserts a tool call from a chunk into the collection and returns the part for this chunk. + * + * @param toolCall the tool call from the chunk. + * @return the {@link Part} to emit for this chunk, or {@code null} if it cannot be converted. + */ + private Part upsertToolCall(ChatCompletionsCommon.ToolCall toolCall) { + int index = toolCall.index != null ? toolCall.index : toolCallParts.size(); + + initializeToolCallState(index); + updateAccumulatedToolCall(index, toolCall); + + return buildChunkToolCallPart(toolCall); + } + + /** + * Initializes the state for a new tool call index if it doesn't exist. + * + * @param index the index of the tool call. + */ + private void initializeToolCallState(int index) { + if (!toolCallParts.containsKey(index)) { + toolCallParts.put( + index, Part.builder().functionCall(FunctionCall.builder().build()).build()); + toolCallArgsAccumulator.put(index, new StringBuilder()); + } + } + + /** + * Updates the accumulated tool call state with data from the chunk. + * + * @param index the index of the tool call. + * @param toolCall the tool call from the chunk. + */ + private void updateAccumulatedToolCall(int index, ChatCompletionsCommon.ToolCall toolCall) { + Part part = toolCallParts.get(index); + FunctionCall.Builder fcBuilder = + part.functionCall().isPresent() + ? part.functionCall().get().toBuilder() + : FunctionCall.builder(); + + if (toolCall.id != null) { + fcBuilder.id(toolCall.id); + } + + appendFunctionDetails(fcBuilder, toolCall.function, index); + + part = toolCall.applyThoughtSignature(part); + Part updatedPart = part.toBuilder().functionCall(fcBuilder.build()).build(); + toolCallParts.put(index, updatedPart); + } + + private void appendFunctionDetails( + FunctionCall.Builder fcBuilder, ChatCompletionsCommon.Function function, int index) { + if (function == null) { + return; + } + if (function.name != null) { + fcBuilder.name(function.name); + } + if (function.arguments != null) { + toolCallArgsAccumulator.get(index).append(function.arguments); + } + } + + /** + * Builds the {@link Part} for the current chunk's tool call. + * + * @param toolCall the tool call from the chunk. + * @return the {@link Part} for this chunk. + */ + private Part buildChunkToolCallPart(ChatCompletionsCommon.ToolCall toolCall) { + Part chunkPart = toolCall.toPart(); + if (chunkPart == null) { + FunctionCall.Builder chunkFcBuilder = FunctionCall.builder(); + if (toolCall.id != null) { + chunkFcBuilder.id(toolCall.id); + } + chunkPart = Part.builder().functionCall(chunkFcBuilder.build()).build(); + chunkPart = toolCall.applyThoughtSignature(chunkPart); + } + return chunkPart; + } + + private ImmutableList getContentParts() { + ImmutableList.Builder parts = ImmutableList.builder(); + if (contentParts.length() > 0) { + parts.add(Part.fromText(contentParts.toString())); + } + + // If a server sends keys 0 and 2 but not 1 then squash the indices and + // return parts at indices 0 and 1. + ImmutableList sortedKeys = ImmutableList.sortedCopyOf(toolCallParts.keySet()); + + for (int index : sortedKeys) { + Part part = toolCallParts.get(index); + if (part != null && part.functionCall().isPresent()) { + FunctionCall fc = part.functionCall().get(); + StringBuilder argsSb = toolCallArgsAccumulator.get(index); + if (argsSb != null && argsSb.length() > 0) { + try { + Map args = + objectMapper.readValue( + argsSb.toString(), new TypeReference>() {}); + fc = fc.toBuilder().args(args).build(); + part = part.toBuilder().functionCall(fc).build(); + } catch (JsonProcessingException e) { + throw new IllegalArgumentException( + "Failed to parse final tool call arguments: " + argsSb, e); + } + } + } + parts.add(part); + } + return parts.build(); + } + } } diff --git a/core/src/main/java/com/google/adk/plugins/agentanalytics/BatchProcessor.java b/core/src/main/java/com/google/adk/plugins/agentanalytics/BatchProcessor.java index 924ad228e..68d3f4c6d 100644 --- a/core/src/main/java/com/google/adk/plugins/agentanalytics/BatchProcessor.java +++ b/core/src/main/java/com/google/adk/plugins/agentanalytics/BatchProcessor.java @@ -249,7 +249,7 @@ private void populateVector(FieldVector vector, int index, Object value) { @Override public void close() { - if (this.queue != null && !this.queue.isEmpty()) { + while (this.queue != null && !this.queue.isEmpty()) { this.flush(); } if (this.allocator != null) { diff --git a/core/src/main/java/com/google/adk/plugins/agentanalytics/BigQueryAgentAnalyticsPlugin.java b/core/src/main/java/com/google/adk/plugins/agentanalytics/BigQueryAgentAnalyticsPlugin.java index 3c673b140..5f8222e70 100644 --- a/core/src/main/java/com/google/adk/plugins/agentanalytics/BigQueryAgentAnalyticsPlugin.java +++ b/core/src/main/java/com/google/adk/plugins/agentanalytics/BigQueryAgentAnalyticsPlugin.java @@ -17,11 +17,11 @@ package com.google.adk.plugins.agentanalytics; import static com.google.adk.plugins.agentanalytics.BigQueryUtils.createAnalyticsViews; +import static com.google.adk.plugins.agentanalytics.BigQueryUtils.getVersionHeaderValue; import static com.google.adk.plugins.agentanalytics.BigQueryUtils.maybeUpgradeSchema; import static com.google.adk.plugins.agentanalytics.JsonFormatter.convertToJsonNode; import static com.google.adk.plugins.agentanalytics.JsonFormatter.smartTruncate; import static com.google.adk.plugins.agentanalytics.JsonFormatter.toJavaObject; -import static java.util.concurrent.TimeUnit.MILLISECONDS; import com.google.adk.agents.BaseAgent; import com.google.adk.agents.CallbackContext; @@ -41,8 +41,7 @@ import com.google.adk.tools.ToolContext; import com.google.adk.tools.mcp.AbstractMcpTool; import com.google.adk.utils.AgentEnums.AgentOrigin; -import com.google.api.gax.core.FixedCredentialsProvider; -import com.google.api.gax.retrying.RetrySettings; +import com.google.api.gax.rpc.FixedHeaderProvider; import com.google.auth.oauth2.GoogleCredentials; import com.google.cloud.bigquery.BigQuery; import com.google.cloud.bigquery.BigQueryException; @@ -53,11 +52,7 @@ import com.google.cloud.bigquery.Table; import com.google.cloud.bigquery.TableId; import com.google.cloud.bigquery.TableInfo; -import com.google.cloud.bigquery.storage.v1.BigQueryWriteClient; -import com.google.cloud.bigquery.storage.v1.BigQueryWriteSettings; -import com.google.cloud.bigquery.storage.v1.StreamWriter; import com.google.common.annotations.VisibleForTesting; -import com.google.common.base.VerifyException; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; import com.google.genai.types.Content; @@ -70,10 +65,6 @@ import java.util.HashMap; import java.util.Map; import java.util.Optional; -import java.util.concurrent.Executors; -import java.util.concurrent.ScheduledExecutorService; -import java.util.concurrent.ThreadFactory; -import java.util.concurrent.atomic.AtomicLong; import java.util.logging.Level; import java.util.logging.Logger; import org.jspecify.annotations.Nullable; @@ -88,7 +79,6 @@ public class BigQueryAgentAnalyticsPlugin extends BasePlugin { Logger.getLogger(BigQueryAgentAnalyticsPlugin.class.getName()); private static final ImmutableList DEFAULT_AUTH_SCOPES = ImmutableList.of("https://www.googleapis.com/auth/cloud-platform"); - private static final AtomicLong threadCounter = new AtomicLong(0); private static final ImmutableMap HITL_EVENT_TYPES = ImmutableMap.of( "adk_request_credential", @@ -100,11 +90,8 @@ public class BigQueryAgentAnalyticsPlugin extends BasePlugin { private final BigQueryLoggerConfig config; private final BigQuery bigQuery; - private final BigQueryWriteClient writeClient; - private final ScheduledExecutorService executor; private final Object tableEnsuredLock = new Object(); - @VisibleForTesting final BatchProcessor batchProcessor; - @VisibleForTesting final TraceManager traceManager; + private final PluginState state; private volatile boolean tableEnsured = false; public BigQueryAgentAnalyticsPlugin(BigQueryLoggerConfig config) throws IOException { @@ -113,32 +100,20 @@ public BigQueryAgentAnalyticsPlugin(BigQueryLoggerConfig config) throws IOExcept public BigQueryAgentAnalyticsPlugin(BigQueryLoggerConfig config, BigQuery bigQuery) throws IOException { + this(config, bigQuery, new PluginState(config)); + } + + BigQueryAgentAnalyticsPlugin(BigQueryLoggerConfig config, BigQuery bigQuery, PluginState state) { super("bigquery_agent_analytics"); this.config = config; this.bigQuery = bigQuery; - ThreadFactory threadFactory = - r -> new Thread(r, "bq-analytics-plugin-" + threadCounter.getAndIncrement()); - this.executor = Executors.newScheduledThreadPool(1, threadFactory); - this.writeClient = createWriteClient(config); - this.traceManager = createTraceManager(); - - if (config.enabled()) { - StreamWriter writer = createWriter(config); - this.batchProcessor = - new BatchProcessor( - writer, - config.batchSize(), - config.batchFlushInterval(), - config.queueMaxSize(), - executor); - this.batchProcessor.start(); - } else { - this.batchProcessor = null; - } + this.state = state; } private static BigQuery createBigQuery(BigQueryLoggerConfig config) throws IOException { BigQueryOptions.Builder builder = BigQueryOptions.newBuilder(); + builder.setHeaderProvider( + FixedHeaderProvider.create(ImmutableMap.of("user-agent", getVersionHeaderValue()))); if (config.credentials() != null) { builder.setCredentials(config.credentials()); } else { @@ -194,7 +169,7 @@ private void ensureTableExists(BigQuery bigQuery, BigQueryLoggerConfig config) { try { if (config.createViews()) { - var unused = executor.submit(() -> createAnalyticsViews(bigQuery, config)); + var unused = state.getExecutor().submit(() -> createAnalyticsViews(bigQuery, config)); } } catch (RuntimeException e) { logger.log(Level.WARNING, "Failed to create/update BigQuery views for table: " + tableId, e); @@ -209,48 +184,6 @@ private void processBigQueryException(BigQueryException e, String logMessage) { } } - protected BigQueryWriteClient createWriteClient(BigQueryLoggerConfig config) throws IOException { - if (config.credentials() != null) { - return BigQueryWriteClient.create( - BigQueryWriteSettings.newBuilder() - .setCredentialsProvider(FixedCredentialsProvider.create(config.credentials())) - .build()); - } - return BigQueryWriteClient.create(); - } - - protected String getStreamName(BigQueryLoggerConfig config) { - return String.format( - "projects/%s/datasets/%s/tables/%s/streams/_default", - config.projectId(), config.datasetId(), config.tableName()); - } - - protected StreamWriter createWriter(BigQueryLoggerConfig config) { - BigQueryLoggerConfig.RetryConfig retryConfig = config.retryConfig(); - RetrySettings retrySettings = - RetrySettings.newBuilder() - .setMaxAttempts(retryConfig.maxRetries()) - .setInitialRetryDelay( - org.threeten.bp.Duration.ofMillis(retryConfig.initialDelay().toMillis())) - .setRetryDelayMultiplier(retryConfig.multiplier()) - .setMaxRetryDelay(org.threeten.bp.Duration.ofMillis(retryConfig.maxDelay().toMillis())) - .build(); - - String streamName = getStreamName(config); - try { - return StreamWriter.newBuilder(streamName, writeClient) - .setRetrySettings(retrySettings) - .setWriterSchema(BigQuerySchema.getArrowSchema()) - .build(); - } catch (Exception e) { - throw new VerifyException("Failed to create StreamWriter for " + streamName, e); - } - } - - protected TraceManager createTraceManager() { - return new TraceManager(); - } - private void logEvent( String eventType, InvocationContext invocationContext, @@ -265,7 +198,7 @@ private void logEvent( Object content, boolean isContentTruncated, Optional eventData) { - if (!config.enabled() || batchProcessor == null) { + if (!config.enabled()) { return; } if (!config.eventAllowlist().isEmpty() && !config.eventAllowlist().contains(eventType)) { @@ -274,6 +207,23 @@ private void logEvent( if (config.eventDenylist().contains(eventType)) { return; } + if (state.isProcessed(invocationContext.invocationId())) { + return; + } + if (config.contentFormatter() != null && content != null) { + try { + content = config.contentFormatter().apply(content, eventType); + } catch (RuntimeException e) { + + logger.log( + Level.WARNING, + "Failed to format content for invocation ID: " + invocationContext.invocationId(), + e); + content = null; // Fail-closed to avoid leaking unmasked sensitive data + } + } + String invocationId = invocationContext.invocationId(); + BatchProcessor processor = state.getBatchProcessor(invocationId); // Ensure table exists before logging. ensureTableExistsOnce(); // Log common fields @@ -285,10 +235,12 @@ private void logEvent( row.put("invocation_id", invocationContext.invocationId()); row.put("user_id", invocationContext.userId()); // Parse and log content - ParsedContent parsedContent = JsonFormatter.parse(content, config.maxContentLength()); - row.put("content_parts", parsedContent.parts()); - row.put("content", parsedContent.content()); - row.put("is_truncated", isContentTruncated || parsedContent.isTruncated()); + if (content != null) { + ParsedContent parsedContent = JsonFormatter.parse(content, config.maxContentLength()); + row.put("content_parts", parsedContent.parts()); + row.put("content", parsedContent.content()); + row.put("is_truncated", isContentTruncated || parsedContent.isTruncated()); + } EventData data = eventData.orElse(EventData.builder().build()); row.put("status", data.status()); @@ -301,11 +253,12 @@ private void logEvent( row.put("attributes", convertToJsonNode(getAttributes(data, invocationContext))); addTraceDetails(row, invocationContext, eventData); - batchProcessor.append(row); + processor.append(row); } private void addTraceDetails( Map row, InvocationContext invocationContext, Optional eventData) { + TraceManager traceManager = state.getTraceManager(invocationContext.invocationId()); String traceId = eventData .flatMap(EventData::traceIdOverride) @@ -336,7 +289,7 @@ private void addTraceDetails( private Map getAttributes( EventData eventData, InvocationContext invocationContext) { Map attributes = new HashMap<>(eventData.extraAttributes()); - + TraceManager traceManager = state.getTraceManager(invocationContext.invocationId()); attributes.put("root_agent_name", traceManager.getRootAgentName()); eventData.model().ifPresent(m -> attributes.put("model", m)); eventData.modelVersion().ifPresent(mv -> attributes.put("model_version", mv)); @@ -362,7 +315,10 @@ private Map getAttributes( } attributes.put("session_metadata", sessionMeta); } catch (RuntimeException e) { - // Ignore session enrichment errors as in Python. + logger.log( + Level.WARNING, + "Failed to log session metadata for invocation ID: " + invocationContext.invocationId(), + e); } } @@ -375,25 +331,17 @@ private Map getAttributes( @Override public Completable close() { - if (batchProcessor != null) { - batchProcessor.close(); - } - if (writeClient != null) { - writeClient.close(); - } - try { - executor.shutdown(); - if (!executor.awaitTermination(config.shutdownTimeout().toMillis(), MILLISECONDS)) { - executor.shutdownNow(); - } - } catch (InterruptedException e) { - executor.shutdownNow(); - Thread.currentThread().interrupt(); - } + state.close(); return Completable.complete(); } + @VisibleForTesting + PluginState getState() { + return state; + } + private Optional getCompletedEventData(InvocationContext invocationContext) { + TraceManager traceManager = state.getTraceManager(invocationContext.invocationId()); String traceId = traceManager.getTraceId(invocationContext); // Pop the invocation span from the trace manager. Optional popped = traceManager.popSpan(); @@ -426,7 +374,12 @@ public Maybe onUserMessageCallback( InvocationContext invocationContext, Content userMessage) { return Maybe.fromAction( () -> { - traceManager.ensureInvocationSpan(invocationContext); + if (state.isProcessed(invocationContext.invocationId())) { + return; + } + state + .getTraceManager(invocationContext.invocationId()) + .ensureInvocationSpan(invocationContext); logEvent("USER_MESSAGE_RECEIVED", invocationContext, userMessage, Optional.empty()); if (userMessage.parts().isPresent()) { for (Part part : userMessage.parts().get()) { @@ -454,6 +407,9 @@ public Maybe onUserMessageCallback( public Maybe onEventCallback(InvocationContext invocationContext, Event event) { return Maybe.fromAction( () -> { + if (state.isProcessed(invocationContext.invocationId())) { + return; + } EventData.Builder eventDataBuilder = EventData.builder() .setExtraAttributes( @@ -510,9 +466,16 @@ public Maybe onEventCallback(InvocationContext invocationContext, Event e @Override public Maybe beforeRunCallback(InvocationContext invocationContext) { - traceManager.ensureInvocationSpan(invocationContext); return Maybe.fromAction( - () -> logEvent("INVOCATION_STARTING", invocationContext, null, Optional.empty())); + () -> { + if (state.isProcessed(invocationContext.invocationId())) { + return; + } + state + .getTraceManager(invocationContext.invocationId()) + .ensureInvocationSpan(invocationContext); + logEvent("INVOCATION_STARTING", invocationContext, null, Optional.empty()); + }); } @Override @@ -524,8 +487,17 @@ public Completable afterRunCallback(InvocationContext invocationContext) { invocationContext, null, getCompletedEventData(invocationContext)); - batchProcessor.flush(); - traceManager.clearStack(); + // Mark invocation ID as processed to avoid memory leaks. + state.markProcessed(invocationContext.invocationId()); + BatchProcessor processor = state.removeProcessor(invocationContext.invocationId()); + if (processor != null) { + processor.flush(); + processor.close(); + } + TraceManager traceManager = state.removeTraceManager(invocationContext.invocationId()); + if (traceManager != null) { + traceManager.clearStack(); + } }); } @@ -533,7 +505,12 @@ public Completable afterRunCallback(InvocationContext invocationContext) { public Maybe beforeAgentCallback(BaseAgent agent, CallbackContext callbackContext) { return Maybe.fromAction( () -> { - traceManager.pushSpan("agent:" + agent.name()); + if (state.isProcessed(callbackContext.invocationContext().invocationId())) { + return; + } + state + .getTraceManager(callbackContext.invocationContext().invocationId()) + .pushSpan("agent:" + agent.name()); logEvent("AGENT_STARTING", callbackContext.invocationContext(), null, Optional.empty()); }); } @@ -563,6 +540,9 @@ public Maybe beforeModelCallback( CallbackContext callbackContext, LlmRequest.Builder llmRequest) { return Maybe.fromAction( () -> { + if (state.isProcessed(callbackContext.invocationContext().invocationId())) { + return; + } Map attributes = new HashMap<>(); Map llmConfig = new HashMap<>(); LlmRequest req = llmRequest.build(); @@ -622,7 +602,9 @@ public Maybe beforeModelCallback( .setModel(req.model().orElse("")) .setExtraAttributes(attributes) .build(); - traceManager.pushSpan("llm_request"); + state + .getTraceManager(callbackContext.invocationContext().invocationId()) + .pushSpan("llm_request"); logEvent("LLM_REQUEST", callbackContext.invocationContext(), req, Optional.of(eventData)); }); } @@ -632,6 +614,11 @@ public Maybe afterModelCallback( CallbackContext callbackContext, LlmResponse llmResponse) { return Maybe.fromAction( () -> { + if (state.isProcessed(callbackContext.invocationContext().invocationId())) { + return; + } + TraceManager traceManager = + state.getTraceManager(callbackContext.invocationContext().invocationId()); // TODO(b/495809488): Add formatting of the content ParsedContent parsedContent = JsonFormatter.parse(llmResponse.content().orElse(null), config.maxContentLength()); @@ -728,6 +715,11 @@ public Maybe onModelErrorCallback( CallbackContext callbackContext, LlmRequest.Builder llmRequest, Throwable error) { return Maybe.fromAction( () -> { + if (state.isProcessed(callbackContext.invocationContext().invocationId())) { + return; + } + TraceManager traceManager = + state.getTraceManager(callbackContext.invocationContext().invocationId()); InvocationContext invocationContext = callbackContext.invocationContext(); Optional popped = traceManager.popSpan(); String spanId = popped.map(RecordData::spanId).orElse(null); @@ -758,11 +750,14 @@ public Maybe> beforeToolCallback( BaseTool tool, Map toolArgs, ToolContext toolContext) { return Maybe.fromAction( () -> { + if (state.isProcessed(toolContext.invocationContext().invocationId())) { + return; + } TruncationResult res = smartTruncate(toolArgs, config.maxContentLength()); ImmutableMap contentMap = ImmutableMap.of( "tool_origin", getToolOrigin(tool), "tool", tool.name(), "args", res.node()); - traceManager.pushSpan("tool"); + state.getTraceManager(toolContext.invocationContext().invocationId()).pushSpan("tool"); logEvent("TOOL_STARTING", toolContext.invocationContext(), contentMap, Optional.empty()); }); } @@ -775,6 +770,14 @@ public Maybe> afterToolCallback( Map result) { return Maybe.fromAction( () -> { + if (state.isProcessed(toolContext.invocationContext().invocationId())) { + return; + } + state + .getTraceManager(toolContext.invocationContext().invocationId()) + .ensureInvocationSpan(toolContext.invocationContext()); + TraceManager traceManager = + state.getTraceManager(toolContext.invocationContext().invocationId()); Optional popped = traceManager.popSpan(); TruncationResult truncationResult = smartTruncate(result, config.maxContentLength()); ImmutableMap contentMap = @@ -812,6 +815,11 @@ public Maybe> onToolErrorCallback( BaseTool tool, Map toolArgs, ToolContext toolContext, Throwable error) { return Maybe.fromAction( () -> { + if (state.isProcessed(toolContext.invocationContext().invocationId())) { + return; + } + TraceManager traceManager = + state.getTraceManager(toolContext.invocationContext().invocationId()); Optional popped = traceManager.popSpan(); TruncationResult truncationResult = smartTruncate(toolArgs, config.maxContentLength()); String toolOrigin = getToolOrigin(tool); diff --git a/core/src/main/java/com/google/adk/plugins/agentanalytics/BigQueryLoggerConfig.java b/core/src/main/java/com/google/adk/plugins/agentanalytics/BigQueryLoggerConfig.java index ccce8c3bc..b35e7c51d 100644 --- a/core/src/main/java/com/google/adk/plugins/agentanalytics/BigQueryLoggerConfig.java +++ b/core/src/main/java/com/google/adk/plugins/agentanalytics/BigQueryLoggerConfig.java @@ -77,20 +77,32 @@ public abstract class BigQueryLoggerConfig { // Max size of the batch processor queue. public abstract int queueMaxSize(); - // Optional custom formatter for content. - // TODO(b/491852782): Implement content formatter. - @Nullable - public abstract BiFunction contentFormatter(); + /** + * Optional custom formatter for content. + * + *

Allow plugins to modify the content before logging. This is useful for masking sensitive + * data, formatting content, etc. + * + *

The contentFormatter must be thread-safe as it may be called concurrently across + * different agent invocations and fast/non-blocking to avoid adding latency to the agent's + * event processing pipeline. + * + *

Important: To avoid corruption of the logs, the incoming content object should + * not be mutated. Modifying code should return a new copy of the object with + * desired changes. + */ + public abstract @Nullable BiFunction contentFormatter(); + + // GCS bucket name to store multi-modal content. + public abstract String gcsBucketName(); // TODO(b/491852782): Implement connection id. public abstract Optional connectionId(); // Toggle for session metadata (e.g. gchat thread-id). - // TODO(b/491852782): Implement logging of session metadata. public abstract boolean logSessionMetadata(); // Static custom tags (e.g. {"agent_role": "sales"}). - // TODO(b/491852782): Implement custom tags. public abstract ImmutableMap customTags(); // Automatically add new columns to existing tables when the plugin @@ -120,6 +132,7 @@ public static Builder builder() { .tableName("events") .clusteringFields(ImmutableList.of("event_type", "agent", "user_id")) .logMultiModalContent(true) + .gcsBucketName("") .retryConfig(RetryConfig.builder().build()) .batchSize(1) .batchFlushInterval(Duration.ofSeconds(1)) @@ -205,6 +218,9 @@ public abstract Builder contentFormatter( @CanIgnoreReturnValue public abstract Builder viewPrefix(String viewPrefix); + @CanIgnoreReturnValue + public abstract Builder gcsBucketName(String gcsBucketName); + @CanIgnoreReturnValue public abstract Builder credentials(Credentials credentials); diff --git a/core/src/main/java/com/google/adk/plugins/agentanalytics/BigQueryUtils.java b/core/src/main/java/com/google/adk/plugins/agentanalytics/BigQueryUtils.java index 60306799d..f0db45e12 100644 --- a/core/src/main/java/com/google/adk/plugins/agentanalytics/BigQueryUtils.java +++ b/core/src/main/java/com/google/adk/plugins/agentanalytics/BigQueryUtils.java @@ -22,6 +22,7 @@ import static com.google.common.collect.ImmutableMap.toImmutableMap; import static java.util.stream.Collectors.toCollection; +import com.google.adk.Version; import com.google.cloud.bigquery.BigQuery; import com.google.cloud.bigquery.BigQueryException; import com.google.cloud.bigquery.Field; @@ -136,6 +137,13 @@ final class BigQueryUtils { "JSON_QUERY(content, '$.args') AS tool_args")) .buildOrThrow(); + private static final String FRAMEWORK_PREFIX = "google-adk-bq-logger-java"; + + /** Returns the telemetry header value. */ + static String getVersionHeaderValue() { + return FRAMEWORK_PREFIX + "/" + Version.JAVA_ADK_VERSION; + } + /** Creates and/or replaces the analytics views in BigQuery. */ static void createAnalyticsViews(BigQuery bigQuery, BigQueryLoggerConfig config) { for (Map.Entry> entry : EVENT_VIEW_DEFS.entrySet()) { diff --git a/core/src/main/java/com/google/adk/plugins/agentanalytics/PluginState.java b/core/src/main/java/com/google/adk/plugins/agentanalytics/PluginState.java new file mode 100644 index 000000000..63c60c491 --- /dev/null +++ b/core/src/main/java/com/google/adk/plugins/agentanalytics/PluginState.java @@ -0,0 +1,186 @@ +package com.google.adk.plugins.agentanalytics; + +import static com.google.adk.plugins.agentanalytics.BigQueryUtils.getVersionHeaderValue; +import static java.util.concurrent.TimeUnit.MILLISECONDS; + +import com.google.api.gax.core.FixedCredentialsProvider; +import com.google.api.gax.retrying.RetrySettings; +import com.google.api.gax.rpc.FixedHeaderProvider; +import com.google.cloud.bigquery.storage.v1.BigQueryWriteClient; +import com.google.cloud.bigquery.storage.v1.BigQueryWriteSettings; +import com.google.cloud.bigquery.storage.v1.StreamWriter; +import com.google.common.annotations.VisibleForTesting; +import com.google.common.base.VerifyException; +import com.google.common.cache.Cache; +import com.google.common.cache.CacheBuilder; +import com.google.common.collect.ImmutableMap; +import java.io.IOException; +import java.util.Collection; +import java.util.UUID; +import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.Executors; +import java.util.concurrent.ScheduledExecutorService; +import java.util.concurrent.ThreadFactory; +import java.util.concurrent.atomic.AtomicLong; +import java.util.logging.Logger; +import org.threeten.bp.Duration; + +/** Manages state for the BigQueryAgentAnalyticsPlugin. */ +class PluginState { + private static final Logger logger = Logger.getLogger(PluginState.class.getName()); + private final BigQueryLoggerConfig config; + private final ScheduledExecutorService executor; + private final BigQueryWriteClient writeClient; + private static final AtomicLong threadCounter = new AtomicLong(0); + // Map of invocation ID to BatchProcessor. + private final ConcurrentHashMap batchProcessors = + new ConcurrentHashMap<>(); + // Map of invocation ID to TraceManager. + private final ConcurrentHashMap traceManagers = new ConcurrentHashMap<>(); + // Cache of invocation ID to Boolean indicating invocation ID has been processed. + private final Cache processedInvocations; + + PluginState(BigQueryLoggerConfig config) throws IOException { + this.config = config; + ThreadFactory threadFactory = + r -> new Thread(r, "bq-analytics-plugin-" + threadCounter.getAndIncrement()); + this.executor = Executors.newScheduledThreadPool(1, threadFactory); + // One write client per plugin instance, shared by all invocations. + this.writeClient = createWriteClient(config); + this.processedInvocations = + CacheBuilder.newBuilder() + .maximumSize(10000) + .expireAfterWrite(java.time.Duration.ofMinutes(10)) + .build(); + } + + ScheduledExecutorService getExecutor() { + return executor; + } + + boolean isProcessed(String invocationId) { + boolean isProcessed = processedInvocations.getIfPresent(invocationId) != null; + if (isProcessed) { + logger.info("Invocation ID: " + invocationId + " already processed"); + } + return isProcessed; + } + + void markProcessed(String invocationId) { + processedInvocations.put(invocationId, true); + } + + protected BigQueryWriteClient createWriteClient(BigQueryLoggerConfig config) throws IOException { + BigQueryWriteSettings.Builder settingsBuilder = + BigQueryWriteSettings.newBuilder() + .setHeaderProvider( + FixedHeaderProvider.create(ImmutableMap.of("user-agent", getVersionHeaderValue()))); + if (config.credentials() != null) { + settingsBuilder.setCredentialsProvider(FixedCredentialsProvider.create(config.credentials())); + } + return BigQueryWriteClient.create(settingsBuilder.build()); + } + + protected StreamWriter createWriter() { + BigQueryLoggerConfig.RetryConfig retryConfig = config.retryConfig(); + RetrySettings retrySettings = + RetrySettings.newBuilder() + .setMaxAttempts(retryConfig.maxRetries()) + .setInitialRetryDelay(Duration.ofMillis(retryConfig.initialDelay().toMillis())) + .setRetryDelayMultiplier(retryConfig.multiplier()) + .setMaxRetryDelay(Duration.ofMillis(retryConfig.maxDelay().toMillis())) + .build(); + + String streamName = getStreamName(config); + try { + return StreamWriter.newBuilder(streamName, writeClient) + .setTraceId(BigQueryUtils.getVersionHeaderValue() + ":" + UUID.randomUUID()) + .setRetrySettings(retrySettings) + .setWriterSchema(BigQuerySchema.getArrowSchema()) + .build(); + } catch (Exception e) { + throw new VerifyException("Failed to create StreamWriter for " + streamName, e); + } + } + + @VisibleForTesting + String getStreamName(BigQueryLoggerConfig config) { + return String.format( + "projects/%s/datasets/%s/tables/%s/streams/_default", + config.projectId(), config.datasetId(), config.tableName()); + } + + @VisibleForTesting + TraceManager getTraceManager(String invocationId) { + return traceManagers.computeIfAbsent(invocationId, id -> new TraceManager()); + } + + @VisibleForTesting + BatchProcessor getBatchProcessor(String invocationId) { + return batchProcessors.computeIfAbsent( + invocationId, + id -> { + BatchProcessor p = + new BatchProcessor( + createWriter(), + config.batchSize(), + config.batchFlushInterval(), + config.queueMaxSize(), + executor); + p.start(); + return p; + }); + } + + @VisibleForTesting + Collection getTraceManagers() { + return traceManagers.values(); + } + + @VisibleForTesting + Collection getBatchProcessors() { + return batchProcessors.values(); + } + + @VisibleForTesting + TraceManager removeTraceManager(String invocationId) { + return traceManagers.remove(invocationId); + } + + @VisibleForTesting + protected BatchProcessor removeProcessor(String invocationId) { + return batchProcessors.remove(invocationId); + } + + void clearTraceManagers() { + traceManagers.clear(); + } + + void clearBatchProcessors() { + batchProcessors.clear(); + } + + void close() { + for (BatchProcessor processor : getBatchProcessors()) { + processor.close(); + } + for (TraceManager traceManager : getTraceManagers()) { + traceManager.clearStack(); + } + clearBatchProcessors(); + clearTraceManagers(); + + if (writeClient != null) { + writeClient.close(); + } + try { + executor.shutdown(); + if (!executor.awaitTermination(config.shutdownTimeout().toMillis(), MILLISECONDS)) { + executor.shutdownNow(); + } + } catch (InterruptedException e) { + executor.shutdownNow(); + Thread.currentThread().interrupt(); + } + } +} diff --git a/core/src/main/java/com/google/adk/tools/AgentTool.java b/core/src/main/java/com/google/adk/tools/AgentTool.java index 956a8eb51..66d4a2700 100644 --- a/core/src/main/java/com/google/adk/tools/AgentTool.java +++ b/core/src/main/java/com/google/adk/tools/AgentTool.java @@ -26,6 +26,7 @@ import com.google.adk.agents.ConfigAgentUtils.ConfigurationException; import com.google.adk.agents.LlmAgent; import com.google.adk.events.Event; +import com.google.adk.plugins.Plugin; import com.google.adk.runner.InMemoryRunner; import com.google.adk.runner.Runner; import com.google.adk.sessions.State; @@ -46,6 +47,7 @@ public class AgentTool extends BaseTool { private final BaseAgent agent; private final boolean skipSummarization; + private final boolean includePlugins; public static BaseTool fromConfig(ToolArgsConfig args, String configAbsPath) throws ConfigurationException { @@ -62,21 +64,34 @@ public static BaseTool fromConfig(ToolArgsConfig args, String configAbsPath) } BaseAgent agent = resolvedAgents.get(0); - return AgentTool.create(agent, args.getOrDefault("skipSummarization", false).booleanValue()); + return AgentTool.create( + agent, + args.getOrDefault("skipSummarization", false).booleanValue(), + args.getOrDefault("includePlugins", false).booleanValue()); + } + + public static AgentTool create( + BaseAgent agent, boolean skipSummarization, boolean includePlugins) { + return new AgentTool(agent, skipSummarization, includePlugins); } public static AgentTool create(BaseAgent agent, boolean skipSummarization) { - return new AgentTool(agent, skipSummarization); + return new AgentTool(agent, skipSummarization, /* includePlugins= */ false); } public static AgentTool create(BaseAgent agent) { - return new AgentTool(agent, false); + return new AgentTool(agent, /* skipSummarization= */ false, /* includePlugins= */ false); } protected AgentTool(BaseAgent agent, boolean skipSummarization) { + this(agent, skipSummarization, /* includePlugins= */ false); + } + + protected AgentTool(BaseAgent agent, boolean skipSummarization, boolean includePlugins) { super(agent.name(), agent.description()); this.agent = agent; this.skipSummarization = skipSummarization; + this.includePlugins = includePlugins; } @VisibleForTesting @@ -159,13 +174,23 @@ public Single> runAsync(Map args, ToolContex content = Content.fromParts(Part.fromText(input.toString())); } - Runner runner = new InMemoryRunner(this.agent, toolContext.agentName()); - // Session state is final, can't update to toolContext state - // session.toBuilder().setState(toolContext.getState()); + ImmutableList plugins = + this.includePlugins + ? ImmutableList.of(toolContext.invocationContext().pluginManager()) + : ImmutableList.of(); + Runner runner = new InMemoryRunner(this.agent, toolContext.agentName(), plugins); return runner .sessionService() .createSession(toolContext.agentName(), "tmp-user", toolContext.state(), null) .flatMapPublisher(session -> runner.runAsync(session.userId(), session.id(), content)) + .doOnNext( + event -> { + if (event.actions() != null + && event.actions().stateDelta() != null + && !event.actions().stateDelta().isEmpty()) { + updateState(event.actions().stateDelta(), toolContext.state()); + } + }) .lastElement() .map(Optional::of) .defaultIfEmpty(Optional.empty()) @@ -177,13 +202,6 @@ public Single> runAsync(Map args, ToolContex Event lastEvent = optionalLastEvent.get(); Optional outputText = lastEvent.content().map(Content::text); - // Forward state delta to parent session. - if (lastEvent.actions() != null - && lastEvent.actions().stateDelta() != null - && !lastEvent.actions().stateDelta().isEmpty()) { - updateState(lastEvent.actions().stateDelta(), toolContext.state()); - } - if (outputText.isEmpty()) { return ImmutableMap.of(); } diff --git a/core/src/main/java/com/google/adk/tools/mcp/McpAsyncToolset.java b/core/src/main/java/com/google/adk/tools/mcp/McpAsyncToolset.java index 5f4c2164b..cb541eccf 100644 --- a/core/src/main/java/com/google/adk/tools/mcp/McpAsyncToolset.java +++ b/core/src/main/java/com/google/adk/tools/mcp/McpAsyncToolset.java @@ -16,6 +16,8 @@ package com.google.adk.tools.mcp; +import static com.google.common.base.Preconditions.checkNotNull; + import com.fasterxml.jackson.databind.ObjectMapper; import com.google.adk.JsonBaseModel; import com.google.adk.agents.ReadonlyContext; @@ -23,7 +25,6 @@ import com.google.adk.tools.BaseToolset; import com.google.adk.tools.NamedToolPredicate; import com.google.adk.tools.ToolPredicate; -import com.google.common.base.Preconditions; import com.google.common.collect.ImmutableList; import com.google.errorprone.annotations.CanIgnoreReturnValue; import io.modelcontextprotocol.client.McpAsyncClient; @@ -64,21 +65,31 @@ public class McpAsyncToolset implements BaseToolset { private final @Nullable Object toolFilter; private final AtomicReference>> mcpTools = new AtomicReference<>(); + public static Builder builder() { + return new Builder(); + } + /** Builder for McpAsyncToolset */ public static class Builder { - private Object connectionParams = null; + private McpSessionManager mcpSessionManager = null; private ObjectMapper objectMapper = null; private @Nullable Object toolFilter = null; @CanIgnoreReturnValue public Builder connectionParams(ServerParameters connectionParams) { - this.connectionParams = connectionParams; + this.mcpSessionManager = new McpSessionManager(connectionParams); return this; } @CanIgnoreReturnValue public Builder connectionParams(SseServerParameters connectionParams) { - this.connectionParams = connectionParams; + this.mcpSessionManager = new McpSessionManager(connectionParams); + return this; + } + + @CanIgnoreReturnValue + public Builder mcpSessionManager(McpSessionManager mcpSessionManager) { + this.mcpSessionManager = mcpSessionManager; return this; } @@ -90,7 +101,7 @@ public Builder objectMapper(ObjectMapper objectMapper) { @CanIgnoreReturnValue public Builder toolFilter(List toolNames) { - this.toolFilter = new NamedToolPredicate(Preconditions.checkNotNull(toolNames)); + this.toolFilter = new NamedToolPredicate(checkNotNull(toolNames)); return this; } @@ -104,14 +115,8 @@ public McpAsyncToolset build() { if (objectMapper == null) { objectMapper = JsonBaseModel.getMapper(); } - if (connectionParams instanceof ServerParameters setSelectedParams) { - return new McpAsyncToolset(setSelectedParams, objectMapper, toolFilter); - } else if (connectionParams instanceof SseServerParameters sseServerParameters) { - return new McpAsyncToolset(sseServerParameters, objectMapper, toolFilter); - } else { - throw new IllegalArgumentException( - "connectionParams must be either ServerParameters or SseServerParameters"); - } + checkNotNull(mcpSessionManager, "Connection params must be set"); + return new McpAsyncToolset(mcpSessionManager, objectMapper, toolFilter); } } @@ -123,29 +128,11 @@ public McpAsyncToolset build() { * @param toolFilter Either a ToolPredicate or a List of tool names. */ McpAsyncToolset( - SseServerParameters connectionParams, - ObjectMapper objectMapper, - @Nullable Object toolFilter) { - Objects.requireNonNull(connectionParams); - Objects.requireNonNull(objectMapper); - this.objectMapper = objectMapper; - this.mcpSessionManager = new McpSessionManager(connectionParams); - this.toolFilter = toolFilter; - } - - /** - * Initializes the McpAsyncToolset with local server parameters. - * - * @param connectionParams The local server connection parameters to the MCP server. - * @param objectMapper An ObjectMapper instance for parsing schemas. - * @param toolFilter Either a ToolPredicate or a List of tool names or null. - */ - McpAsyncToolset( - ServerParameters connectionParams, ObjectMapper objectMapper, @Nullable Object toolFilter) { - Objects.requireNonNull(connectionParams); + McpSessionManager mcpSessionManager, ObjectMapper objectMapper, @Nullable Object toolFilter) { + Objects.requireNonNull(mcpSessionManager); Objects.requireNonNull(objectMapper); this.objectMapper = objectMapper; - this.mcpSessionManager = new McpSessionManager(connectionParams); + this.mcpSessionManager = mcpSessionManager; this.toolFilter = toolFilter; } diff --git a/core/src/main/java/com/google/adk/utils/ModelNameUtils.java b/core/src/main/java/com/google/adk/utils/ModelNameUtils.java index cf0f2221e..56fd6dd95 100644 --- a/core/src/main/java/com/google/adk/utils/ModelNameUtils.java +++ b/core/src/main/java/com/google/adk/utils/ModelNameUtils.java @@ -20,11 +20,14 @@ import java.util.Objects; import java.util.regex.Matcher; import java.util.regex.Pattern; +import org.jspecify.annotations.Nullable; /** Utility class for model names. */ public final class ModelNameUtils { private static final String GEMINI_PREFIX = "gemini-"; private static final Pattern GEMINI_2_PATTERN = Pattern.compile("^gemini-2\\..*"); + private static final Pattern GEMINI_VERSION_PATTERN = + Pattern.compile("^gemini-(\\d+)(?:\\.(\\d+))?.*"); private static final String GEMINI_CLASS = "com.google.adk.models.Gemini"; private static final Pattern PATH_PATTERN = Pattern.compile("^projects/[^/]+/locations/[^/]+/publishers/[^/]+/models/(.+)$"); @@ -39,6 +42,28 @@ public static boolean isGemini2Model(String modelString) { return matchesModelPattern(modelString, GEMINI_2_PATTERN); } + public static boolean isGemini2OrAbove(@Nullable String modelString) { + return isGeminiVersionOrAbove(modelString, 2, 0); + } + + private static boolean isGeminiVersionOrAbove( + @Nullable String modelString, int minMajor, int minMinor) { + if (modelString == null) { + return false; + } + String modelName = extractModelName(modelString); + Matcher matcher = GEMINI_VERSION_PATTERN.matcher(modelName); + if (matcher.matches()) { + int major = Integer.parseInt(matcher.group(1)); + int minor = matcher.group(2) != null ? Integer.parseInt(matcher.group(2)) : 0; + if (major > minMajor) { + return true; + } + return major == minMajor && minor >= minMinor; + } + return false; + } + private static boolean matchesModelPattern(String modelString, Pattern pattern) { if (modelString == null) { return false; diff --git a/core/src/test/java/com/google/adk/codeexecutors/BuiltInCodeExecutorTest.java b/core/src/test/java/com/google/adk/codeexecutors/BuiltInCodeExecutorTest.java new file mode 100644 index 000000000..e3e4c660c --- /dev/null +++ b/core/src/test/java/com/google/adk/codeexecutors/BuiltInCodeExecutorTest.java @@ -0,0 +1,73 @@ +package com.google.adk.codeexecutors; + +import static com.google.common.truth.Truth.assertThat; +import static org.junit.Assert.assertThrows; + +import com.google.adk.models.LlmRequest; +import com.google.genai.types.Tool; +import java.util.List; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +@RunWith(JUnit4.class) +public class BuiltInCodeExecutorTest { + + @Test + public void executeCode_throwsUnsupportedOperationException() { + BuiltInCodeExecutor executor = new BuiltInCodeExecutor(); + assertThrows(UnsupportedOperationException.class, () -> executor.executeCode(null, null)); + } + + @Test + public void processLlmRequest_withGemini2_addsCodeExecutionTool() { + BuiltInCodeExecutor executor = new BuiltInCodeExecutor(); + LlmRequest.Builder requestBuilder = LlmRequest.builder().model("gemini-2.5-flash"); + + executor.processLlmRequest(requestBuilder); + + List tools = requestBuilder.build().config().get().tools().get(); + assertThat(tools).hasSize(1); + assertThat(tools.get(0).codeExecution()).isPresent(); + } + + @Test + public void processLlmRequest_withGemini3_addsCodeExecutionTool() { + BuiltInCodeExecutor executor = new BuiltInCodeExecutor(); + LlmRequest.Builder requestBuilder = LlmRequest.builder().model("gemini-3.0-pro"); + + executor.processLlmRequest(requestBuilder); + + List tools = requestBuilder.build().config().get().tools().get(); + assertThat(tools).hasSize(1); + assertThat(tools.get(0).codeExecution()).isPresent(); + } + + @Test + public void processLlmRequest_withGemini1_throwsException() { + BuiltInCodeExecutor executor = new BuiltInCodeExecutor(); + LlmRequest.Builder requestBuilder = LlmRequest.builder().model("gemini-1.5-pro"); + + IllegalArgumentException exception = + assertThrows( + IllegalArgumentException.class, () -> executor.processLlmRequest(requestBuilder)); + + assertThat(exception) + .hasMessageThat() + .contains("Gemini code execution tool is not supported for model gemini-1.5-pro"); + } + + @Test + public void processLlmRequest_withoutModel_throwsException() { + BuiltInCodeExecutor executor = new BuiltInCodeExecutor(); + LlmRequest.Builder requestBuilder = LlmRequest.builder(); + + IllegalArgumentException exception = + assertThrows( + IllegalArgumentException.class, () -> executor.processLlmRequest(requestBuilder)); + + assertThat(exception) + .hasMessageThat() + .contains("Gemini code execution tool is not supported for model"); + } +} diff --git a/core/src/test/java/com/google/adk/flows/llmflows/FunctionsTest.java b/core/src/test/java/com/google/adk/flows/llmflows/FunctionsTest.java index d5db4d4b3..1b8de4e4f 100644 --- a/core/src/test/java/com/google/adk/flows/llmflows/FunctionsTest.java +++ b/core/src/test/java/com/google/adk/flows/llmflows/FunctionsTest.java @@ -20,7 +20,6 @@ import static com.google.adk.testing.TestUtils.createInvocationContext; import static com.google.adk.testing.TestUtils.createRootAgent; import static com.google.common.truth.Truth.assertThat; -import static org.junit.Assert.assertThrows; import com.google.adk.agents.InvocationContext; import com.google.adk.agents.RunConfig; @@ -90,11 +89,11 @@ public void handleFunctionCalls_missingTool() { Part.fromText("..."), Part.fromFunctionCall("missing_tool", ImmutableMap.of()))) .build(); - assertThrows( - RuntimeException.class, - () -> - Functions.handleFunctionCalls( - invocationContext, event, /* tools= */ ImmutableMap.of())); + Event functionResponseEvent = + Functions.handleFunctionCalls(invocationContext, event, /* tools= */ ImmutableMap.of()) + .blockingGet(); + + assertThat(functionResponseEvent).isNull(); } @Test diff --git a/core/src/test/java/com/google/adk/models/chat/ChatCompletionsResponseTest.java b/core/src/test/java/com/google/adk/models/chat/ChatCompletionsResponseTest.java index dd1a5d85a..ad1839019 100644 --- a/core/src/test/java/com/google/adk/models/chat/ChatCompletionsResponseTest.java +++ b/core/src/test/java/com/google/adk/models/chat/ChatCompletionsResponseTest.java @@ -22,10 +22,15 @@ import com.google.adk.models.LlmResponse; import com.google.adk.models.chat.ChatCompletionsResponse.ChatCompletion; import com.google.adk.models.chat.ChatCompletionsResponse.ChatCompletionChunk; +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; +import com.google.genai.types.Content; import com.google.genai.types.CustomMetadata; +import com.google.genai.types.FinishReason; import com.google.genai.types.FinishReason.Known; import com.google.genai.types.FunctionCall; import com.google.genai.types.Part; +import java.util.Arrays; import java.util.Base64; import java.util.List; import java.util.Map; @@ -482,7 +487,6 @@ public void testToLlmResponse_thoughtSignature() throws Exception { objectMapper.readValue(json, ChatCompletion.class); LlmResponse response = completion.toLlmResponse(); - assertThat(response.content().get().parts().get().get(0).thoughtSignature().get()) .isEqualTo(Base64.getDecoder().decode("c2ln")); } @@ -646,7 +650,7 @@ public void testToolCallToPart_withThoughtSignature() throws Exception { Part part = toolCall.toPart(); assertThat(part).isNotNull(); - assertThat(part.thoughtSignature().get()).isEqualTo(Base64.getDecoder().decode("c2ln")); + assertThat(part.thoughtSignature()).hasValue(Base64.getDecoder().decode("c2ln")); } @Test @@ -687,4 +691,185 @@ public void testToLlmResponse_noChoices() throws Exception { assertThat(response.content()).isPresent(); assertThat(response.content().get().parts()).isEmpty(); } + + @Test + public void testChunkCollection_accumulatesMultipleToolCalls() throws Exception { + ChatCompletionsResponse.ChatCompletionChunkCollection collection = + new ChatCompletionsResponse.ChatCompletionChunkCollection(); + + String chunk1Json = + """ + {"choices":[{"delta":{"tool_calls":[{"index":0,"id":"call_id_1","type":"function","function":{"name":"roll_die","arguments":""}}]}}]} + """; + String chunk2Json = + """ + {"choices":[{"delta":{"tool_calls":[{"index":0,"function":{"arguments":"{\\\"sides\\\":8}"}}]}}]} + """; + String chunk3Json = + """ + {"choices":[{"delta":{"tool_calls":[{"index":1,"id":"call_id_2","type":"function","function":{"name":"roll_die","arguments":""}}]}}]} + """; + String chunk4Json = + """ + {"choices":[{"delta":{"tool_calls":[{"index":1,"function":{"arguments":"{\\\"sides\\\":8}"}}]}}]} + """; + String chunk5Json = + """ + {"choices":[{"finish_reason":"tool_calls"}]} + """; + + ImmutableList unused1 = + collection.processChunk( + objectMapper.readValue(chunk1Json, ChatCompletionsResponse.ChatCompletionChunk.class)); + ImmutableList unused2 = + collection.processChunk( + objectMapper.readValue(chunk2Json, ChatCompletionsResponse.ChatCompletionChunk.class)); + ImmutableList unused3 = + collection.processChunk( + objectMapper.readValue(chunk3Json, ChatCompletionsResponse.ChatCompletionChunk.class)); + ImmutableList unused4 = + collection.processChunk( + objectMapper.readValue(chunk4Json, ChatCompletionsResponse.ChatCompletionChunk.class)); + ImmutableList responses = + collection.processChunk( + objectMapper.readValue(chunk5Json, ChatCompletionsResponse.ChatCompletionChunk.class)); + + LlmResponse expectedFinalResponse = + LlmResponse.builder() + .content( + Content.builder() + .role("") + .parts( + Arrays.asList( + Part.builder() + .functionCall( + FunctionCall.builder() + .id("call_id_1") + .name("roll_die") + .args(ImmutableMap.of("sides", 8)) + .build()) + .build(), + Part.builder() + .functionCall( + FunctionCall.builder() + .id("call_id_2") + .name("roll_die") + .args(ImmutableMap.of("sides", 8)) + .build()) + .build())) + .build()) + .finishReason(new FinishReason(Known.STOP.toString())) + .customMetadata(ImmutableList.of()) + .modelVersion("") + .build(); + + LlmResponse finalResponse = responses.get(1); + + assertThat(finalResponse).isEqualTo(expectedFinalResponse); + } + + @Test + public void testChunkCollection_simpleText() throws Exception { + ChatCompletionsResponse.ChatCompletionChunkCollection collection = + new ChatCompletionsResponse.ChatCompletionChunkCollection(); + + String chunk1Json = + """ + {"choices":[{"delta":{"content":"Hello "}}]} + """; + String chunk2Json = + """ + {"choices":[{"delta":{"content":"World!"}}]} + """; + String chunk3Json = + """ + {"choices":[{"finish_reason":"stop"}]} + """; + + ImmutableList unused1 = + collection.processChunk( + objectMapper.readValue(chunk1Json, ChatCompletionsResponse.ChatCompletionChunk.class)); + ImmutableList unused2 = + collection.processChunk( + objectMapper.readValue(chunk2Json, ChatCompletionsResponse.ChatCompletionChunk.class)); + ImmutableList responses = + collection.processChunk( + objectMapper.readValue(chunk3Json, ChatCompletionsResponse.ChatCompletionChunk.class)); + + LlmResponse expectedFinalResponse = + LlmResponse.builder() + .content( + Content.builder() + .role("") + .parts(ImmutableList.of(Part.fromText("Hello World!"))) + .build()) + .finishReason(new FinishReason(Known.STOP.toString())) + .customMetadata(ImmutableList.of()) + .modelVersion("") + .build(); + + LlmResponse finalResponse = responses.get(1); + + assertThat(finalResponse).isEqualTo(expectedFinalResponse); + } + + @Test + public void testChunkCollection_withRefusal() throws Exception { + ChatCompletionsResponse.ChatCompletionChunkCollection collection = + new ChatCompletionsResponse.ChatCompletionChunkCollection(); + + String chunk1Json = + """ + {"choices":[{"delta":{"refusal":"I cannot do that."}}]} + """; + String chunk2Json = + """ + {"choices":[{"finish_reason":"stop"}]} + """; + + ImmutableList unused1 = + collection.processChunk( + objectMapper.readValue(chunk1Json, ChatCompletionsResponse.ChatCompletionChunk.class)); + ImmutableList responses = + collection.processChunk( + objectMapper.readValue(chunk2Json, ChatCompletionsResponse.ChatCompletionChunk.class)); + + LlmResponse expectedFinalResponse = + LlmResponse.builder() + .content( + Content.builder() + .role("") + .parts(ImmutableList.of(Part.fromText("I cannot do that."))) + .build()) + .finishReason(new FinishReason(Known.STOP.toString())) + .customMetadata(ImmutableList.of()) + .modelVersion("") + .build(); + + LlmResponse finalResponse = responses.get(1); + + assertThat(finalResponse).isEqualTo(expectedFinalResponse); + } + + @Test + public void testChunkCollection_noChoices() throws Exception { + String json = + """ + { + "id": "chatcmpl-123", + "object": "chat.completion", + "created": 1677652288, + "model": "gpt-4" + } + """; + + ChatCompletionsResponse.ChatCompletion completion = + objectMapper.readValue(json, ChatCompletionsResponse.ChatCompletion.class); + + LlmResponse response = completion.toLlmResponse(); + + assertThat(response.modelVersion()).hasValue("gpt-4"); + assertThat(response.content()).isPresent(); + assertThat(response.content().get().parts()).isEmpty(); + } } diff --git a/core/src/test/java/com/google/adk/plugins/agentanalytics/BigQueryAgentAnalyticsPluginE2ETest.java b/core/src/test/java/com/google/adk/plugins/agentanalytics/BigQueryAgentAnalyticsPluginE2ETest.java index 53faf3329..04d98bf0f 100644 --- a/core/src/test/java/com/google/adk/plugins/agentanalytics/BigQueryAgentAnalyticsPluginE2ETest.java +++ b/core/src/test/java/com/google/adk/plugins/agentanalytics/BigQueryAgentAnalyticsPluginE2ETest.java @@ -19,6 +19,7 @@ import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertFalse; import static org.junit.Assert.assertNotNull; +import static org.junit.Assert.assertNull; import static org.junit.Assert.assertTrue; import static org.mockito.ArgumentMatchers.any; import static org.mockito.Mockito.mock; @@ -63,6 +64,7 @@ public final class BigQueryAgentAnalyticsPluginE2ETest { private StreamWriter mockWriter; private BigQueryWriteClient mockWriteClient; private BigQueryLoggerConfig config; + private PluginState state; private BigQueryAgentAnalyticsPlugin plugin; private Runner runner; private BaseAgent fakeAgent; @@ -92,26 +94,34 @@ public void setUp() throws Exception { when(mockWriter.append(any(ArrowRecordBatch.class))) .thenReturn(ApiFutures.immediateFuture(AppendRowsResponse.getDefaultInstance())); - plugin = - new BigQueryAgentAnalyticsPlugin(config, mockBigQuery) { + state = + new PluginState(config) { @Override protected BigQueryWriteClient createWriteClient(BigQueryLoggerConfig config) { return mockWriteClient; } @Override - protected StreamWriter createWriter(BigQueryLoggerConfig config) { + protected StreamWriter createWriter() { return mockWriter; } + + @Override + protected BatchProcessor removeProcessor(String invocationId) { + return null; + } }; + plugin = new BigQueryAgentAnalyticsPlugin(config, mockBigQuery, state); + when(mockWriter.append(any(ArrowRecordBatch.class))) .thenAnswer( invocation -> { ArrowRecordBatch recordedBatch = invocation.getArgument(0); + BatchProcessor batchProcessor = state.getBatchProcessors().iterator().next(); try (VectorSchemaRoot root = VectorSchemaRoot.create( - BigQuerySchema.getArrowSchema(), plugin.batchProcessor.allocator)) { + BigQuerySchema.getArrowSchema(), batchProcessor.allocator)) { VectorLoader loader = new VectorLoader(root); loader.load(recordedBatch); for (int i = 0; i < root.getRowCount(); i++) { @@ -150,8 +160,9 @@ public void runAgent_logsAgentStartingAndCompleted() throws Exception { // Ensure everything is flushed. The BatchProcessor flushes asynchronously sometimes, // but the direct flush() call should help. We wait up to 2 seconds for all 5 expected events. + BatchProcessor batchProcessor = state.getBatchProcessors().iterator().next(); for (int i = 0; i < 20 && capturedRows.size() < 5; i++) { - plugin.batchProcessor.flush(); + batchProcessor.flush(); if (capturedRows.size() < 5) { Thread.sleep(100); } @@ -190,7 +201,8 @@ public void runAgent_logsAgentStartingAndCompleted() throws Exception { assertEquals("user", agentStartingRow.get("user_id")); assertNotNull("invocation_id should be populated", agentStartingRow.get("invocation_id")); assertTrue("timestamp should be positive", (Long) agentStartingRow.get("timestamp") > 0); - assertEquals(false, agentStartingRow.get("is_truncated")); + // AGENT_STARTING is not a content-bearing event, so is_truncated is not set and should be null. + assertNull(agentStartingRow.get("is_truncated")); // Verify content for USER_MESSAGE_RECEIVED Map userMessageRow = diff --git a/core/src/test/java/com/google/adk/plugins/agentanalytics/BigQueryAgentAnalyticsPluginTest.java b/core/src/test/java/com/google/adk/plugins/agentanalytics/BigQueryAgentAnalyticsPluginTest.java index c7e35e3d6..5a149d3e2 100644 --- a/core/src/test/java/com/google/adk/plugins/agentanalytics/BigQueryAgentAnalyticsPluginTest.java +++ b/core/src/test/java/com/google/adk/plugins/agentanalytics/BigQueryAgentAnalyticsPluginTest.java @@ -62,7 +62,6 @@ import com.google.genai.types.GenerateContentResponse; import com.google.genai.types.GenerateContentResponseUsageMetadata; import com.google.genai.types.Part; -import io.opentelemetry.api.GlobalOpenTelemetry; import io.opentelemetry.api.trace.Span; import io.opentelemetry.api.trace.SpanContext; import io.opentelemetry.api.trace.Tracer; @@ -75,7 +74,12 @@ import java.util.Map; import java.util.Objects; import java.util.Optional; +import java.util.Set; import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Executors; +import java.util.function.BiFunction; import java.util.logging.Handler; import java.util.logging.Level; import java.util.logging.LogRecord; @@ -113,6 +117,7 @@ public class BigQueryAgentAnalyticsPluginTest { private BaseAgent fakeAgent; private BigQueryLoggerConfig config; + private PluginState state; private BigQueryAgentAnalyticsPlugin plugin; private Handler mockHandler; private Tracer tracer; @@ -140,24 +145,21 @@ public void setUp() throws Exception { when(mockWriter.append(any(ArrowRecordBatch.class))) .thenReturn(ApiFutures.immediateFuture(AppendRowsResponse.getDefaultInstance())); - plugin = - new BigQueryAgentAnalyticsPlugin(config, mockBigQuery) { + state = + new PluginState(config) { @Override protected BigQueryWriteClient createWriteClient(BigQueryLoggerConfig config) { return mockWriteClient; } @Override - protected StreamWriter createWriter(BigQueryLoggerConfig config) { + protected StreamWriter createWriter() { return mockWriter; } - - @Override - protected TraceManager createTraceManager() { - return new TraceManager(tracer); - } }; + plugin = new BigQueryAgentAnalyticsPlugin(config, mockBigQuery, state); + Session session = Session.builder("session_id").appName("test_app").userId("test_user").build(); when(mockInvocationContext.session()).thenReturn(session); when(mockInvocationContext.invocationId()).thenReturn("invocation_id"); @@ -183,7 +185,7 @@ public void onUserMessageCallback_appendsToWriter() throws Exception { Content content = Content.builder().build(); plugin.onUserMessageCallback(mockInvocationContext, content).blockingSubscribe(); - plugin.batchProcessor.flush(); + state.getBatchProcessor("invocation_id").flush(); verify(mockWriter, atLeastOnce()).append(any(ArrowRecordBatch.class)); } @@ -191,15 +193,15 @@ public void onUserMessageCallback_appendsToWriter() throws Exception { @Test public void beforeRunCallback_appendsToWriter() throws Exception { plugin.beforeRunCallback(mockInvocationContext).blockingSubscribe(); - plugin.batchProcessor.flush(); + state.getBatchProcessor("invocation_id").flush(); verify(mockWriter, atLeastOnce()).append(any(ArrowRecordBatch.class)); } @Test public void afterRunCallback_flushesAndAppends() throws Exception { + plugin.beforeRunCallback(mockInvocationContext).blockingSubscribe(); plugin.afterRunCallback(mockInvocationContext).blockingSubscribe(); - plugin.batchProcessor.flush(); verify(mockWriter, atLeastOnce()).append(any(ArrowRecordBatch.class)); } @@ -213,7 +215,7 @@ public void getStreamName_returnsCorrectFormat() { .tableName("test-table") .build(); - String streamName = plugin.getStreamName(config); + String streamName = state.getStreamName(config); assertEquals( "projects/test-project/datasets/test-dataset/tables/test-table/streams/_default", @@ -253,7 +255,7 @@ public void onUserMessageCallback_handlesTableCreationFailure() throws Exception // Should not throw exception plugin.onUserMessageCallback(mockInvocationContext, content).blockingSubscribe(); - plugin.batchProcessor.flush(); + state.getBatchProcessor("invocation_id").flush(); ArgumentCaptor captor = ArgumentCaptor.forClass(LogRecord.class); verify(mockHandler, atLeastOnce()).publish(captor.capture()); @@ -280,7 +282,7 @@ public void onUserMessageCallback_handlesAppendFailure() throws Exception { plugin.onUserMessageCallback(mockInvocationContext, content).blockingSubscribe(); // Flush should handle the failed future from writer.append() - plugin.batchProcessor.flush(); + state.getBatchProcessor("invocation_id").flush(); verify(mockWriter, atLeastOnce()).append(any(ArrowRecordBatch.class)); ArgumentCaptor captor = ArgumentCaptor.forClass(LogRecord.class); @@ -350,7 +352,8 @@ public void logEvent_populatesCommonFields() throws Exception { ArrowRecordBatch recordedBatch = invocation.getArgument(0); Schema schema = BigQuerySchema.getArrowSchema(); try (VectorSchemaRoot root = - VectorSchemaRoot.create(schema, plugin.batchProcessor.allocator)) { + VectorSchemaRoot.create( + schema, state.getBatchProcessor("invocation_id").allocator)) { VectorLoader loader = new VectorLoader(root); loader.load(recordedBatch); @@ -411,7 +414,7 @@ public void logEvent_populatesCommonFields() throws Exception { Content content = Content.fromParts(Part.fromText("test message")); plugin.onUserMessageCallback(mockInvocationContext, content).blockingSubscribe(); - plugin.batchProcessor.flush(); + state.getBatchProcessor("invocation_id").flush(); assertTrue(failureMessage[0], checksPassed[0]); } @@ -429,12 +432,12 @@ public void logEvent_populatesTraceDetails() throws Exception { Span mockSpan = Span.wrap(mockSpanContext); try (Scope scope = mockSpan.makeCurrent()) { - plugin.traceManager.attachCurrentSpan(); + state.getTraceManager("invocation_id").attachCurrentSpan(); Content content = Content.builder().build(); plugin.onUserMessageCallback(mockInvocationContext, content).blockingSubscribe(); - Map row = plugin.batchProcessor.queue.poll(); + Map row = state.getBatchProcessor("invocation_id").queue.poll(); assertNotNull("Row not found in queue", row); assertEquals(traceId, row.get("trace_id")); assertEquals(spanId, row.get("span_id")); @@ -447,7 +450,7 @@ public void complexType_appendsToWriter() throws Exception { Content content = Content.fromParts(part); plugin.onUserMessageCallback(mockInvocationContext, content).blockingSubscribe(); - plugin.batchProcessor.flush(); + state.getBatchProcessor("invocation_id").flush(); verify(mockWriter, atLeastOnce()).append(any(ArrowRecordBatch.class)); } @@ -462,7 +465,7 @@ public void onEventCallback_populatesCorrectFields() throws Exception { plugin.onEventCallback(mockInvocationContext, event).blockingSubscribe(); - Map row = plugin.batchProcessor.queue.poll(); + Map row = state.getBatchProcessor("invocation_id").queue.poll(); assertNotNull("Row not found in queue", row); assertEquals("STATE_DELTA", row.get("event_type")); assertEquals("agent_name", row.get("agent")); @@ -479,19 +482,24 @@ public void onModelErrorCallback_populatesCorrectFields() throws Exception { LlmRequest.Builder mockLlmRequestBuilder = mock(LlmRequest.Builder.class); Throwable error = new RuntimeException("model error message"); - plugin.traceManager.pushSpan("llm_request"); + state.getTraceManager("invocation_id").pushSpan("llm_request"); plugin .onModelErrorCallback(mockCallbackContext, mockLlmRequestBuilder, error) .blockingSubscribe(); - Map row = plugin.batchProcessor.queue.poll(); + Map row = plugin.getState().getBatchProcessor("invocation_id").queue.poll(); assertNotNull("Row not found in queue", row); assertEquals("LLM_ERROR", row.get("event_type")); assertEquals("agent_name", row.get("agent")); assertEquals("ERROR", row.get("status")); assertEquals("model error message", row.get("error_message")); assertNotNull(row.get("latency_ms")); - assertEquals(false, row.get("is_truncated")); + assertFalse("Row should not contain content when it is null", row.containsKey("content")); + assertFalse( + "Row should not contain content_parts when it is null", row.containsKey("content_parts")); + assertFalse( + "Row should not contain is_truncated when content is null", + row.containsKey("is_truncated")); } @Test @@ -524,13 +532,13 @@ public void afterModelCallback_populatesCorrectFields() throws Exception { tracer.spanBuilder("ambient").setParent(Context.current().with(parentSpan)).startSpan(); // Set valid ambient span context try (Scope scope = ambientSpan.makeCurrent()) { - plugin.traceManager.pushSpan("parent_request"); - plugin.traceManager.pushSpan("llm_request"); + state.getTraceManager("invocation_id").pushSpan("parent_request"); + state.getTraceManager("invocation_id").pushSpan("llm_request"); plugin.afterModelCallback(mockCallbackContext, adkResponse).blockingSubscribe(); } finally { ambientSpan.end(); } - Map row = plugin.batchProcessor.queue.poll(); + Map row = state.getBatchProcessor("invocation_id").queue.poll(); assertNotNull("Row not found in queue", row); assertEquals("LLM_RESPONSE", row.get("event_type")); ObjectNode contentMap = (ObjectNode) row.get("content"); @@ -562,10 +570,10 @@ public void afterToolCallback_populatesCorrectFields() throws Exception { ImmutableMap toolArgs = ImmutableMap.of("arg1", "value1"); ImmutableMap result = ImmutableMap.of("res1", "value2"); - plugin.traceManager.pushSpan("tool_request"); + state.getTraceManager("invocation_id").pushSpan("tool_request"); plugin.afterToolCallback(mockTool, toolArgs, mockToolContext, result).blockingSubscribe(); - Map row = plugin.batchProcessor.queue.poll(); + Map row = state.getBatchProcessor("invocation_id").queue.poll(); assertNotNull("Row not found in queue", row); assertEquals("TOOL_COMPLETED", row.get("event_type")); assertEquals("agent_name", row.get("agent")); @@ -592,12 +600,12 @@ public AgentOrigin toolOrigin() { AgentTool a2aTool = AgentTool.create(a2aAgent); - plugin.traceManager.pushSpan("tool_request"); + state.getTraceManager("invocation_id").pushSpan("tool_request"); plugin .afterToolCallback(a2aTool, ImmutableMap.of(), mockToolContext, ImmutableMap.of()) .blockingSubscribe(); - Map row = plugin.batchProcessor.queue.poll(); + Map row = state.getBatchProcessor("invocation_id").queue.poll(); assertNotNull(row); ObjectNode contentMap = (ObjectNode) row.get("content"); assertEquals("A2A", contentMap.get("tool_origin").asText()); @@ -609,7 +617,7 @@ public void logEvent_includesSessionMetadata_whenEnabled() throws Exception { Content content = Content.fromParts(Part.fromText("test message")); plugin.onUserMessageCallback(mockInvocationContext, content).blockingSubscribe(); - Map row = plugin.batchProcessor.queue.poll(); + Map row = state.getBatchProcessor("invocation_id").queue.poll(); assertNotNull(row); ObjectNode attributes = (ObjectNode) row.get("attributes"); assertTrue("attributes should contain session_metadata", attributes.has("session_metadata")); @@ -622,32 +630,131 @@ public void logEvent_includesSessionMetadata_whenEnabled() throws Exception { @Test public void logEvent_excludesSessionMetadata_whenDisabled() throws Exception { BigQueryLoggerConfig disabledConfig = config.toBuilder().logSessionMetadata(false).build(); + PluginState disabledState = + new PluginState(disabledConfig) { + @Override + protected BigQueryWriteClient createWriteClient(BigQueryLoggerConfig config) { + return mockWriteClient; + } + + @Override + protected StreamWriter createWriter() { + return mockWriter; + } + }; BigQueryAgentAnalyticsPlugin disabledPlugin = - new BigQueryAgentAnalyticsPlugin(disabledConfig, mockBigQuery) { + new BigQueryAgentAnalyticsPlugin(disabledConfig, mockBigQuery, disabledState); + + Content content = Content.fromParts(Part.fromText("test message")); + disabledPlugin.onUserMessageCallback(mockInvocationContext, content).blockingSubscribe(); + + Map row = disabledState.getBatchProcessor("invocation_id").queue.poll(); + assertNotNull(row); + ObjectNode attributes = (ObjectNode) row.get("attributes"); + assertFalse( + "attributes should not contain session_metadata", attributes.has("session_metadata")); + } + + @Test + public void logEvent_usesContentFormatter_whenConfigured() throws Exception { + BiFunction formatter = + (content, eventType) -> { + if (Objects.equals(eventType, "USER_MESSAGE_RECEIVED") && content instanceof Content) { + return "Formatted: " + content; + } + return content; + }; + + BigQueryLoggerConfig formattedConfig = config.toBuilder().contentFormatter(formatter).build(); + PluginState formattedState = + new PluginState(formattedConfig) { @Override protected BigQueryWriteClient createWriteClient(BigQueryLoggerConfig config) { return mockWriteClient; } @Override - protected StreamWriter createWriter(BigQueryLoggerConfig config) { + protected StreamWriter createWriter() { return mockWriter; } + }; + BigQueryAgentAnalyticsPlugin formattedPlugin = + new BigQueryAgentAnalyticsPlugin(formattedConfig, mockBigQuery, formattedState); + Content content = Content.fromParts(Part.fromText("test message")); + formattedPlugin.onUserMessageCallback(mockInvocationContext, content).blockingSubscribe(); + + Map row = formattedState.getBatchProcessor("invocation_id").queue.poll(); + assertNotNull(row); + assertTrue(row.get("content").toString().contains("Formatted: ")); + } + + @Test + public void logEvent_handlesNullContentFromFormatter() throws Exception { + BiFunction formatter = (content, eventType) -> null; + + BigQueryLoggerConfig formattedConfig = config.toBuilder().contentFormatter(formatter).build(); + PluginState formattedState = + new PluginState(formattedConfig) { @Override - protected TraceManager createTraceManager() { - return new TraceManager(GlobalOpenTelemetry.getTracer("test-plugin-disabled")); + protected BigQueryWriteClient createWriteClient(BigQueryLoggerConfig config) { + return mockWriteClient; + } + + @Override + protected StreamWriter createWriter() { + return mockWriter; } }; + BigQueryAgentAnalyticsPlugin formattedPlugin = + new BigQueryAgentAnalyticsPlugin(formattedConfig, mockBigQuery, formattedState); Content content = Content.fromParts(Part.fromText("test message")); - disabledPlugin.onUserMessageCallback(mockInvocationContext, content).blockingSubscribe(); + formattedPlugin.onUserMessageCallback(mockInvocationContext, content).blockingSubscribe(); - Map row = disabledPlugin.batchProcessor.queue.poll(); + Map row = formattedState.getBatchProcessor("invocation_id").queue.poll(); assertNotNull(row); - ObjectNode attributes = (ObjectNode) row.get("attributes"); assertFalse( - "attributes should not contain session_metadata", attributes.has("session_metadata")); + "Row should not contain content when formatter returns null", row.containsKey("content")); + assertFalse( + "Row should not contain content_parts when formatter returns null", + row.containsKey("content_parts")); + } + + @Test + public void logEvent_handlesExceptionFromFormatter() throws Exception { + BiFunction formatter = + (content, eventType) -> { + throw new RuntimeException("Formatter error"); + }; + + BigQueryLoggerConfig formattedConfig = config.toBuilder().contentFormatter(formatter).build(); + PluginState formattedState = + new PluginState(formattedConfig) { + @Override + protected BigQueryWriteClient createWriteClient(BigQueryLoggerConfig config) { + return mockWriteClient; + } + + @Override + protected StreamWriter createWriter() { + return mockWriter; + } + }; + BigQueryAgentAnalyticsPlugin formattedPlugin = + new BigQueryAgentAnalyticsPlugin(formattedConfig, mockBigQuery, formattedState); + + Content content = Content.fromParts(Part.fromText("test message")); + formattedPlugin.onUserMessageCallback(mockInvocationContext, content).blockingSubscribe(); + + Map row = formattedState.getBatchProcessor("invocation_id").queue.poll(); + assertNotNull(row); + assertFalse( + "Row should not contain content when formatter throws exception", + row.containsKey("content")); + assertFalse( + "Row should not contain content_parts when formatter throws exception", + row.containsKey("content_parts")); } @Test @@ -767,6 +874,100 @@ public void createAnalyticsViews_executesQueries() throws Exception { .anyMatch(q -> q.contains("CREATE OR REPLACE VIEW `project.dataset.v_llm_response`"))); } + @Test + public void multipleInvocations_logsCorrectly() throws Exception { + BigQueryLoggerConfig testConfig = config.toBuilder().batchSize(10).build(); + PluginState testState = + new PluginState(testConfig) { + @Override + protected BigQueryWriteClient createWriteClient(BigQueryLoggerConfig config) { + return mockWriteClient; + } + + @Override + protected StreamWriter createWriter() { + return mockWriter; + } + }; + BigQueryAgentAnalyticsPlugin testPlugin = + new BigQueryAgentAnalyticsPlugin(testConfig, mockBigQuery, testState); + + InvocationContext context1 = mock(InvocationContext.class); + when(context1.invocationId()).thenReturn("inv-1"); + when(context1.agent()).thenReturn(fakeAgent); + when(context1.session()).thenReturn(Session.builder("s1").build()); + + InvocationContext context2 = mock(InvocationContext.class); + when(context2.invocationId()).thenReturn("inv-2"); + when(context2.agent()).thenReturn(fakeAgent); + when(context2.session()).thenReturn(Session.builder("s2").build()); + + var unused1 = testPlugin.beforeRunCallback(context1).blockingGet(); + var unused2 = + testPlugin + .onUserMessageCallback(context1, Content.fromParts(Part.fromText("msg1"))) + .blockingGet(); + + var unused3 = testPlugin.beforeRunCallback(context2).blockingGet(); + var unused4 = + testPlugin + .onUserMessageCallback(context2, Content.fromParts(Part.fromText("msg2"))) + .blockingGet(); + + // Verify processors are created and have correct data in their queues + BatchProcessor p1 = testState.getBatchProcessor("inv-1"); + BatchProcessor p2 = testState.getBatchProcessor("inv-2"); + + assertNotNull("Processor for inv-1 should exist", p1); + assertNotNull("Processor for inv-2 should exist", p2); + assertFalse("Queue for inv-1 should not be empty", p1.queue.isEmpty()); + assertFalse("Queue for inv-2 should not be empty", p2.queue.isEmpty()); + + assertTrue( + "All logs for inv-1 should have correct invocation_id", + p1.queue.stream().allMatch(row -> row.get("invocation_id").equals("inv-1"))); + assertTrue( + "All logs for inv-2 should have correct invocation_id", + p2.queue.stream().allMatch(row -> row.get("invocation_id").equals("inv-2"))); + + // Now flush and verify writer was called + testPlugin.afterRunCallback(context1).blockingAwait(); + testPlugin.afterRunCallback(context2).blockingAwait(); + + verify(mockWriter, atLeastOnce()).append(any(ArrowRecordBatch.class)); + } + + @Test + public void logEvent_createsUniqueProcessorPerInvocation() throws Exception { + int numInvocations = 5; + ExecutorService testExecutor = Executors.newFixedThreadPool(numInvocations); + Set processors = ConcurrentHashMap.newKeySet(); + CountDownLatch latch = new CountDownLatch(numInvocations); + + for (int i = 0; i < numInvocations; i++) { + final String invocationId = "inv-" + i; + testExecutor.execute( + () -> { + try { + InvocationContext context = mock(InvocationContext.class); + when(context.invocationId()).thenReturn(invocationId); + when(context.agent()).thenReturn(fakeAgent); + Session session = Session.builder("s").build(); + when(context.session()).thenReturn(session); + + plugin.beforeRunCallback(context).blockingSubscribe(); + processors.add(state.getBatchProcessor(invocationId)); + } finally { + latch.countDown(); + } + }); + } + + latch.await(); + assertEquals(numInvocations, processors.size()); + testExecutor.shutdown(); + } + private static class FakeAgent extends BaseAgent { FakeAgent(String name) { super(name, "description", null, null, null); diff --git a/core/src/test/java/com/google/adk/tools/AgentToolTest.java b/core/src/test/java/com/google/adk/tools/AgentToolTest.java index 0f168c5df..b37db6611 100644 --- a/core/src/test/java/com/google/adk/tools/AgentToolTest.java +++ b/core/src/test/java/com/google/adk/tools/AgentToolTest.java @@ -28,6 +28,8 @@ import com.google.adk.agents.LlmAgent; import com.google.adk.agents.SequentialAgent; import com.google.adk.models.LlmResponse; +import com.google.adk.plugins.Plugin; +import com.google.adk.plugins.PluginManager; import com.google.adk.sessions.InMemorySessionService; import com.google.adk.sessions.Session; import com.google.adk.testing.TestLlm; @@ -41,6 +43,7 @@ import io.reactivex.rxjava3.core.Flowable; import io.reactivex.rxjava3.core.Maybe; import java.util.Map; +import java.util.concurrent.atomic.AtomicBoolean; import org.junit.Before; import org.junit.Test; import org.junit.runner.RunWith; @@ -495,6 +498,46 @@ public void call_withSkipSummarizationAndStateDelta_propagatesStateAndSetsSkipSu assertThat(toolContext.actions().skipSummarization()).hasValue(true); } + @Test + public void call_withMultipleStateDeltasInResponse_propagatesAllStateDeltas() throws Exception { + AfterAgentCallback firstCallback = + (callbackContext) -> { + callbackContext.state().put("key1", "val1"); + return Maybe.empty(); + }; + AfterAgentCallback secondCallback = + (callbackContext) -> { + callbackContext.state().put("key2", "val2"); + return Maybe.empty(); + }; + LlmAgent firstAgent = + createTestAgentBuilder(createTestLlm(LlmResponse.builder().build())) + .name("first_agent") + .afterAgentCallback(firstCallback) + .build(); + LlmAgent secondAgent = + createTestAgentBuilder(createTestLlm(LlmResponse.builder().build())) + .name("second_agent") + .afterAgentCallback(secondCallback) + .build(); + SequentialAgent sequentialAgent = + SequentialAgent.builder() + .name("sequence") + .description("Process the query through multiple steps") + .subAgents(ImmutableList.of(firstAgent, secondAgent)) + .build(); + ToolContext toolContext = createToolContext(sequentialAgent); + assertThat(toolContext.state()).isEmpty(); + + Map unused = + AgentTool.create(sequentialAgent) + .runAsync(ImmutableMap.of("request", "test"), toolContext) + .blockingGet(); + + assertThat(toolContext.state()).containsEntry("key1", "val1"); + assertThat(toolContext.state()).containsEntry("key2", "val2"); + } + @Test public void declaration_sequentialAgentWithFirstSubAgentInputSchema_returnsDeclarationWithSchema() { @@ -664,6 +707,169 @@ public void declaration_emptySequentialAgent_fallsBackToRequest() { .build()); } + @Test + public void call_withIncludePluginsTrue_propagatesPlugins() throws Exception { + AtomicBoolean callbackCalled = new AtomicBoolean(false); + Plugin mockPlugin = + new Plugin() { + @Override + public String getName() { + return "mock_plugin"; + } + + @Override + public Maybe beforeRunCallback(InvocationContext invocationContext) { + callbackCalled.set(true); + return Maybe.empty(); + } + }; + LlmAgent testAgent = + createTestAgentBuilder(createTestLlm(LlmResponse.builder().build())) + .name("agent_name") + .description("agent description") + .build(); + AgentTool agentTool = + AgentTool.create(testAgent, /* skipSummarization= */ false, /* includePlugins= */ true); + Session session = + sessionService.createSession("test-app", "test-user", null, "test-session").blockingGet(); + InvocationContext invocationContext = + InvocationContext.builder() + .invocationId(InvocationContext.newInvocationContextId()) + .agent(testAgent) + .session(session) + .sessionService(sessionService) + .pluginManager(new PluginManager(ImmutableList.of(mockPlugin))) + .build(); + ToolContext toolContext = ToolContext.builder(invocationContext).build(); + + Map unused = + agentTool.runAsync(ImmutableMap.of("request", "magic"), toolContext).blockingGet(); + + assertThat(callbackCalled.get()).isTrue(); + } + + @Test + public void call_withIncludePluginsFalse_doesNotPropagatePlugins() throws Exception { + AtomicBoolean callbackCalled = new AtomicBoolean(false); + Plugin mockPlugin = + new Plugin() { + @Override + public String getName() { + return "mock_plugin"; + } + + @Override + public Maybe beforeRunCallback(InvocationContext invocationContext) { + callbackCalled.set(true); + return Maybe.empty(); + } + }; + LlmAgent testAgent = + createTestAgentBuilder(createTestLlm(LlmResponse.builder().build())) + .name("agent_name") + .description("agent description") + .build(); + AgentTool agentTool = + AgentTool.create(testAgent, /* skipSummarization= */ false, /* includePlugins= */ false); + Session session = + sessionService.createSession("test-app", "test-user", null, "test-session").blockingGet(); + InvocationContext invocationContext = + InvocationContext.builder() + .invocationId(InvocationContext.newInvocationContextId()) + .agent(testAgent) + .session(session) + .sessionService(sessionService) + .pluginManager(new PluginManager(ImmutableList.of(mockPlugin))) + .build(); + ToolContext toolContext = ToolContext.builder(invocationContext).build(); + + Map unused = + agentTool.runAsync(ImmutableMap.of("request", "magic"), toolContext).blockingGet(); + + assertThat(callbackCalled.get()).isFalse(); + } + + @Test + public void call_createWithAgentOnly_defaultsIncludePluginsToFalse() throws Exception { + AtomicBoolean callbackCalled = new AtomicBoolean(false); + Plugin mockPlugin = + new Plugin() { + @Override + public String getName() { + return "mock_plugin"; + } + + @Override + public Maybe beforeRunCallback(InvocationContext invocationContext) { + callbackCalled.set(true); + return Maybe.empty(); + } + }; + LlmAgent testAgent = + createTestAgentBuilder(createTestLlm(LlmResponse.builder().build())) + .name("agent_name") + .description("agent description") + .build(); + AgentTool agentTool = AgentTool.create(testAgent); + Session session = + sessionService.createSession("test-app", "test-user", null, "test-session").blockingGet(); + InvocationContext invocationContext = + InvocationContext.builder() + .invocationId(InvocationContext.newInvocationContextId()) + .agent(testAgent) + .session(session) + .sessionService(sessionService) + .pluginManager(new PluginManager(ImmutableList.of(mockPlugin))) + .build(); + ToolContext toolContext = ToolContext.builder(invocationContext).build(); + + Map unused = + agentTool.runAsync(ImmutableMap.of("request", "magic"), toolContext).blockingGet(); + + assertThat(callbackCalled.get()).isFalse(); + } + + @Test + public void call_createWithAgentAndSkipSummarization_defaultsIncludePluginsToFalse() + throws Exception { + AtomicBoolean callbackCalled = new AtomicBoolean(false); + Plugin mockPlugin = + new Plugin() { + @Override + public String getName() { + return "mock_plugin"; + } + + @Override + public Maybe beforeRunCallback(InvocationContext invocationContext) { + callbackCalled.set(true); + return Maybe.empty(); + } + }; + LlmAgent testAgent = + createTestAgentBuilder(createTestLlm(LlmResponse.builder().build())) + .name("agent_name") + .description("agent description") + .build(); + AgentTool agentTool = AgentTool.create(testAgent, /* skipSummarization= */ true); + Session session = + sessionService.createSession("test-app", "test-user", null, "test-session").blockingGet(); + InvocationContext invocationContext = + InvocationContext.builder() + .invocationId(InvocationContext.newInvocationContextId()) + .agent(testAgent) + .session(session) + .sessionService(sessionService) + .pluginManager(new PluginManager(ImmutableList.of(mockPlugin))) + .build(); + ToolContext toolContext = ToolContext.builder(invocationContext).build(); + + Map unused = + agentTool.runAsync(ImmutableMap.of("request", "magic"), toolContext).blockingGet(); + + assertThat(callbackCalled.get()).isFalse(); + } + private ToolContext createToolContext(BaseAgent agent) { Session session = sessionService.createSession("test-app", "test-user", null, "test-session").blockingGet(); diff --git a/core/src/test/java/com/google/adk/utils/ModelNameUtilsTest.java b/core/src/test/java/com/google/adk/utils/ModelNameUtilsTest.java index 20dda7034..86bf126f6 100644 --- a/core/src/test/java/com/google/adk/utils/ModelNameUtilsTest.java +++ b/core/src/test/java/com/google/adk/utils/ModelNameUtilsTest.java @@ -71,6 +71,68 @@ public void isGemini2Model_withNullModel_returnsFalse() { assertThat(ModelNameUtils.isGemini2Model(null)).isFalse(); } + @Test + public void isGemini2OrAbove_withGemini3Model_returnsTrue() { + assertThat(ModelNameUtils.isGemini2OrAbove("gemini-3.0-pro")).isTrue(); + } + + @Test + public void isGemini2OrAbove_withGemini2Model_returnsTrue() { + assertThat(ModelNameUtils.isGemini2OrAbove("gemini-2.0-pro")).isTrue(); + } + + @Test + public void isGemini2OrAbove_withGemini25Model_returnsTrue() { + assertThat(ModelNameUtils.isGemini2OrAbove("gemini-2.5-flash")).isTrue(); + } + + @Test + public void isGemini2OrAbove_withGemini1Model_returnsFalse() { + assertThat(ModelNameUtils.isGemini2OrAbove("gemini-1.5-pro")).isFalse(); + } + + @Test + public void isGemini2OrAbove_withInvalid_returnsFalse() { + assertThat(ModelNameUtils.isGemini2OrAbove("???")).isFalse(); + } + + @Test + public void isGemini2OrAbove_withInvalidGemini1Version_returnsFalse() { + assertThat(ModelNameUtils.isGemini2OrAbove("gemini-01")).isFalse(); + } + + @Test + public void isGemini2OrAbove_withPathBasedGemini3Model_returnsTrue() { + assertThat( + ModelNameUtils.isGemini2OrAbove( + "projects/test-project/locations/us-central1/publishers/google/models/gemini-3.0-flash")) + .isTrue(); + } + + @Test + public void isGemini2OrAbove_withPathBasedGemini1Model_returnsFalse() { + assertThat( + ModelNameUtils.isGemini2OrAbove( + "projects/test-project/locations/us-central1/publishers/google/models/gemini-1.5-pro")) + .isFalse(); + } + + @Test + public void isGemini2OrAbove_withApigeeGemini3Model_returnsTrue() { + assertThat(ModelNameUtils.isGemini2OrAbove("apigee/gemini-3.0-flash")).isTrue(); + } + + @Test + public void isGemini2OrAbove_withApigeeProviderV1BetaGemini3Model_returnsTrue() { + assertThat(ModelNameUtils.isGemini2OrAbove("apigee/vertex_ai/v1beta/gemini-3.0-flash")) + .isTrue(); + } + + @Test + public void isGemini2OrAbove_withNullModel_returnsFalse() { + assertThat(ModelNameUtils.isGemini2OrAbove(null)).isFalse(); + } + @Test public void isGeminiModel_withGeminiModel_returnsTrue() { assertThat(ModelNameUtils.isGeminiModel("gemini-1.5-flash")).isTrue(); diff --git a/dev/pom.xml b/dev/pom.xml index c094c2561..bf89e7ca6 100644 --- a/dev/pom.xml +++ b/dev/pom.xml @@ -18,7 +18,7 @@ com.google.adk google-adk-parent - 1.1.0 + 1.2.1-SNAPSHOT google-adk-dev diff --git a/maven_plugin/examples/custom_tools/pom.xml b/maven_plugin/examples/custom_tools/pom.xml index 68978c6be..38bc9b561 100644 --- a/maven_plugin/examples/custom_tools/pom.xml +++ b/maven_plugin/examples/custom_tools/pom.xml @@ -4,7 +4,7 @@ com.example custom-tools-example - 1.1.0 + 1.2.1-SNAPSHOT jar ADK Custom Tools Example diff --git a/maven_plugin/examples/simple-agent/pom.xml b/maven_plugin/examples/simple-agent/pom.xml index f3c7bfd97..c713f525d 100644 --- a/maven_plugin/examples/simple-agent/pom.xml +++ b/maven_plugin/examples/simple-agent/pom.xml @@ -4,7 +4,7 @@ com.example simple-adk-agent - 1.1.0 + 1.2.1-SNAPSHOT jar Simple ADK Agent Example diff --git a/maven_plugin/pom.xml b/maven_plugin/pom.xml index 071959597..f87df835d 100644 --- a/maven_plugin/pom.xml +++ b/maven_plugin/pom.xml @@ -5,7 +5,7 @@ com.google.adk google-adk-parent - 1.1.0 + 1.2.1-SNAPSHOT ../pom.xml diff --git a/pom.xml b/pom.xml index ec2a00ae6..6f6837df5 100644 --- a/pom.xml +++ b/pom.xml @@ -17,7 +17,7 @@ com.google.adk google-adk-parent - 1.1.0 + 1.2.1-SNAPSHOT pom Google Agent Development Kit Maven Parent POM diff --git a/tutorials/city-time-weather/pom.xml b/tutorials/city-time-weather/pom.xml index 7b668bb71..f63dc96a8 100644 --- a/tutorials/city-time-weather/pom.xml +++ b/tutorials/city-time-weather/pom.xml @@ -20,7 +20,7 @@ com.google.adk google-adk-parent - 1.1.0 + 1.2.1-SNAPSHOT ../../pom.xml diff --git a/tutorials/live-audio-single-agent/pom.xml b/tutorials/live-audio-single-agent/pom.xml index c7c46bbd9..3c4475b6a 100644 --- a/tutorials/live-audio-single-agent/pom.xml +++ b/tutorials/live-audio-single-agent/pom.xml @@ -20,7 +20,7 @@ com.google.adk google-adk-parent - 1.1.0 + 1.2.1-SNAPSHOT ../../pom.xml