From 87524252efd363bf82fdae2f7e6880f8c48e1ada Mon Sep 17 00:00:00 2001 From: adilburaksen Date: Mon, 1 Jun 2026 14:37:08 +0300 Subject: [PATCH] fix(skills): prevent path traversal in LocalSkillSource Add input validation to LocalSkillSource to ensure skill names and resource paths cannot escape the skills base directory via path traversal sequences (e.g. "../../../etc/passwd") or absolute paths (e.g. "/etc/passwd"). The new validatePathWithinBase() helper normalizes and resolves each caller-supplied path component against its parent directory, then checks that the result still starts with that parent. This mirrors the boundary check already present in the Go implementation (filesystem_source.go). Affected methods: findResourcePath, listResources, findSkillMdPath. Corresponding tests added for all traversal and absolute-path cases. --- .../java/com/google/adk/agents/BaseAgent.java | 27 +- .../java/com/google/adk/agents/LoopAgent.java | 24 -- .../com/google/adk/agents/ParallelAgent.java | 32 -- .../google/adk/agents/SequentialAgent.java | 27 +- .../adk/flows/llmflows/BaseLlmFlow.java | 90 ++--- .../google/adk/flows/llmflows/Functions.java | 36 +- .../BigQueryAgentAnalyticsPlugin.java | 5 +- .../plugins/agentanalytics/GcsOffloader.java | 94 ----- .../plugins/agentanalytics/JsonFormatter.java | 86 ++--- .../agentanalytics/MimeTypeMapper.java | 60 ---- .../adk/plugins/agentanalytics/Parser.java | 142 +------- .../plugins/agentanalytics/PluginState.java | 71 +--- .../java/com/google/adk/runner/Runner.java | 57 +-- .../adk/skills/AbstractSkillSource.java | 24 +- .../adk/skills/InMemorySkillSource.java | 15 +- .../google/adk/skills/LocalSkillSource.java | 63 ++-- .../adk/skills/SkillSourceException.java | 36 +- .../google/adk/telemetry/Instrumentation.java | 31 +- .../com/google/adk/tools/BaseToolset.java | 10 +- .../tools/mcp/DefaultMcpTransportBuilder.java | 62 +--- .../adk/tools/skills/ListSkillsTool.java | 59 ---- .../tools/skills/LoadSkillResourceTool.java | 239 ------------- .../adk/tools/skills/LoadSkillTool.java | 97 ----- .../google/adk/tools/skills/SkillToolset.java | 130 ------- .../com/google/adk/agents/BaseAgentTest.java | 84 +---- .../com/google/adk/agents/LlmAgentTest.java | 4 +- .../adk/flows/llmflows/BaseLlmFlowTest.java | 113 ------ .../BigQueryAgentAnalyticsPluginTest.java | 133 ------- .../agentanalytics/GcsOffloaderTest.java | 122 ------- .../agentanalytics/JsonFormatterTest.java | 90 +---- .../agentanalytics/MimeTypeMapperTest.java | 52 --- .../plugins/agentanalytics/ParserTest.java | 8 +- .../agentanalytics/PluginStateTest.java | 64 +--- .../com/google/adk/runner/RunnerTest.java | 142 +------- .../adk/skills/LocalSkillSourceTest.java | 85 +---- .../adk/telemetry/ContextPropagationTest.java | 19 +- .../com/google/adk/testing/TestCallback.java | 4 +- .../java/com/google/adk/testing/TestLlm.java | 2 +- .../mcp/DefaultMcpTransportBuilderTest.java | 274 --------------- .../adk/tools/skills/ListSkillsToolTest.java | 151 -------- .../skills/LoadSkillResourceToolTest.java | 330 ------------------ .../adk/tools/skills/LoadSkillToolTest.java | 162 --------- .../adk/tools/skills/SkillToolsetTest.java | 143 -------- 43 files changed, 228 insertions(+), 3271 deletions(-) delete mode 100644 core/src/main/java/com/google/adk/plugins/agentanalytics/GcsOffloader.java delete mode 100644 core/src/main/java/com/google/adk/plugins/agentanalytics/MimeTypeMapper.java delete mode 100644 core/src/main/java/com/google/adk/tools/skills/ListSkillsTool.java delete mode 100644 core/src/main/java/com/google/adk/tools/skills/LoadSkillResourceTool.java delete mode 100644 core/src/main/java/com/google/adk/tools/skills/LoadSkillTool.java delete mode 100644 core/src/main/java/com/google/adk/tools/skills/SkillToolset.java delete mode 100644 core/src/test/java/com/google/adk/plugins/agentanalytics/GcsOffloaderTest.java delete mode 100644 core/src/test/java/com/google/adk/plugins/agentanalytics/MimeTypeMapperTest.java delete mode 100644 core/src/test/java/com/google/adk/tools/mcp/DefaultMcpTransportBuilderTest.java delete mode 100644 core/src/test/java/com/google/adk/tools/skills/ListSkillsToolTest.java delete mode 100644 core/src/test/java/com/google/adk/tools/skills/LoadSkillResourceToolTest.java delete mode 100644 core/src/test/java/com/google/adk/tools/skills/LoadSkillToolTest.java delete mode 100644 core/src/test/java/com/google/adk/tools/skills/SkillToolsetTest.java diff --git a/core/src/main/java/com/google/adk/agents/BaseAgent.java b/core/src/main/java/com/google/adk/agents/BaseAgent.java index 5b154862e..cbceceed2 100644 --- a/core/src/main/java/com/google/adk/agents/BaseAgent.java +++ b/core/src/main/java/com/google/adk/agents/BaseAgent.java @@ -24,8 +24,7 @@ import com.google.adk.agents.Callbacks.BeforeAgentCallback; import com.google.adk.events.Event; import com.google.adk.plugins.Plugin; -import com.google.adk.telemetry.Instrumentation; -import com.google.adk.telemetry.Instrumentation.AgentInvocation; +import com.google.adk.telemetry.Tracing; import com.google.adk.utils.AgentEnums.AgentOrigin; import com.google.common.collect.ImmutableList; import com.google.errorprone.annotations.CanIgnoreReturnValue; @@ -323,13 +322,11 @@ public Flowable runAsync(InvocationContext parentContext) { private Flowable run( InvocationContext parentContext, Function> runImplementation) { - Context otelContext = Context.current(); - return Flowable.using( - () -> - Instrumentation.recordAgentInvocation( - createInvocationContext(parentContext), this, otelContext), - agentInvocation -> { - InvocationContext invocationContext = agentInvocation.getCtx(); + Context parentSpanContext = Context.current(); + return Flowable.defer( + () -> { + InvocationContext invocationContext = createInvocationContext(parentContext); + Flowable mainAndAfterEvents = Flowable.defer(() -> runImplementation.apply(invocationContext)) .concatWith( @@ -353,10 +350,14 @@ private Flowable run( return Flowable.just(beforeEvent).concatWith(mainAndAfterEvents); }) .switchIfEmpty(mainAndAfterEvents) - .doOnNext(agentInvocation::addEvent) - .doOnError(agentInvocation::setError); - }, - AgentInvocation::close); + .compose( + Tracing.trace("invoke_agent " + name()) + .setParent(parentSpanContext) + .configure( + span -> + Tracing.traceAgentInvocation( + span, name(), description(), invocationContext))); + }); } /** diff --git a/core/src/main/java/com/google/adk/agents/LoopAgent.java b/core/src/main/java/com/google/adk/agents/LoopAgent.java index c12387231..743d569b9 100644 --- a/core/src/main/java/com/google/adk/agents/LoopAgent.java +++ b/core/src/main/java/com/google/adk/agents/LoopAgent.java @@ -30,30 +30,6 @@ * *

The loop continues until a sub-agent escalates, or until the maximum number of iterations is * reached (if specified). - * - *

Composition with {@link LlmAgent}s: a {@code LoopAgent} does not transfer control back - * to a parent {@link LlmAgent}. To react to loop results, place the {@code LoopAgent} and the - * follow-up {@link LlmAgent} as siblings inside a {@link SequentialAgent}. Loop sub-agents publish - * via {@code outputKey} and the follow-up reads via {@code {key}} placeholders in its instruction: - * - *

{@code
- * var refiner =
- *     LlmAgent.builder()
- *         .name("refiner")
- *         .model("gemini-flash-latest")
- *         .instruction("Refine: {draft?}")
- *         .outputKey("draft")
- *         .build();
- * var publisher =
- *     LlmAgent.builder()
- *         .name("publisher")
- *         .model("gemini-flash-latest")
- *         .instruction("Publish: {draft}")
- *         .build();
- * var loop =
- *     LoopAgent.builder().name("loop").subAgents(refiner).maxIterations(3).build();
- * var root = SequentialAgent.builder().name("root").subAgents(loop, publisher).build();
- * }
*/ public class LoopAgent extends BaseAgent { private static final Logger logger = LoggerFactory.getLogger(LoopAgent.class); diff --git a/core/src/main/java/com/google/adk/agents/ParallelAgent.java b/core/src/main/java/com/google/adk/agents/ParallelAgent.java index e1382a317..1e98dbf50 100644 --- a/core/src/main/java/com/google/adk/agents/ParallelAgent.java +++ b/core/src/main/java/com/google/adk/agents/ParallelAgent.java @@ -34,38 +34,6 @@ *

This approach is beneficial for scenarios requiring multiple perspectives or attempts on a * single task, such as running different algorithms simultaneously or generating multiple responses * for review by a subsequent evaluation agent. - * - *

Composition with {@link LlmAgent}s: a {@code ParallelAgent} does not transfer control - * back to a parent {@link LlmAgent}. To follow a fan-out with an aggregation step, wrap both in a - * {@link SequentialAgent} (used as the root or transferred-to agent). Each parallel sub-agent - * publishes via {@code outputKey} and the aggregator reads via {@code {key}} placeholders in its - * instruction: - * - *

{@code
- * var contacts =
- *     LlmAgent.builder()
- *         .name("contacts")
- *         .model("gemini-flash-latest")
- *         .instruction("List contacts.")
- *         .outputKey("contacts")
- *         .build();
- * var schedule =
- *     LlmAgent.builder()
- *         .name("schedule")
- *         .model("gemini-flash-latest")
- *         .instruction("List schedule.")
- *         .outputKey("schedule")
- *         .build();
- * var writer =
- *     LlmAgent.builder()
- *         .name("writer")
- *         .model("gemini-flash-latest")
- *         .instruction("Write: contacts={contacts}, schedule={schedule}")
- *         .build();
- * var gather =
- *     ParallelAgent.builder().name("gather").subAgents(contacts, schedule).build();
- * var root = SequentialAgent.builder().name("root").subAgents(gather, writer).build();
- * }
*/ public class ParallelAgent extends BaseAgent { diff --git a/core/src/main/java/com/google/adk/agents/SequentialAgent.java b/core/src/main/java/com/google/adk/agents/SequentialAgent.java index 95ca50d16..b0b45a0ec 100644 --- a/core/src/main/java/com/google/adk/agents/SequentialAgent.java +++ b/core/src/main/java/com/google/adk/agents/SequentialAgent.java @@ -22,32 +22,7 @@ import org.slf4j.Logger; import org.slf4j.LoggerFactory; -/** - * An agent that runs its sub-agents sequentially. - * - *

Composition with {@link LlmAgent}s: a {@code SequentialAgent} does not transfer control - * back to a parent {@link LlmAgent}. Use it as the root or transferred-to agent and place any - * follow-up {@link LlmAgent} as the next sibling. Upstream publishes via {@code outputKey} and - * downstream reads via {@code {key}} placeholders in its instruction: - * - *

{@code
- * var draft =
- *     LlmAgent.builder()
- *         .name("draft")
- *         .model("gemini-flash-latest")
- *         .instruction("Draft a summary.")
- *         .outputKey("draft")
- *         .build();
- * var reviewer =
- *     LlmAgent.builder()
- *         .name("reviewer")
- *         .model("gemini-flash-latest")
- *         .instruction("Polish the draft: {draft}")
- *         .build();
- * var pipeline =
- *     SequentialAgent.builder().name("pipeline").subAgents(draft, reviewer).build();
- * }
- */ +/** An agent that runs its sub-agents sequentially. */ public class SequentialAgent extends BaseAgent { private static final Logger logger = LoggerFactory.getLogger(SequentialAgent.class); diff --git a/core/src/main/java/com/google/adk/flows/llmflows/BaseLlmFlow.java b/core/src/main/java/com/google/adk/flows/llmflows/BaseLlmFlow.java index dffba0e80..fffeab698 100644 --- a/core/src/main/java/com/google/adk/flows/llmflows/BaseLlmFlow.java +++ b/core/src/main/java/com/google/adk/flows/llmflows/BaseLlmFlow.java @@ -38,10 +38,7 @@ import com.google.adk.models.LlmRequest; import com.google.adk.models.LlmResponse; import com.google.adk.telemetry.Tracing; -import com.google.adk.tools.BaseTool; -import com.google.adk.tools.BaseToolset; import com.google.adk.tools.ToolContext; -import com.google.common.annotations.VisibleForTesting; import com.google.common.collect.ImmutableList; import com.google.common.collect.Iterables; import com.google.genai.types.FunctionResponse; @@ -61,7 +58,6 @@ import java.util.Optional; import java.util.Set; import java.util.concurrent.atomic.AtomicReference; -import java.util.function.BiFunction; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -100,8 +96,20 @@ private Flowable preprocess( Context currentContext = Context.current(); LlmAgent agent = (LlmAgent) context.agent(); + RequestProcessor toolsProcessor = + (ctx, req) -> { + LlmRequest.Builder builder = req.toBuilder(); + return agent + .canonicalTools(new ReadonlyContext(ctx)) + .concatMapCompletable( + tool -> tool.processLlmRequest(builder, ToolContext.builder(ctx).build())) + .andThen( + Single.fromCallable( + () -> RequestProcessingResult.create(builder.build(), ImmutableList.of()))); + }; + Iterable allProcessors = - Iterables.concat(requestProcessors, ImmutableList.of(getRequestProcessorFromTools(agent))); + Iterables.concat(requestProcessors, ImmutableList.of(toolsProcessor)); return Flowable.fromIterable(allProcessors) .concatMap( @@ -113,58 +121,6 @@ private Flowable preprocess( result -> result.events() != null ? result.events() : ImmutableList.of())); } - /** - * Constructs a {@link RequestProcessor} that sequentially applies the {@code processLlmRequest} - * methods of all tools and toolsets associated with this agent to the incoming {@link - * LlmRequest}. - * - * @return A {@link RequestProcessor} that applies tool-specific modifications to LLM requests. - */ - @VisibleForTesting - RequestProcessor getRequestProcessorFromTools(LlmAgent agent) { - return (context, request) -> { - ReadonlyContext readonlyContext = new ReadonlyContext(context); - List> processors = new ArrayList<>(); - - for (Object toolOrToolset : agent.toolsUnion()) { - if (toolOrToolset instanceof BaseTool baseTool) { - processors.add( - (builder, ctx) -> { - Completable c = baseTool.processLlmRequest(builder, ctx); - return c == null ? Completable.complete() : c; - }); - } else if (toolOrToolset instanceof BaseToolset baseToolset) { - // First apply the toolset's own request processor, then unwrap all tools from the toolset - // and apply each individual tool's request processor sequentially. - processors.add( - (builder, ctx) -> { - Completable c = baseToolset.processLlmRequest(builder, ctx); - Completable toolsetProcessor = c == null ? Completable.complete() : c; - return toolsetProcessor - .andThen(baseToolset.getTools(readonlyContext)) - .concatMapCompletable( - b -> { - Completable tc = b.processLlmRequest(builder, ctx); - return tc == null ? Completable.complete() : tc; - }); - }); - } else { - throw new IllegalArgumentException( - "Object in tools list is not of a supported type: " - + toolOrToolset.getClass().getName()); - } - } - - LlmRequest.Builder builder = request.toBuilder(); - ToolContext toolContext = ToolContext.builder(context).build(); - return Flowable.fromIterable(processors) - .concatMapCompletable(f -> f.apply(builder, toolContext)) - .andThen( - Single.fromCallable( - () -> RequestProcessingResult.create(builder.build(), ImmutableList.of()))); - }; - } - /** * Post-processes the LLM response after receiving it from the LLM. Executes all registered {@link * ResponseProcessor} instances. Emits events for the model response and any subsequent function @@ -479,10 +435,12 @@ private Flowable runOneStep(Context spanContext, InvocationContext contex "Agent not found: " + agentToTransfer))); } return postProcessedEvents.concatWith( - nextAgent - .get() - .runAsync(context) - .compose(Tracing.withContext(spanContext))); + Flowable.defer( + () -> { + try (Scope s = spanContext.makeCurrent()) { + return nextAgent.get().runAsync(context); + } + })); } return postProcessedEvents; }); @@ -664,10 +622,12 @@ public void onError(Throwable e) { "Agent not found: " + event.actions().transferToAgent().get()); } Flowable nextAgentEvents = - nextAgent - .get() - .runLive(invocationContext) - .compose(Tracing.withContext(spanContext)); + Flowable.defer( + () -> { + try (Scope s = spanContext.makeCurrent()) { + return nextAgent.get().runLive(invocationContext); + } + }); events = Flowable.concat(events, nextAgentEvents); } return events; diff --git a/core/src/main/java/com/google/adk/flows/llmflows/Functions.java b/core/src/main/java/com/google/adk/flows/llmflows/Functions.java index 8c60ebf76..4aa20798d 100644 --- a/core/src/main/java/com/google/adk/flows/llmflows/Functions.java +++ b/core/src/main/java/com/google/adk/flows/llmflows/Functions.java @@ -29,8 +29,6 @@ import com.google.adk.events.Event; import com.google.adk.events.EventActions; import com.google.adk.events.ToolConfirmation; -import com.google.adk.telemetry.Instrumentation; -import com.google.adk.telemetry.Instrumentation.ToolExecution; import com.google.adk.telemetry.Tracing; import com.google.adk.tools.BaseTool; import com.google.adk.tools.FunctionTool; @@ -432,25 +430,6 @@ private static Maybe postProcessFunctionResult( ToolContext toolContext, boolean isLive, Context parentContext) { - return Maybe.using( - () -> - Instrumentation.recordToolExecution( - tool, invocationContext.agent(), functionArgs, parentContext), - toolExecution -> - processFunctionResult( - maybeFunctionResult, invocationContext, tool, functionArgs, toolContext, isLive) - .doOnSuccess(event -> toolExecution.context().setFunctionResponseEvent(event)) - .doOnError(toolExecution::setError), - ToolExecution::close); - } - - private static Maybe processFunctionResult( - Maybe> maybeFunctionResult, - InvocationContext invocationContext, - BaseTool tool, - Map functionArgs, - ToolContext toolContext, - boolean isLive) { return maybeFunctionResult .map(Optional::of) .defaultIfEmpty(Optional.empty()) @@ -488,7 +467,20 @@ private static Maybe processFunctionResult( tool, finalFunctionResult, toolContext, invocationContext); return Maybe.just(event); }); - }); + }) + .compose( + Tracing.trace("execute_tool [" + tool.name() + "]") + .setParent(parentContext) + .onSuccess( + (span, event) -> + Tracing.traceToolExecution( + span, + tool.name(), + tool.description(), + tool.getClass().getSimpleName(), + functionArgs, + event, + null))); } private static Optional mergeParallelFunctionResponseEvents( diff --git a/core/src/main/java/com/google/adk/plugins/agentanalytics/BigQueryAgentAnalyticsPlugin.java b/core/src/main/java/com/google/adk/plugins/agentanalytics/BigQueryAgentAnalyticsPlugin.java index 566dbd5a4..59e09c8a7 100644 --- a/core/src/main/java/com/google/adk/plugins/agentanalytics/BigQueryAgentAnalyticsPlugin.java +++ b/core/src/main/java/com/google/adk/plugins/agentanalytics/BigQueryAgentAnalyticsPlugin.java @@ -253,10 +253,7 @@ private Completable logEvent( parseFuture = state .getParser() - .parse( - content, - traceIds.traceId(), - traceIds.spanId() != null ? traceIds.spanId() : "no_span") + .parse(content) .thenAccept( parsedContent -> { row.put( diff --git a/core/src/main/java/com/google/adk/plugins/agentanalytics/GcsOffloader.java b/core/src/main/java/com/google/adk/plugins/agentanalytics/GcsOffloader.java deleted file mode 100644 index 17993bb8e..000000000 --- a/core/src/main/java/com/google/adk/plugins/agentanalytics/GcsOffloader.java +++ /dev/null @@ -1,94 +0,0 @@ -/* - * Copyright 2026 Google LLC - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package com.google.adk.plugins.agentanalytics; - -import static java.nio.charset.StandardCharsets.UTF_8; - -import com.google.auth.Credentials; -import com.google.cloud.storage.BlobId; -import com.google.cloud.storage.BlobInfo; -import com.google.cloud.storage.Storage; -import com.google.cloud.storage.StorageOptions; -import java.util.concurrent.CompletableFuture; -import java.util.concurrent.Executor; -import java.util.concurrent.RejectedExecutionException; -import org.jspecify.annotations.Nullable; - -/** Offloads content to GCS. */ -class GcsOffloader { - private final Storage storage; - private final String bucketName; - private final Executor executor; - private final boolean isStorageOverride; - - GcsOffloader( - String projectId, - String bucketName, - Executor executor, - @Nullable Credentials credentials, - @Nullable Storage storageOverride) { - if (storageOverride != null) { - this.isStorageOverride = true; - this.storage = storageOverride; - } else { - this.isStorageOverride = false; - StorageOptions.Builder builder = StorageOptions.newBuilder().setProjectId(projectId); - if (credentials != null) { - builder.setCredentials(credentials); - } - this.storage = builder.build().getService(); - } - this.bucketName = bucketName; - this.executor = executor; - } - - /** Async wrapper around blocking GCS upload for binary data. */ - CompletableFuture uploadContent(byte[] data, String contentType, String path) { - try { - return CompletableFuture.supplyAsync( - () -> { - BlobId blobId = BlobId.of(bucketName, path); - BlobInfo blobInfo = BlobInfo.newBuilder(blobId).setContentType(contentType).build(); - storage.create(blobInfo, data); - return String.format("gs://%s/%s", bucketName, path); - }, - executor); - } catch (RejectedExecutionException e) { - return CompletableFuture.failedFuture(e); - } - } - - /** Async wrapper around blocking GCS upload for text data. */ - CompletableFuture uploadContent(String data, String contentType, String path) { - try { - return CompletableFuture.supplyAsync(() -> data.getBytes(UTF_8), executor) - .thenCompose(bytes -> uploadContent(bytes, contentType, path)); - } catch (RejectedExecutionException e) { - return CompletableFuture.failedFuture(e); - } - } - - String getBucketName() { - return bucketName; - } - - void close() throws Exception { - if (storage != null && !isStorageOverride) { - storage.close(); - } - } -} diff --git a/core/src/main/java/com/google/adk/plugins/agentanalytics/JsonFormatter.java b/core/src/main/java/com/google/adk/plugins/agentanalytics/JsonFormatter.java index 58049a9ef..34430b861 100644 --- a/core/src/main/java/com/google/adk/plugins/agentanalytics/JsonFormatter.java +++ b/core/src/main/java/com/google/adk/plugins/agentanalytics/JsonFormatter.java @@ -16,26 +16,20 @@ package com.google.adk.plugins.agentanalytics; -import static java.util.Collections.newSetFromMap; - import com.fasterxml.jackson.databind.JsonNode; import com.fasterxml.jackson.databind.ObjectMapper; import com.fasterxml.jackson.databind.node.ArrayNode; import com.fasterxml.jackson.databind.node.ObjectNode; import com.google.auto.value.AutoValue; import com.google.common.base.Utf8; -import java.util.IdentityHashMap; import java.util.Map; import java.util.Set; -import java.util.logging.Logger; import org.jspecify.annotations.Nullable; /** Utility for parsing, formatting and truncating content for BigQuery logging. */ final class JsonFormatter { - private static final Logger logger = Logger.getLogger(JsonFormatter.class.getName()); static final ObjectMapper mapper = new ObjectMapper().findAndRegisterModules(); static final String TRUNCATION_SUFFIX = "...[truncated]"; - static final String CYCLE_DETECTED_MESSAGE = "[cycle detected]"; @AutoValue abstract static class TruncationResult { @@ -54,14 +48,10 @@ static TruncationResult smartTruncate(Object obj, int maxLength) { return TruncationResult.create(mapper.nullNode(), false); } try { - if (obj instanceof JsonNode jsonNode) { - return recursiveSmartTruncate(jsonNode, maxLength, newSetFromMap(new IdentityHashMap<>())); - } - return recursiveSmartTruncate( - mapper.valueToTree(obj), maxLength, newSetFromMap(new IdentityHashMap<>())); + return recursiveSmartTruncate(mapper.valueToTree(obj), maxLength); } catch (IllegalArgumentException e) { // Fallback for types that mapper can't handle directly as a tree - return truncateWithStatus(safeToString(obj), maxLength); + return truncateWithStatus(String.valueOf(obj), maxLength); } } @@ -72,60 +62,38 @@ static JsonNode convertToJsonNode(Object obj) { try { return mapper.valueToTree(obj); } catch (IllegalArgumentException e) { - // Fallback for types that mapper can't handle directly as a tree. - return mapper.valueToTree(safeToString(obj)); - } - } - - static String safeToString(Object obj) { - try { - return String.valueOf(obj); - } catch (RuntimeException e) { - logger.warning("RuntimeException when converting object to string"); - return "[ERROR CONVERTING TO STRING]"; + // Fallback for types that mapper can't handle directly as a tree + return mapper.valueToTree(String.valueOf(obj)); } } - private static TruncationResult recursiveSmartTruncate( - JsonNode node, int maxLength, Set visited) { - if (node.isContainerNode()) { - if (visited.contains(node)) { - return TruncationResult.create(mapper.valueToTree(CYCLE_DETECTED_MESSAGE), true); - } - visited.add(node); - } - try { - boolean isTruncated = false; - if (node.isTextual()) { - String text = node.asText(); - if (Utf8.encodedLength(text) > maxLength) { - return TruncationResult.create(mapper.valueToTree(truncate(text, maxLength)), true); - } - return TruncationResult.create(node, false); - } else if (node.isObject()) { - ObjectNode newNode = mapper.createObjectNode(); - Set> properties = node.properties(); - for (Map.Entry entry : properties) { - TruncationResult res = recursiveSmartTruncate(entry.getValue(), maxLength, visited); - newNode.set(entry.getKey(), res.node()); - isTruncated = isTruncated || res.isTruncated(); - } - return TruncationResult.create(newNode, isTruncated); - } else if (node.isArray()) { - ArrayNode newNode = mapper.createArrayNode(); - for (JsonNode element : node) { - TruncationResult res = recursiveSmartTruncate(element, maxLength, visited); - newNode.add(res.node()); - isTruncated = isTruncated || res.isTruncated(); - } - return TruncationResult.create(newNode, isTruncated); + private static TruncationResult recursiveSmartTruncate(JsonNode node, int maxLength) { + boolean isTruncated = false; + if (node.isTextual()) { + String text = node.asText(); + if (Utf8.encodedLength(text) > maxLength) { + return TruncationResult.create(mapper.valueToTree(truncate(text, maxLength)), true); } return TruncationResult.create(node, false); - } finally { - if (node.isContainerNode()) { - visited.remove(node); + } else if (node.isObject()) { + ObjectNode newNode = mapper.createObjectNode(); + Set> properties = node.properties(); + for (Map.Entry entry : properties) { + TruncationResult res = recursiveSmartTruncate(entry.getValue(), maxLength); + newNode.set(entry.getKey(), res.node()); + isTruncated = isTruncated || res.isTruncated(); + } + return TruncationResult.create(newNode, isTruncated); + } else if (node.isArray()) { + ArrayNode newNode = mapper.createArrayNode(); + for (JsonNode element : node) { + TruncationResult res = recursiveSmartTruncate(element, maxLength); + newNode.add(res.node()); + isTruncated = isTruncated || res.isTruncated(); } + return TruncationResult.create(newNode, isTruncated); } + return TruncationResult.create(node, false); } static TruncationResult truncateWithStatus(String s, int maxLength) { diff --git a/core/src/main/java/com/google/adk/plugins/agentanalytics/MimeTypeMapper.java b/core/src/main/java/com/google/adk/plugins/agentanalytics/MimeTypeMapper.java deleted file mode 100644 index 8505e2d1a..000000000 --- a/core/src/main/java/com/google/adk/plugins/agentanalytics/MimeTypeMapper.java +++ /dev/null @@ -1,60 +0,0 @@ -/* - * Copyright 2026 Google LLC - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package com.google.adk.plugins.agentanalytics; - -import com.google.common.collect.ImmutableMap; - -/** Utility to map MIME types to file extensions. */ -final class MimeTypeMapper { - private static final ImmutableMap MIME_TO_EXT = - ImmutableMap.builder() - // Images - .put("image/jpeg", ".jpg") - .put("image/png", ".png") - .put("image/gif", ".gif") - .put("image/webp", ".webp") - .put("image/bmp", ".bmp") - .put("image/tiff", ".tiff") - // Audio - .put("audio/mpeg", ".mp3") - .put("audio/ogg", ".ogg") - .put("audio/wav", ".wav") - .put("audio/x-wav", ".wav") - .put("audio/webm", ".webm") - .put("audio/aac", ".aac") - .put("audio/midi", ".mid") - .put("audio/x-m4a", ".m4a") - // Video - .put("video/mp4", ".mp4") - .put("video/mpeg", ".mpeg") - .put("video/ogg", ".ogv") - .put("video/webm", ".webm") - .put("video/avi", ".avi") - .put("video/x-msvideo", ".avi") - .put("video/quicktime", ".mov") - .buildOrThrow(); - - private MimeTypeMapper() {} - - /** - * Returns the file extension (including the dot) for the given MIME type. Returns an empty string - * if the MIME type is unknown. - */ - static String getExtension(String mimeType) { - return MIME_TO_EXT.getOrDefault(mimeType, ""); - } -} diff --git a/core/src/main/java/com/google/adk/plugins/agentanalytics/Parser.java b/core/src/main/java/com/google/adk/plugins/agentanalytics/Parser.java index c489273e1..5db8be46c 100644 --- a/core/src/main/java/com/google/adk/plugins/agentanalytics/Parser.java +++ b/core/src/main/java/com/google/adk/plugins/agentanalytics/Parser.java @@ -19,7 +19,6 @@ import static com.google.adk.plugins.agentanalytics.JsonFormatter.mapper; import static com.google.adk.plugins.agentanalytics.JsonFormatter.smartTruncate; import static com.google.adk.plugins.agentanalytics.JsonFormatter.truncate; -import static com.google.adk.plugins.agentanalytics.JsonFormatter.truncateAndAddSuffix; import static com.google.adk.plugins.agentanalytics.JsonFormatter.truncateWithStatus; import com.fasterxml.jackson.annotation.JsonProperty; @@ -40,40 +39,16 @@ import java.util.ArrayList; import java.util.List; import java.util.Optional; -import java.util.UUID; import java.util.concurrent.CompletableFuture; -import java.util.logging.Level; -import java.util.logging.Logger; import org.jspecify.annotations.Nullable; -import org.threeten.bp.Instant; -import org.threeten.bp.LocalDate; -import org.threeten.bp.ZoneOffset; /** Utility for parsing content for BigQuery logging. */ final class Parser { - private static final String DEFAULT_EXTENSION = ".bin"; - private static final int MAX_OFFLOADED_TEXT_LENGTH = 200; - private static final Logger logger = Logger.getLogger(Parser.class.getName()); - private static final int INLINE_TEXT_LIMIT = 32 * 1024; // 32KB limit - private static final String UPLOAD_FAILED_MESSAGE = "[UPLOAD FAILED]"; - private static final String MEDIA_OFFLOADED_MESSAGE = "[MEDIA OFFLOADED]"; private static final String BINARY_DATA_MESSAGE = "[BINARY DATA]"; - private static final String TEXT_OFFLOADED_SUFFIX = "... [OFFLOADED]"; - - private final @Nullable GcsOffloader offloader; private final int maxLength; - private final @Nullable String connectionId; - private final boolean logMultiModalContent; - - Parser( - @Nullable GcsOffloader offloader, - int maxLength, - @Nullable String connectionId, - boolean logMultiModalContent) { - this.offloader = offloader; + + Parser(int maxLength) { this.maxLength = maxLength; - this.connectionId = connectionId; - this.logMultiModalContent = logMultiModalContent; } @AutoValue @@ -177,11 +152,9 @@ static ObjectRef create( * Parses content into JSON payload and content parts, matching Python implementation. * * @param content the content to parse - * @param traceId the trace ID for GCS path - * @param spanId the span ID for GCS path * @return a CompletableFuture of ParsedContent object */ - CompletableFuture parse(Object content, String traceId, String spanId) { + CompletableFuture parse(Object content) { if (content instanceof LlmRequest llmRequest) { ObjectNode jsonPayload = mapper.createObjectNode(); ArrayNode messages = mapper.createArrayNode(); @@ -189,15 +162,13 @@ CompletableFuture parse(Object content, String traceId, String sp List contents = llmRequest.contents(); for (Content c : contents) { - futures.add(parseContentObject(c, traceId, spanId)); + futures.add(parseContentObject(c)); } CompletableFuture systemFuture = null; if (llmRequest.config().isPresent() && llmRequest.config().get().systemInstruction().isPresent()) { - systemFuture = - parseContentObject( - llmRequest.config().get().systemInstruction().get(), traceId, spanId); + systemFuture = parseContentObject(llmRequest.config().get().systemInstruction().get()); futures.add(systemFuture); } CompletableFuture finalSystemFuture = systemFuture; @@ -231,7 +202,7 @@ CompletableFuture parse(Object content, String traceId, String sp } if (content instanceof LlmResponse llmResponse) { ObjectNode jsonPayload = mapper.createObjectNode(); - return parseContentObject(llmResponse.content().orElse(null), traceId, spanId) + return parseContentObject(llmResponse.content().orElse(null)) .thenApply( parsed -> { ObjectNode summaryNode = mapper.createObjectNode(); @@ -254,7 +225,7 @@ CompletableFuture parse(Object content, String traceId, String sp }); } if (content instanceof Content || content instanceof Part) { - return parseContentObject(content, traceId, spanId) + return parseContentObject(content) .thenApply( parsed -> { ObjectNode summaryNode = mapper.createObjectNode(); @@ -278,13 +249,10 @@ CompletableFuture parse(Object content, String traceId, String sp * Parses a Content or Part object into summary text and content parts. * * @param content the Content or Part object to parse - * @param traceId the trace ID for GCS path - * @param spanId the span ID for GCS path * @return a CompletableFuture of ParsedContentObject containing parts, summary, and truncation * flag */ - private CompletableFuture parseContentObject( - Object content, String traceId, String spanId) { + private CompletableFuture parseContentObject(Object content) { List parts; if (content instanceof Content c) { parts = c.parts().orElse(ImmutableList.of()); @@ -297,7 +265,7 @@ private CompletableFuture parseContentObject( List> partFutures = new ArrayList<>(); for (int i = 0; i < parts.size(); i++) { - partFutures.add(processPart(parts.get(i), i, traceId, spanId)); + partFutures.add(processPart(parts.get(i), i)); } return CompletableFuture.allOf(partFutures.toArray(new CompletableFuture[0])) @@ -327,8 +295,7 @@ private CompletableFuture parseContentObject( }); } - private CompletableFuture processPart( - Part part, int index, String traceId, String spanId) { + private CompletableFuture processPart(Part part, int index) { ContentPart.Builder partBuilder = ContentPart.builder() .setPartIndex(index) @@ -353,88 +320,17 @@ private CompletableFuture processPart( if (part.inlineData().isPresent()) { Blob blob = part.inlineData().get(); String mimeType = blob.mimeType().orElse("application/octet-stream"); - if (logMultiModalContent && offloader != null) { - String ext = MimeTypeMapper.getExtension(mimeType); - if (ext.isEmpty()) { - ext = DEFAULT_EXTENSION; - } - - String path = - String.format( - "%s/%s/%s_p%d_%s%s", - getLocalDate(), traceId, spanId, index, UUID.randomUUID(), ext); - return offloader - .uploadContent(blob.data().orElse(new byte[0]), mimeType, path) - .handle( - (uri, ex) -> { - if (ex != null) { - logger.log(Level.WARNING, "Failed to offload content to GCS", ex); - partBuilder.setText(UPLOAD_FAILED_MESSAGE); - } else { - ObjectNode details = mapper.createObjectNode(); - ObjectNode gcsMetadata = details.putObject("gcs_metadata"); - gcsMetadata.put("content_type", mimeType); - - partBuilder - .setStorageMode("GCS_REFERENCE") - .setUri(uri) - .setMimeType(mimeType) - .setText(MEDIA_OFFLOADED_MESSAGE) - .setObjectRef( - mapper.valueToTree(ObjectRef.create(uri, null, connectionId, details))); - } - return TruncationResult.create(mapper.valueToTree(partBuilder.build()), false); - }); - } else { - partBuilder.setText(BINARY_DATA_MESSAGE).setMimeType(mimeType); - return CompletableFuture.completedFuture( - TruncationResult.create(mapper.valueToTree(partBuilder.build()), false)); - } + partBuilder.setText(BINARY_DATA_MESSAGE).setMimeType(mimeType); + return CompletableFuture.completedFuture( + TruncationResult.create(mapper.valueToTree(partBuilder.build()), false)); } // CASE C: Text if (part.text().isPresent()) { String text = part.text().get(); - int textLen = Utf8.encodedLength(text); - int offloadThreshold = Math.min(INLINE_TEXT_LIMIT, maxLength); - - if (offloader != null && textLen > offloadThreshold) { - - String path = - String.format( - "%s/%s/%s_p%d_%s.txt", getLocalDate(), traceId, spanId, index, UUID.randomUUID()); - return offloader - .uploadContent(text, "text/plain", path) - .handle( - (uri, ex) -> { - if (ex != null) { - logger.log(Level.WARNING, "Failed to offload text to GCS", ex); - TruncationResult res = truncateWithStatus(text, maxLength); - partBuilder.setText(res.node().asText()); - return TruncationResult.create( - mapper.valueToTree(partBuilder.build()), res.isTruncated()); - } else { - ObjectNode details = mapper.createObjectNode(); - ObjectNode gcsMetadata = details.putObject("gcs_metadata"); - gcsMetadata.put("content_type", "text/plain"); - - partBuilder - .setStorageMode("GCS_REFERENCE") - .setUri(uri) - .setMimeType("text/plain") - .setText( - truncateAndAddSuffix( - text, MAX_OFFLOADED_TEXT_LENGTH, TEXT_OFFLOADED_SUFFIX)) - .setObjectRef( - mapper.valueToTree(ObjectRef.create(uri, null, connectionId, details))); - return TruncationResult.create(mapper.valueToTree(partBuilder.build()), true); - } - }); - } else { - TruncationResult res = truncateWithStatus(text, maxLength); - partBuilder.setText(res.node().asText()); - return CompletableFuture.completedFuture( - TruncationResult.create(mapper.valueToTree(partBuilder.build()), res.isTruncated())); - } + TruncationResult res = truncateWithStatus(text, maxLength); + partBuilder.setText(res.node().asText()); + return CompletableFuture.completedFuture( + TruncationResult.create(mapper.valueToTree(partBuilder.build()), res.isTruncated())); } if (part.functionCall().isPresent()) { FunctionCall fc = part.functionCall().get(); @@ -483,8 +379,4 @@ ArrayNode formatContentParts(Optional content) { } return partsArray; } - - private LocalDate getLocalDate() { - return Instant.now().atZone(ZoneOffset.UTC).toLocalDate(); - } } diff --git a/core/src/main/java/com/google/adk/plugins/agentanalytics/PluginState.java b/core/src/main/java/com/google/adk/plugins/agentanalytics/PluginState.java index d1826ec5e..0654fab5d 100644 --- a/core/src/main/java/com/google/adk/plugins/agentanalytics/PluginState.java +++ b/core/src/main/java/com/google/adk/plugins/agentanalytics/PluginState.java @@ -3,7 +3,6 @@ import static com.google.adk.plugins.agentanalytics.BigQueryUtils.getVersionHeaderValue; import static com.google.common.collect.ImmutableList.toImmutableList; import static java.util.concurrent.TimeUnit.MILLISECONDS; -import static java.util.concurrent.TimeUnit.SECONDS; import com.google.api.gax.core.FixedCredentialsProvider; import com.google.api.gax.retrying.RetrySettings; @@ -22,35 +21,21 @@ import java.util.Collection; import java.util.Set; import java.util.UUID; -import java.util.concurrent.ArrayBlockingQueue; import java.util.concurrent.CompletableFuture; import java.util.concurrent.ConcurrentHashMap; -import java.util.concurrent.ExecutorService; import java.util.concurrent.Executors; import java.util.concurrent.ScheduledExecutorService; -import java.util.concurrent.ThreadPoolExecutor; import java.util.concurrent.TimeoutException; import java.util.concurrent.atomic.AtomicLong; import java.util.logging.Level; import java.util.logging.Logger; -import org.jspecify.annotations.Nullable; import org.threeten.bp.Duration; -import org.threeten.bp.Instant; /** Manages state for the BigQueryAgentAnalyticsPlugin. */ class PluginState { private static final Logger logger = Logger.getLogger(PluginState.class.getName()); - private static final int GCS_OFFLOAD_CORE_POOL_SIZE = 2; - private static final int GCS_OFFLOAD_MAX_THREADS = 10; - // Max number of tasks in the queue before we start rejecting tasks and executing them in the - // caller thread. - private static final int GCS_OFFLOAD_QUEUE_SIZE = 100; - // Idle time before threads are terminated. - private static final int GCS_OFFLOAD_IDLE_TIME_SECONDS = 30; - private final BigQueryLoggerConfig config; private final ScheduledExecutorService executor; - private final ExecutorService offloadExecutor; private final BigQueryWriteClient writeClient; private static final AtomicLong threadCounter = new AtomicLong(0); // Map of invocation ID to BatchProcessor. @@ -60,7 +45,6 @@ class PluginState { private final ConcurrentHashMap traceManagers = new ConcurrentHashMap<>(); // Cache of invocation ID to Boolean indicating invocation ID has been processed. private final Cache processedInvocations; - private final GcsOffloader offloader; private final Parser parser; private final ConcurrentHashMap>> pendingTasks = new ConcurrentHashMap<>(); @@ -70,7 +54,6 @@ class PluginState { this.executor = Executors.newScheduledThreadPool( 2, r -> new Thread(r, "bq-analytics-plugin-" + threadCounter.getAndIncrement())); - this.offloadExecutor = createGcsOffloadThreadPool(); // One write client per plugin instance, shared by all invocations. this.writeClient = createWriteClient(config); this.processedInvocations = @@ -78,25 +61,7 @@ class PluginState { .maximumSize(10000) .expireAfterWrite(java.time.Duration.ofMinutes(10)) .build(); - this.offloader = getGcsOffloader(config); - this.parser = - new Parser( - offloader, - config.maxContentLength(), - config.connectionId().orElse(null), - config.logMultiModalContent()); - } - - private static ExecutorService createGcsOffloadThreadPool() { - return new ThreadPoolExecutor( - GCS_OFFLOAD_CORE_POOL_SIZE, // The lower limit of threads. - GCS_OFFLOAD_MAX_THREADS, // The upper limit of threads. - GCS_OFFLOAD_IDLE_TIME_SECONDS, // Time to keep idle threads alive. - SECONDS, - new ArrayBlockingQueue<>(GCS_OFFLOAD_QUEUE_SIZE), // workQueue: Hand off tasks directly. - r -> new Thread(r, "bq-analytics-plugin-offload-" + threadCounter.getAndIncrement()), - // Reject tasks if the queue is full. - new ThreadPoolExecutor.AbortPolicy()); + this.parser = new Parser(config.maxContentLength()); } ScheduledExecutorService getExecutor() { @@ -177,14 +142,6 @@ BatchProcessor getBatchProcessor(String invocationId) { }); } - protected @Nullable GcsOffloader getGcsOffloader(BigQueryLoggerConfig config) { - if (config.gcsBucketName().isEmpty()) { - return null; - } - return new GcsOffloader( - config.projectId(), config.gcsBucketName(), offloadExecutor, config.credentials(), null); - } - Parser getParser() { return parser; } @@ -217,8 +174,7 @@ void clearBatchProcessors() { batchProcessors.clear(); } - @VisibleForTesting - protected Set> getPendingTasksForInvocation(String invocationId) { + private Set> getPendingTasksForInvocation(String invocationId) { return pendingTasks.computeIfAbsent(invocationId, k -> ConcurrentHashMap.newKeySet()); } @@ -307,34 +263,13 @@ Completable close() { } try { executor.shutdown(); - offloadExecutor.shutdown(); - long totalTimeoutMillis = config.shutdownTimeout().toMillis(); - Instant startTime = Instant.now(); - if (!executor.awaitTermination(totalTimeoutMillis, MILLISECONDS)) { + if (!executor.awaitTermination(config.shutdownTimeout().toMillis(), MILLISECONDS)) { executor.shutdownNow(); } - long elapsedTimeMillis = Duration.between(startTime, Instant.now()).toMillis(); - long remainingMillis = totalTimeoutMillis - elapsedTimeMillis; - if (remainingMillis > 0) { - if (!offloadExecutor.awaitTermination(remainingMillis, MILLISECONDS)) { - offloadExecutor.shutdownNow(); - } - } else { - offloadExecutor.shutdownNow(); - } } catch (InterruptedException e) { executor.shutdownNow(); - offloadExecutor.shutdownNow(); Thread.currentThread().interrupt(); } - - try { - if (offloader != null) { - offloader.close(); - } - } catch (Exception e) { - logger.log(Level.WARNING, "Failed to close GCS offloader", e); - } }); } } diff --git a/core/src/main/java/com/google/adk/runner/Runner.java b/core/src/main/java/com/google/adk/runner/Runner.java index 1ab101398..44a281f72 100644 --- a/core/src/main/java/com/google/adk/runner/Runner.java +++ b/core/src/main/java/com/google/adk/runner/Runner.java @@ -16,8 +16,6 @@ package com.google.adk.runner; -import static com.google.common.collect.ImmutableSet.toImmutableSet; - import com.google.adk.agents.ActiveStreamingTool; import com.google.adk.agents.BaseAgent; import com.google.adk.agents.ContextCacheConfig; @@ -47,8 +45,6 @@ import com.google.adk.utils.CollectionUtils; import com.google.common.base.Preconditions; import com.google.common.collect.ImmutableList; -import com.google.common.collect.ImmutableSet; -import com.google.common.collect.Iterables; import com.google.common.collect.MapMaker; import com.google.errorprone.annotations.CanIgnoreReturnValue; import com.google.genai.types.AudioTranscriptionConfig; @@ -68,7 +64,6 @@ import java.util.Collections; import java.util.List; import java.util.Map; -import java.util.Objects; import java.util.Optional; import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.ConcurrentMap; @@ -490,12 +485,6 @@ protected Flowable runAsyncImpl( BaseAgent rootAgent = this.agent; String invocationId = InvocationContext.newInvocationContextId(); - // Pre-merge stateDelta so onUserMessageCallback can access it. - // Safe: session is a copy; persistence still happens via appendNewMessageToSession. - if (stateDelta != null && !stateDelta.isEmpty()) { - stateDelta.forEach((key, value) -> session.state().put(key, value)); - } - // Create initial context InvocationContext initialContext = newInvocationContextBuilder(session) @@ -777,15 +766,12 @@ private boolean isTransferableAcrossAgentTree(BaseAgent agentToRun) { return true; } - /** Returns the agent that should handle the next request based on session history. */ + /** + * Returns the agent that should handle the next request based on session history. + * + * @return agent to run. + */ private BaseAgent findAgentToRun(Session session, BaseAgent rootAgent) { - // Route function responses back to the originating function-call author so HITL tool - // confirmations resume the sub-agent even through non-LlmAgent ancestors. - Optional functionCallAuthor = findFunctionCallAuthor(session, rootAgent); - if (functionCallAuthor.isPresent()) { - return functionCallAuthor.get(); - } - List events = new ArrayList<>(session.events()); Collections.reverse(events); @@ -816,39 +802,6 @@ private BaseAgent findAgentToRun(Session session, BaseAgent rootAgent) { return rootAgent; } - /** - * If the last event is a function response, returns the agent that emitted the matching function - * call (by id), or empty if no match is found in the agent tree. - */ - private static Optional findFunctionCallAuthor(Session session, BaseAgent rootAgent) { - List events = session.events(); - if (events.isEmpty()) { - return Optional.empty(); - } - ImmutableSet functionResponseIds = - Iterables.getLast(events).functionResponses().stream() - .map(fr -> fr.id().orElse(null)) - .filter(Objects::nonNull) - .collect(toImmutableSet()); - - // Iterate in reverse to prefer the most recent matching call, mirroring Python ADK's - // find_event_by_function_call_id. Function call IDs are unique in normal flows, so this - // is defense-in-depth and not covered by mutation testing. - List precedingEvents = new ArrayList<>(events.subList(0, events.size() - 1)); - Collections.reverse(precedingEvents); - for (Event event : precedingEvents) { - boolean matches = - event.functionCalls().stream() - .map(fc -> fc.id().orElse(null)) - .filter(Objects::nonNull) - .anyMatch(functionResponseIds::contains); - if (matches && event.author() != null) { - return rootAgent.findAgent(event.author()); - } - } - return Optional.empty(); - } - private void addActiveStreamingTools(InvocationContext invocationContext, List tools) { tools.stream() .filter(FunctionTool.class::isInstance) diff --git a/core/src/main/java/com/google/adk/skills/AbstractSkillSource.java b/core/src/main/java/com/google/adk/skills/AbstractSkillSource.java index 31f17a18e..aca399f92 100644 --- a/core/src/main/java/com/google/adk/skills/AbstractSkillSource.java +++ b/core/src/main/java/com/google/adk/skills/AbstractSkillSource.java @@ -16,8 +16,6 @@ package com.google.adk.skills; -import static com.google.adk.skills.SkillSourceException.SKILL_FORMAT_ERROR; -import static com.google.adk.skills.SkillSourceException.SKILL_LOAD_ERROR; import static java.nio.channels.Channels.newReader; import static java.nio.charset.StandardCharsets.UTF_8; @@ -84,14 +82,12 @@ private Frontmatter loadFrontmatter(String skillName, PathT skillMdPath) Frontmatter frontmatter = yamlMapper.readValue(yaml, Frontmatter.class); if (!frontmatter.name().equals(skillName)) { throw new SkillSourceException( - "Skill name in the frontmatter '%s' does not match skill name '%s'." - .formatted(frontmatter.name(), skillName), - SKILL_LOAD_ERROR); + "Skill name '%s' does not match directory name '%s'." + .formatted(frontmatter.name(), skillName)); } return frontmatter; } catch (IOException e) { - throw new SkillSourceException( - "Cannot load frontmatter for skill '" + skillName + "'", SKILL_LOAD_ERROR, e); + throw new SkillSourceException("Cannot load frontmatter for skill '" + skillName + "'", e); } } @@ -104,9 +100,7 @@ public Single loadInstructions(String skillName) { return readInstructions(reader); } catch (IOException e) { throw new SkillSourceException( - "Failed to load instruction for skill '" + skillName + "'", - SKILL_LOAD_ERROR, - e); + "Failed to load instruction for skill '" + skillName + "'", e); } }); } @@ -146,8 +140,7 @@ private String readFrontmatterYaml(BufferedReader reader) throws IOException, SkillSourceException { String line = reader.readLine(); if (line == null || !line.trim().equals(THREE_DASHES)) { - throw new SkillSourceException( - "Skill file must start with " + THREE_DASHES, SKILL_FORMAT_ERROR); + throw new SkillSourceException("Skill file must start with " + THREE_DASHES); } StringBuilder sb = new StringBuilder(); @@ -158,15 +151,14 @@ private String readFrontmatterYaml(BufferedReader reader) sb.append(line).append("\n"); } throw new SkillSourceException( - "Skill file frontmatter not properly closed with " + THREE_DASHES, SKILL_FORMAT_ERROR); + "Skill file frontmatter not properly closed with " + THREE_DASHES); } private String readInstructions(BufferedReader reader) throws IOException, SkillSourceException { // Skip the frontmatter block String line = reader.readLine(); if (line == null || !line.trim().equals(THREE_DASHES)) { - throw new SkillSourceException( - "Skill file must start with " + THREE_DASHES, SKILL_FORMAT_ERROR); + throw new SkillSourceException("Skill file must start with " + THREE_DASHES); } boolean dashClosed = false; while ((line = reader.readLine()) != null) { @@ -177,7 +169,7 @@ private String readInstructions(BufferedReader reader) throws IOException, Skill } if (!dashClosed) { throw new SkillSourceException( - "Skill file frontmatter not properly closed with " + THREE_DASHES, SKILL_FORMAT_ERROR); + "Skill file frontmatter not properly closed with " + THREE_DASHES); } // Read the instructions till the end of the file StringBuilder sb = new StringBuilder(); diff --git a/core/src/main/java/com/google/adk/skills/InMemorySkillSource.java b/core/src/main/java/com/google/adk/skills/InMemorySkillSource.java index d299dfb21..42916e36a 100644 --- a/core/src/main/java/com/google/adk/skills/InMemorySkillSource.java +++ b/core/src/main/java/com/google/adk/skills/InMemorySkillSource.java @@ -16,8 +16,6 @@ package com.google.adk.skills; -import static com.google.adk.skills.SkillSourceException.RESOURCE_NOT_FOUND; -import static com.google.adk.skills.SkillSourceException.SKILL_NOT_FOUND; import static com.google.common.base.Preconditions.checkState; import static com.google.common.collect.ImmutableList.toImmutableList; import static java.nio.charset.StandardCharsets.UTF_8; @@ -58,8 +56,7 @@ public Single> listFrontmatters() { public Single> listResources(String skillName, String resourceDirectory) { SkillData data = skills.get(skillName); if (data == null) { - return Single.error( - new SkillSourceException("Skill not found: " + skillName, SKILL_NOT_FOUND)); + return Single.error(new SkillSourceException("Skill not found: " + skillName)); } String prefix = resourceDirectory.isEmpty() @@ -70,8 +67,7 @@ public Single> listResources(String skillName, String reso && data.resources().keySet().stream().noneMatch(path -> path.startsWith(prefix))) { return Single.error( new SkillSourceException( - "Resource directory not found: " + resourceDirectory + " for skill: " + skillName, - RESOURCE_NOT_FOUND)); + "Resource directory not found: " + resourceDirectory + " for skill: " + skillName)); } return Single.just( @@ -96,16 +92,13 @@ public Single loadResource(String skillName, String resourcePath) { .map(SkillData::resources) .mapOptional(m -> Optional.ofNullable(m.get(resourcePath))) .switchIfEmpty( - Single.error( - new SkillSourceException( - "Resource not found: " + resourcePath, RESOURCE_NOT_FOUND))); + Single.error(new SkillSourceException("Resource not found: " + resourcePath))); } private Single getSkillData(String skillName) { SkillData data = skills.get(skillName); if (data == null) { - return Single.error( - new SkillSourceException("Skill not found: " + skillName, SKILL_NOT_FOUND)); + return Single.error(new SkillSourceException("Skill not found: " + skillName)); } return Single.just(data); } diff --git a/core/src/main/java/com/google/adk/skills/LocalSkillSource.java b/core/src/main/java/com/google/adk/skills/LocalSkillSource.java index b4b7d4876..8c891e90f 100644 --- a/core/src/main/java/com/google/adk/skills/LocalSkillSource.java +++ b/core/src/main/java/com/google/adk/skills/LocalSkillSource.java @@ -16,10 +16,6 @@ package com.google.adk.skills; -import static com.google.adk.skills.SkillSourceException.RESOURCE_LOAD_ERROR; -import static com.google.adk.skills.SkillSourceException.RESOURCE_NOT_FOUND; -import static com.google.adk.skills.SkillSourceException.SKILL_LOAD_ERROR; -import static com.google.adk.skills.SkillSourceException.SKILL_NOT_FOUND; import static com.google.common.collect.ImmutableList.toImmutableList; import static java.nio.file.Files.isDirectory; @@ -45,18 +41,22 @@ public LocalSkillSource(Path skillsBasePath) { @Override public Single> listResources(String skillName, String resourceDirectory) { + try { + validatePathWithinBase(skillsBasePath, skillName); + validatePathWithinBase(skillsBasePath.resolve(skillName), resourceDirectory); + } catch (SkillSourceException e) { + return Single.error(e); + } Path skillDir = skillsBasePath.resolve(skillName); if (!isDirectory(skillDir)) { - return Single.error( - new SkillSourceException("Skill not found: " + skillName, SKILL_NOT_FOUND)); + return Single.error(new SkillSourceException("Skill not found: " + skillName)); } Path resourceDir = skillDir.resolve(resourceDirectory); if (!isDirectory(resourceDir)) { return Single.error( new SkillSourceException( "Resource directory '%s' not found for skill '%s'" - .formatted(resourceDirectory, skillName), - RESOURCE_NOT_FOUND)); + .formatted(resourceDirectory, skillName))); } return Single.fromCallable( @@ -73,9 +73,7 @@ public Single> listResources(String skillName, String reso t -> Single.error( new SkillSourceException( - "Failed to traverse resource directory: " + resourceDirectory, - RESOURCE_LOAD_ERROR, - t))); + "Failed to traverse resource directory: " + resourceDirectory, t))); } @Override @@ -86,9 +84,7 @@ protected Flowable listSkills() { t -> Flowable.error( new SkillSourceException( - "Failed to list skills in directory: " + skillsBasePath, - SKILL_LOAD_ERROR, - t))) + "Failed to list skills in directory: " + skillsBasePath, t))) .filter(Files::isDirectory) .mapOptional(this::findSkillMd) .map(skillMd -> new SkillMdPath(skillMd.getParent().getFileName().toString(), skillMd)); @@ -96,25 +92,33 @@ protected Flowable listSkills() { @Override protected Single findResourcePath(String skillName, String resourcePath) { + try { + validatePathWithinBase(skillsBasePath, skillName); + validatePathWithinBase(skillsBasePath.resolve(skillName), resourcePath); + } catch (SkillSourceException e) { + return Single.error(e); + } Path file = skillsBasePath.resolve(skillName).resolve(resourcePath); if (!Files.exists(file)) { - return Single.error( - new SkillSourceException("Resource not found: " + file, RESOURCE_NOT_FOUND)); + return Single.error(new SkillSourceException("Resource not found: " + file)); } return Single.just(file); } @Override protected Single findSkillMdPath(String skillName) { + try { + validatePathWithinBase(skillsBasePath, skillName); + } catch (SkillSourceException e) { + return Single.error(e); + } Path skillDir = skillsBasePath.resolve(skillName); if (!isDirectory(skillDir)) { - return Single.error( - new SkillSourceException("Skill directory not found: " + skillName, SKILL_NOT_FOUND)); + return Single.error(new SkillSourceException("Skill directory not found: " + skillName)); } return Maybe.fromOptional(findSkillMd(skillDir)) .switchIfEmpty( - Single.error( - new SkillSourceException("SKILL.md not found in " + skillName, SKILL_NOT_FOUND))); + Single.error(new SkillSourceException("SKILL.md not found in " + skillName))); } @Override @@ -128,4 +132,23 @@ private Optional findSkillMd(Path dir) { .or(() -> Optional.of(dir.resolve("skill.md"))) .filter(Files::exists); } + + /** + * Validates that {@code component} does not escape {@code base} when resolved against it. + * + * @throws SkillSourceException if the resolved path would be outside {@code base} + */ + private static void validatePathWithinBase(Path base, String component) + throws SkillSourceException { + if (Path.of(component).isAbsolute()) { + throw new SkillSourceException("Absolute paths are not allowed: " + component); + } + Path normalizedBase = base.normalize().toAbsolutePath(); + Path resolved = base.resolve(component).normalize().toAbsolutePath(); + if (!resolved.startsWith(normalizedBase)) { + throw new SkillSourceException( + "Path traversal detected; component must remain within its parent directory: " + + component); + } + } } diff --git a/core/src/main/java/com/google/adk/skills/SkillSourceException.java b/core/src/main/java/com/google/adk/skills/SkillSourceException.java index 273428897..be23291da 100644 --- a/core/src/main/java/com/google/adk/skills/SkillSourceException.java +++ b/core/src/main/java/com/google/adk/skills/SkillSourceException.java @@ -22,43 +22,11 @@ */ public final class SkillSourceException extends Exception { - public static final String SKILL_LOAD_ERROR = "SKILL_LOAD_ERROR"; - public static final String SKILL_NOT_FOUND = "SKILL_NOT_FOUND"; - public static final String SKILL_FORMAT_ERROR = "SKILL_FORMAT_ERROR"; - public static final String RESOURCE_LOAD_ERROR = "RESOURCE_LOAD_ERROR"; - public static final String RESOURCE_NOT_FOUND = "RESOURCE_NOT_FOUND"; - - private final String errorCode; - - /** - * Constructs a new exception with the specified detail message and error code. - * - * @param message The detail message. - * @param errorCode The specific error code categorizing the failure. - */ - public SkillSourceException(String message, String errorCode) { + public SkillSourceException(String message) { super(message); - this.errorCode = errorCode; } - /** - * Constructs a new exception with the specified detail message, error code, and cause. - * - * @param message The detail message. - * @param errorCode The specific error code categorizing the failure. - * @param cause The cause. - */ - public SkillSourceException(String message, String errorCode, Throwable cause) { + public SkillSourceException(String message, Throwable cause) { super(message, cause); - this.errorCode = errorCode; - } - - /** - * Returns the error code categorizing the failure. - * - * @return The error code string. - */ - public String getErrorCode() { - return errorCode; } } diff --git a/core/src/main/java/com/google/adk/telemetry/Instrumentation.java b/core/src/main/java/com/google/adk/telemetry/Instrumentation.java index fd27878c9..a2c62ba12 100644 --- a/core/src/main/java/com/google/adk/telemetry/Instrumentation.java +++ b/core/src/main/java/com/google/adk/telemetry/Instrumentation.java @@ -125,12 +125,8 @@ public static final class AgentInvocation extends ClosableTelemetryScope { private final InvocationContext ctx; private final List events = Collections.synchronizedList(new ArrayList<>()); - public AgentInvocation(InvocationContext ctx, BaseAgent agent, Context parentContext) { - super( - Tracing.getTracer() - .spanBuilder("invoke_agent " + agent.name()) - .setParent(parentContext) - .startSpan()); + public AgentInvocation(InvocationContext ctx, BaseAgent agent) { + super(Tracing.getTracer().spanBuilder("invoke_agent " + agent.name()).startSpan()); this.agent = agent; this.ctx = ctx; Tracing.traceAgentInvocation(span, agent.name(), agent.description(), ctx); @@ -164,13 +160,8 @@ public static final class ToolExecution extends ClosableTelemetryScope { private final BaseAgent agent; private final Map functionArgs; - public ToolExecution( - BaseTool tool, BaseAgent agent, Map functionArgs, Context parentContext) { - super( - Tracing.getTracer() - .spanBuilder("execute_tool " + tool.name()) - .setParent(parentContext) - .startSpan()); + public ToolExecution(BaseTool tool, BaseAgent agent, Map functionArgs) { + super(Tracing.getTracer().spanBuilder("execute_tool " + tool.name()).startSpan()); this.tool = tool; this.agent = agent; this.functionArgs = functionArgs; @@ -205,22 +196,12 @@ protected void handleMetricsError(RuntimeException e) { /** Creates an AgentInvocation context to record agent invocation telemetry. */ public static AgentInvocation recordAgentInvocation(InvocationContext ctx, BaseAgent agent) { - return recordAgentInvocation(ctx, agent, Context.current()); - } - - public static AgentInvocation recordAgentInvocation( - InvocationContext ctx, BaseAgent agent, Context parentContext) { - return new AgentInvocation(ctx, agent, parentContext); + return new AgentInvocation(ctx, agent); } /** Creates a ToolExecution context to record tool execution telemetry. */ public static ToolExecution recordToolExecution( BaseTool tool, BaseAgent agent, Map functionArgs) { - return recordToolExecution(tool, agent, functionArgs, Context.current()); - } - - public static ToolExecution recordToolExecution( - BaseTool tool, BaseAgent agent, Map functionArgs, Context parentContext) { - return new ToolExecution(tool, agent, functionArgs, parentContext); + return new ToolExecution(tool, agent, functionArgs); } } diff --git a/core/src/main/java/com/google/adk/tools/BaseToolset.java b/core/src/main/java/com/google/adk/tools/BaseToolset.java index 84a5d8fc2..76369e5b9 100644 --- a/core/src/main/java/com/google/adk/tools/BaseToolset.java +++ b/core/src/main/java/com/google/adk/tools/BaseToolset.java @@ -17,8 +17,6 @@ package com.google.adk.tools; import com.google.adk.agents.ReadonlyContext; -import com.google.adk.models.LlmRequest; -import io.reactivex.rxjava3.core.Completable; import io.reactivex.rxjava3.core.Flowable; import java.util.List; import org.jspecify.annotations.Nullable; @@ -26,17 +24,11 @@ /** Base interface for toolsets. */ public interface BaseToolset extends AutoCloseable { - /** Processes the outgoing {@link LlmRequest.Builder}. */ - default Completable processLlmRequest( - LlmRequest.Builder llmRequestBuilder, ToolContext toolContext) { - return Completable.complete(); - } - /** * Return all tools in the toolset based on the provided context. * * @param readonlyContext Context used to filter tools available to the agent. - * @return A Flowable emitting tools available under the specified context. + * @return A Single emitting a list of tools available under the specified context. */ Flowable getTools(ReadonlyContext readonlyContext); diff --git a/core/src/main/java/com/google/adk/tools/mcp/DefaultMcpTransportBuilder.java b/core/src/main/java/com/google/adk/tools/mcp/DefaultMcpTransportBuilder.java index 6321da813..84c882c4e 100644 --- a/core/src/main/java/com/google/adk/tools/mcp/DefaultMcpTransportBuilder.java +++ b/core/src/main/java/com/google/adk/tools/mcp/DefaultMcpTransportBuilder.java @@ -1,7 +1,5 @@ package com.google.adk.tools.mcp; -import static com.google.common.base.Strings.isNullOrEmpty; - import com.google.common.collect.ImmutableMap; import io.modelcontextprotocol.client.transport.HttpClientSseClientTransport; import io.modelcontextprotocol.client.transport.HttpClientStreamableHttpTransport; @@ -10,8 +8,6 @@ import io.modelcontextprotocol.json.McpJsonDefaults; import io.modelcontextprotocol.json.McpJsonMapper; import io.modelcontextprotocol.spec.McpClientTransport; -import java.net.URI; -import java.net.URISyntaxException; import java.util.Collection; import java.util.Optional; import reactor.core.publisher.Mono; @@ -48,23 +44,15 @@ public McpClientTransport build(Object connectionParams) { .orElse("")))) .build(); } else if (connectionParams instanceof StreamableHttpServerParameters streamableParams) { - // Split the URL so the transport's URI.resolve does not drop a custom path (b/513186321). - SplitUri split = splitBaseAndEndpoint(streamableParams.url()); - HttpClientStreamableHttpTransport.Builder builder = - HttpClientStreamableHttpTransport.builder(split.baseUri()) - .connectTimeout(streamableParams.timeout()) - .jsonMapper(jsonMapper) - .asyncHttpRequestCustomizer( - (requestBuilder, method, uri, body, context) -> { - streamableParams - .headers() - .forEach((key, value) -> requestBuilder.header(key, value)); - return Mono.just(requestBuilder); - }); - if (split.endpoint() != null) { - builder.endpoint(split.endpoint()); - } - return builder.build(); + return HttpClientStreamableHttpTransport.builder(streamableParams.url()) + .connectTimeout(streamableParams.timeout()) + .jsonMapper(jsonMapper) + .asyncHttpRequestCustomizer( + (builder, method, uri, body, context) -> { + streamableParams.headers().forEach((key, value) -> builder.header(key, value)); + return Mono.just(builder); + }) + .build(); } else { throw new IllegalArgumentException( "DefaultMcpTransportBuilder supports only ServerParameters, SseServerParameters, or" @@ -72,36 +60,4 @@ public McpClientTransport build(Object connectionParams) { + connectionParams.getClass().getName()); } } - - /** - * Splits the URL into a base URI (scheme + authority) and endpoint (path + query + fragment). - * Returns a null endpoint when the URL has no meaningful path or cannot be split, so the - * transport falls back to its default endpoint. - */ - private static SplitUri splitBaseAndEndpoint(String url) { - URI uri; - try { - uri = new URI(url); - } catch (URISyntaxException e) { - return new SplitUri(url, null); - } - if (uri.getScheme() == null || uri.getAuthority() == null) { - return new SplitUri(url, null); - } - String path = uri.getRawPath(); - if (isNullOrEmpty(path) || path.equals("/")) { - return new SplitUri(url, null); - } - String baseUri = uri.getScheme() + "://" + uri.getAuthority(); - StringBuilder endpoint = new StringBuilder(path); - if (uri.getRawQuery() != null) { - endpoint.append('?').append(uri.getRawQuery()); - } - if (uri.getRawFragment() != null) { - endpoint.append('#').append(uri.getRawFragment()); - } - return new SplitUri(baseUri, endpoint.toString()); - } - - private record SplitUri(String baseUri, String endpoint) {} } diff --git a/core/src/main/java/com/google/adk/tools/skills/ListSkillsTool.java b/core/src/main/java/com/google/adk/tools/skills/ListSkillsTool.java deleted file mode 100644 index bc669632a..000000000 --- a/core/src/main/java/com/google/adk/tools/skills/ListSkillsTool.java +++ /dev/null @@ -1,59 +0,0 @@ -/* - * Copyright 2026 Google LLC - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package com.google.adk.tools.skills; - -import com.google.adk.skills.SkillSource; -import com.google.adk.tools.BaseTool; -import com.google.adk.tools.ToolContext; -import com.google.common.collect.ImmutableMap; -import com.google.genai.types.FunctionDeclaration; -import com.google.genai.types.Schema; -import com.google.genai.types.Type; -import io.reactivex.rxjava3.core.Single; -import java.util.Map; -import java.util.Optional; - -/** Tool to list all available skills. */ -final class ListSkillsTool extends BaseTool { - private final SkillSource skillSource; - - ListSkillsTool(SkillSource skillSource) { - super("list_skills", "Lists all available skills with their names and descriptions."); - this.skillSource = skillSource; - } - - @Override - public Optional declaration() { - return Optional.of( - FunctionDeclaration.builder() - .name(name()) - .description(description()) - .parameters( - Schema.builder().type(Type.Known.OBJECT).properties(ImmutableMap.of()).build()) - .build()); - } - - @Override - public Single> runAsync(Map args, ToolContext toolContext) { - return skillSource - .listFrontmatters() - .map(ImmutableMap::values) - .map(SkillToolset::getSkillsPrompt) - .>map(skills -> ImmutableMap.of("skills_xml", skills)) - .onErrorResumeNext(SkillToolset::createErrorResponse); - } -} diff --git a/core/src/main/java/com/google/adk/tools/skills/LoadSkillResourceTool.java b/core/src/main/java/com/google/adk/tools/skills/LoadSkillResourceTool.java deleted file mode 100644 index b3d0858e0..000000000 --- a/core/src/main/java/com/google/adk/tools/skills/LoadSkillResourceTool.java +++ /dev/null @@ -1,239 +0,0 @@ -/* - * Copyright 2026 Google LLC - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package com.google.adk.tools.skills; - -import static com.google.adk.tools.skills.SkillToolset.createErrorResponse; -import static com.google.common.collect.ImmutableList.toImmutableList; -import static java.net.URLConnection.guessContentTypeFromName; -import static java.net.URLConnection.guessContentTypeFromStream; -import static java.nio.charset.StandardCharsets.UTF_8; - -import com.google.adk.models.LlmRequest; -import com.google.adk.skills.SkillSource; -import com.google.adk.tools.BaseTool; -import com.google.adk.tools.ToolContext; -import com.google.common.base.Strings; -import com.google.common.collect.ImmutableList; -import com.google.common.collect.ImmutableMap; -import com.google.common.collect.ImmutableSet; -import com.google.common.collect.Iterables; -import com.google.common.io.ByteSource; -import com.google.genai.types.Content; -import com.google.genai.types.FunctionDeclaration; -import com.google.genai.types.FunctionResponse; -import com.google.genai.types.Part; -import com.google.genai.types.Schema; -import com.google.genai.types.Type; -import io.reactivex.rxjava3.core.Completable; -import io.reactivex.rxjava3.core.Single; -import java.io.ByteArrayInputStream; -import java.io.IOException; -import java.util.ArrayList; -import java.util.HashMap; -import java.util.List; -import java.util.Map; -import java.util.Optional; -import java.util.stream.Stream; - -/** Tool to load resources (references, assets, or scripts) from a skill. */ -final class LoadSkillResourceTool extends BaseTool { - - private static final ImmutableSet EXTRA_TEXT_MIME_TYPES = - ImmutableSet.of( - // go/keep-sorted start - "application/json", - "application/x-python", - "application/x-sh", - "application/x-shar", - "application/x-shellscript", - "application/xml", - "application/yaml" - // go/keep-sorted end - ); - private static final String BINARY_FILE_DETECTED_MSG = - "Binary file detected. The content has been included in the next part of the function" - + " response for you to analyze."; - private static final String SKILL_NAME = "skill_name"; - private static final String FILE_PATH = "file_path"; - private static final String CONTENT = "content"; - private static final String MIME_TYPE = "mime_type"; - - private final SkillSource skillSource; - - LoadSkillResourceTool(SkillSource skillSource) { - super( - "load_skill_resource", - "Loads a resource file (from references/, assets/, or scripts/) from within a skill."); - this.skillSource = skillSource; - } - - @Override - public Optional declaration() { - return Optional.of( - FunctionDeclaration.builder() - .name(name()) - .description(description()) - .parameters( - Schema.builder() - .type(Type.Known.OBJECT) - .properties( - ImmutableMap.of( - SKILL_NAME, - Schema.builder() - .type(Type.Known.STRING) - .description("The name of the skill.") - .build(), - FILE_PATH, - Schema.builder() - .type(Type.Known.STRING) - .description( - "The relative path to the resource (e.g.," - + " 'references/my_doc.md', 'assets/template.txt'," - + " or 'scripts/setup.sh').") - .build())) - .required(ImmutableList.of(SKILL_NAME, FILE_PATH)) - .build()) - .build()); - } - - @Override - public Single> runAsync(Map args, ToolContext toolContext) { - String skillName = (String) args.get(SKILL_NAME); - String resourcePath = (String) args.get(FILE_PATH); - - if (Strings.isNullOrEmpty(skillName)) { - return createErrorResponse("Skill name is required.", "MISSING_SKILL_NAME"); - } - if (Strings.isNullOrEmpty(resourcePath)) { - return createErrorResponse("Resource path is required.", "MISSING_RESOURCE_PATH"); - } - if (!resourcePath.startsWith("references/") - && !resourcePath.startsWith("assets/") - && !resourcePath.startsWith("scripts/")) { - return createErrorResponse( - "Path must start with 'references/', 'assets/', or 'scripts/'.", "INVALID_RESOURCE_PATH"); - } - - return skillSource - .loadResource(skillName, resourcePath) - .>map( - contentSource -> createResult(skillName, resourcePath, contentSource)) - .onErrorResumeNext(SkillToolset::createErrorResponse); - } - - private boolean hasBinaryContentResponse(FunctionResponse functionResponse) { - return functionResponse - .response() - .filter( - resp -> - resp.containsKey(SKILL_NAME) - && resp.containsKey(MIME_TYPE) - && resp.get(CONTENT) instanceof byte[]) - .isPresent(); - } - - @Override - public Completable processLlmRequest( - LlmRequest.Builder llmRequestBuilder, ToolContext toolContext) { - return super.processLlmRequest(llmRequestBuilder, toolContext) - .andThen( - Completable.fromRunnable( - () -> { - List contents = new ArrayList<>(llmRequestBuilder.build().contents()); - if (contents.isEmpty()) { - return; - } - - Content lastContent = Iterables.getLast(contents); - List parts = lastContent.parts().orElse(ImmutableList.of()); - - // Extract raw binary content into a dedicated binary Part - ImmutableList updatedParts = - parts.stream().flatMap(this::processPart).collect(toImmutableList()); - - if (!updatedParts.isEmpty()) { - contents.set( - contents.size() - 1, lastContent.toBuilder().parts(updatedParts).build()); - llmRequestBuilder.contents(contents); - } - })); - } - - /** - * Processes a {@link Part} to extract raw binary content from a function response. - * - *

If the part is a function response from this tool containing binary data, it returns a - * stream containing the updated function response part (with a placeholder message) and a new - * part containing the raw binary data. Otherwise, it returns an empty stream. - * - * @param part the {@link Part} to process - * @return a stream containing the processed parts, or an empty stream if the part does not - * contain a binary function response from this tool - */ - private Stream processPart(Part part) { - return part - .functionResponse() - .filter(funcResp -> funcResp.name().orElse("").equals(name())) - .filter(this::hasBinaryContentResponse) - .stream() - .flatMap( - funcResp -> - funcResp.response().stream() - .flatMap( - response -> { - Map newResponse = new HashMap<>(response); - - String mimeType = newResponse.remove(MIME_TYPE).toString(); - byte[] binaryContent = - (byte[]) newResponse.replace(CONTENT, BINARY_FILE_DETECTED_MSG); - - Part updatedPart = - part.toBuilder() - .functionResponse(funcResp.toBuilder().response(newResponse)) - .build(); - Part binaryPart = Part.fromBytes(binaryContent, mimeType); - - return Stream.of(updatedPart, binaryPart); - })); - } - - private ImmutableMap createResult( - String skillName, String resourcePath, ByteSource contentSource) throws IOException { - byte[] bytes = contentSource.read(); - // Special handling of shell script as the guessContentTypeFromName would return - // application/x-shar - String contentType = - resourcePath.endsWith(".sh") || resourcePath.endsWith(".bash") - ? "application/x-sh" - : guessContentTypeFromName(resourcePath); - if (contentType == null) { - contentType = guessContentTypeFromStream(new ByteArrayInputStream(bytes)); - } - if (contentType == null) { - contentType = "application/octet-stream"; - } - ImmutableMap.Builder builder = ImmutableMap.builder(); - builder.put(SKILL_NAME, skillName).put(FILE_PATH, resourcePath).put(MIME_TYPE, contentType); - - if (contentType.startsWith("text/") || EXTRA_TEXT_MIME_TYPES.contains(contentType)) { - builder.put(CONTENT, new String(bytes, UTF_8)); - } else { - builder.put(CONTENT, bytes); - } - return builder.buildOrThrow(); - } -} diff --git a/core/src/main/java/com/google/adk/tools/skills/LoadSkillTool.java b/core/src/main/java/com/google/adk/tools/skills/LoadSkillTool.java deleted file mode 100644 index eaad3773d..000000000 --- a/core/src/main/java/com/google/adk/tools/skills/LoadSkillTool.java +++ /dev/null @@ -1,97 +0,0 @@ -/* - * Copyright 2026 Google LLC - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package com.google.adk.tools.skills; - -import static com.google.adk.tools.skills.SkillToolset.createErrorResponse; - -import com.google.adk.skills.Frontmatter; -import com.google.adk.skills.SkillSource; -import com.google.adk.tools.BaseTool; -import com.google.adk.tools.ToolContext; -import com.google.common.base.Strings; -import com.google.common.collect.ImmutableList; -import com.google.common.collect.ImmutableMap; -import com.google.genai.types.FunctionDeclaration; -import com.google.genai.types.Schema; -import com.google.genai.types.Type; -import io.reactivex.rxjava3.core.Single; -import java.util.Map; -import java.util.Optional; - -/** Tool to load a skill's instructions. */ -final class LoadSkillTool extends BaseTool { - - private static final String SKILL_NAME = "skill_name"; - private final SkillSource skillSource; - - LoadSkillTool(SkillSource skillSource) { - super("load_skill", "Loads the SKILL.md instructions for a given skill."); - this.skillSource = skillSource; - } - - @Override - public Optional declaration() { - return Optional.of( - FunctionDeclaration.builder() - .name(name()) - .description(description()) - .parameters( - Schema.builder() - .type(Type.Known.OBJECT) - .properties( - ImmutableMap.of( - SKILL_NAME, - Schema.builder() - .type(Type.Known.STRING) - .description("The name of the skill to load.") - .build())) - .required(ImmutableList.of(SKILL_NAME)) - .build()) - .build()); - } - - @Override - public Single> runAsync(Map args, ToolContext toolContext) { - String skillName = (String) args.get(SKILL_NAME); - if (Strings.isNullOrEmpty(skillName)) { - return createErrorResponse("Skill name is required.", "MISSING_SKILL_NAME"); - } - - return skillSource - .loadFrontmatter(skillName) - .>zipWith( - skillSource.loadInstructions(skillName), - (frontmatter, instructions) -> - ImmutableMap.of( - "skill_name", - skillName, - "frontmatter", - frontmatterToMap(frontmatter), - "instructions", - instructions)) - .onErrorResumeNext(SkillToolset::createErrorResponse); - } - - private static ImmutableMap frontmatterToMap(Frontmatter fm) { - ImmutableMap.Builder builder = ImmutableMap.builder(); - builder.put("name", fm.name()).put("description", fm.description()); - fm.license().ifPresent(l -> builder.put("license", l)); - fm.compatibility().ifPresent(c -> builder.put("compatibility", c)); - fm.allowedTools().ifPresent(a -> builder.put("allowed-tools", a)); - return builder.put("metadata", fm.metadata()).buildOrThrow(); - } -} diff --git a/core/src/main/java/com/google/adk/tools/skills/SkillToolset.java b/core/src/main/java/com/google/adk/tools/skills/SkillToolset.java deleted file mode 100644 index d159bd39b..000000000 --- a/core/src/main/java/com/google/adk/tools/skills/SkillToolset.java +++ /dev/null @@ -1,130 +0,0 @@ -/* - * Copyright 2026 Google LLC - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package com.google.adk.tools.skills; - -import static java.util.Optional.ofNullable; - -import com.google.adk.agents.ReadonlyContext; -import com.google.adk.models.LlmRequest; -import com.google.adk.skills.Frontmatter; -import com.google.adk.skills.SkillSource; -import com.google.adk.skills.SkillSourceException; -import com.google.adk.tools.BaseTool; -import com.google.adk.tools.BaseToolset; -import com.google.adk.tools.ToolContext; -import com.google.common.collect.ImmutableList; -import com.google.common.collect.ImmutableMap; -import io.reactivex.rxjava3.core.Completable; -import io.reactivex.rxjava3.core.Flowable; -import io.reactivex.rxjava3.core.Single; -import java.util.Collection; -import java.util.Map; -import java.util.StringJoiner; - -/** - * A toolset for managing and interacting with agent skills. Provides tools to list, load, and run - * skills. - */ -public class SkillToolset implements BaseToolset { - - private static final String DEFAULT_SKILL_SYSTEM_INSTRUCTION = - """ - You can use specialized 'skills' to help you with complex tasks. You MUST use the skill tools to interact with these skills. - - Skills are folders of instructions and resources that extend your capabilities for specialized tasks. Each skill folder contains: - - **SKILL.md** (required): The main instruction file with skill metadata and detailed markdown instructions. - - **references/** (Optional): Additional documentation or examples for skill usage. - - **assets/** (Optional): Templates, scripts or other resources used by the skill. - - **scripts/** (Optional): Executable scripts that can be run via bash. - - This is very important: - - 1. If a skill seems relevant to the current user query, you MUST use the `load_skill` tool with `skill_name=""` to read its full instructions before proceeding. - 2. Once you have read the instructions, follow them exactly as documented before replying to the user. For example, If the instruction lists multiple steps, please make sure you complete all of them in order. - 3. The `load_skill_resource` tool is for viewing files within a skill's directory (e.g., `references/*`, `assets/*`, `scripts/*`). Do NOT use other tools to access these files. - 4. Use `run_skill_script` to run scripts from a skill's `scripts/` directory. Use `load_skill_resource` to view script content first if needed. - """; - - private final SkillSource skillSource; - private final ImmutableList coreTools; - private final String systemInstruction; - - /** Initializes the SkillToolset with a SkillSource and default execution settings. */ - public SkillToolset(SkillSource skillSource) { - this(skillSource, DEFAULT_SKILL_SYSTEM_INSTRUCTION); - } - - /** Initializes the SkillToolset with a SkillSource. */ - public SkillToolset(SkillSource skillSource, String systemInstruction) { - this.skillSource = skillSource; - this.systemInstruction = systemInstruction; - this.coreTools = - ImmutableList.of( - new ListSkillsTool(skillSource), - new LoadSkillTool(skillSource), - new LoadSkillResourceTool(skillSource)); - } - - @Override - public Flowable getTools(ReadonlyContext readonlyContext) { - return Flowable.fromIterable(coreTools); - } - - @Override - public Completable processLlmRequest( - LlmRequest.Builder llmRequestBuilder, ToolContext toolContext) { - return skillSource - .listFrontmatters() - .map(ImmutableMap::values) - .map(SkillToolset::getSkillsPrompt) - .map( - skills -> - llmRequestBuilder.appendInstructions(ImmutableList.of(systemInstruction, skills))) - .ignoreElement(); - } - - @Override - public void close() throws Exception { - // No resources to release for now - } - - static Single> createErrorResponse(String errorMessage, String errorCode) { - return Single.just(ImmutableMap.of("error", errorMessage, "error_code", errorCode)); - } - - static Single> createErrorResponse(Throwable t) { - if (t instanceof SkillSourceException ex) { - return Single.just( - ImmutableMap.of( - "error", - ofNullable(ex.getMessage()).orElse(ex.toString()), - "error_code", - ex.getErrorCode())); - } - return Single.error(t); - } - - static String getSkillsPrompt(Collection frontmatters) { - return frontmatters.stream() - .map(Frontmatter::toXml) - .reduce( - new StringJoiner("\n", "", "").setEmptyValue(""), - StringJoiner::add, - StringJoiner::merge) - .toString(); - } -} diff --git a/core/src/test/java/com/google/adk/agents/BaseAgentTest.java b/core/src/test/java/com/google/adk/agents/BaseAgentTest.java index a3436e6cb..5e2fa5792 100644 --- a/core/src/test/java/com/google/adk/agents/BaseAgentTest.java +++ b/core/src/test/java/com/google/adk/agents/BaseAgentTest.java @@ -22,42 +22,26 @@ import com.google.adk.agents.Callbacks.AfterAgentCallback; import com.google.adk.agents.Callbacks.BeforeAgentCallback; import com.google.adk.events.Event; -import com.google.adk.telemetry.Metrics; import com.google.adk.testing.TestBaseAgent; import com.google.adk.testing.TestCallback; import com.google.adk.testing.TestUtils; import com.google.common.collect.ImmutableList; import com.google.genai.types.Content; import com.google.genai.types.Part; -import io.opentelemetry.api.GlobalOpenTelemetry; -import io.opentelemetry.api.common.AttributeKey; -import io.opentelemetry.api.metrics.Meter; -import io.opentelemetry.sdk.OpenTelemetrySdk; -import io.opentelemetry.sdk.metrics.SdkMeterProvider; -import io.opentelemetry.sdk.metrics.data.HistogramPointData; -import io.opentelemetry.sdk.metrics.data.MetricData; -import io.opentelemetry.sdk.testing.exporter.InMemoryMetricReader; -import io.opentelemetry.sdk.testing.time.TestClock; -import io.opentelemetry.sdk.trace.SdkTracerProvider; import io.reactivex.rxjava3.core.Completable; import io.reactivex.rxjava3.core.Maybe; import java.util.List; import java.util.concurrent.atomic.AtomicBoolean; -import org.junit.After; -import org.junit.Before; import org.junit.Test; import org.junit.runner.RunWith; import org.junit.runners.JUnit4; @RunWith(JUnit4.class) public final class BaseAgentTest { + private static final String TEST_AGENT_NAME = "testAgent"; private static final String TEST_AGENT_DESCRIPTION = "A test agent"; - private InMemoryMetricReader inMemoryMetricReader; - private TestClock testClock; - private Meter originalMeter; - private static class ClosableTestAgent extends TestBaseAgent { final AtomicBoolean closed = new AtomicBoolean(false); @@ -72,35 +56,6 @@ public Completable close() { } } - @Before - public void setUp() { - GlobalOpenTelemetry.resetForTest(); - testClock = TestClock.create(); - inMemoryMetricReader = InMemoryMetricReader.create(); - SdkMeterProvider sdkMeterProvider = - SdkMeterProvider.builder() - .registerMetricReader(inMemoryMetricReader) - .setClock(testClock) - .build(); - - OpenTelemetrySdk openTelemetrySdk = - OpenTelemetrySdk.builder() - .setTracerProvider(SdkTracerProvider.builder().build()) - .setMeterProvider(sdkMeterProvider) - .build(); - - GlobalOpenTelemetry.set(openTelemetrySdk); - originalMeter = GlobalOpenTelemetry.getMeter("gcp.vertex.agent"); - Metrics.setMeterForTesting(openTelemetrySdk.getMeter("gcp.vertex.agent")); - } - - @After - public void tearDown() { - if (originalMeter != null) { - Metrics.setMeterForTesting(originalMeter); - } - } - @Test public void constructor_setsNameAndDescription() { String name = "testName"; @@ -218,36 +173,6 @@ public void runAsync_noCallbacks_invokesRunAsyncImpl() { assertThat(results).hasSize(1); assertThat(results.get(0).content()).hasValue(runAsyncImplContent); assertThat(runAsyncImpl.wasCalled()).isTrue(); - MetricData durationMetric = findMetricByName("gen_ai.agent.invocation.duration"); - assertThat(durationMetric.getUnit()).isEqualTo("ms"); - HistogramPointData durationPoint = - durationMetric.getHistogramData().getPoints().iterator().next(); - assertThat(durationPoint.getAttributes().get(AttributeKey.stringKey("gen_ai.agent.name"))) - .isEqualTo("testAgent"); - - MetricData reqSizeMetric = findMetricByName("gen_ai.agent.request.size"); - assertThat(reqSizeMetric.getUnit()).isEqualTo("By"); - HistogramPointData reqSizePoint = - reqSizeMetric.getHistogramData().getPoints().iterator().next(); - assertThat(reqSizePoint.getSum()).isEqualTo(12.0); - assertThat(reqSizePoint.getAttributes().get(AttributeKey.stringKey("gen_ai.agent.name"))) - .isEqualTo("testAgent"); - - MetricData respSizeMetric = findMetricByName("gen_ai.agent.response.size"); - assertThat(respSizeMetric.getUnit()).isEqualTo("By"); - HistogramPointData respSizePoint = - respSizeMetric.getHistogramData().getPoints().iterator().next(); - assertThat(respSizePoint.getSum()).isEqualTo(11.0); - assertThat(respSizePoint.getAttributes().get(AttributeKey.stringKey("gen_ai.agent.name"))) - .isEqualTo("testAgent"); - - MetricData workflowStepsMetric = findMetricByName("gen_ai.agent.workflow.steps"); - assertThat(workflowStepsMetric.getUnit()).isEqualTo("1"); - HistogramPointData workflowStepsPoint = - workflowStepsMetric.getHistogramData().getPoints().iterator().next(); - assertThat(workflowStepsPoint.getSum()).isEqualTo(1.0); - assertThat(workflowStepsPoint.getAttributes().get(AttributeKey.stringKey("gen_ai.agent.name"))) - .isEqualTo("testAgent"); } @Test @@ -702,11 +627,4 @@ public void close_twoLevelsSubAgents_closesAllSubAgents() { assertThat(subAgent.closed.get()).isTrue(); assertThat(subSubAgent.closed.get()).isTrue(); } - - private MetricData findMetricByName(String name) { - return inMemoryMetricReader.collectAllMetrics().stream() - .filter(m -> m.getName().equals(name)) - .findFirst() - .orElseThrow(() -> new AssertionError("Metric not found: " + name)); - } } diff --git a/core/src/test/java/com/google/adk/agents/LlmAgentTest.java b/core/src/test/java/com/google/adk/agents/LlmAgentTest.java index 26843bb56..e40a83aa0 100644 --- a/core/src/test/java/com/google/adk/agents/LlmAgentTest.java +++ b/core/src/test/java/com/google/adk/agents/LlmAgentTest.java @@ -440,7 +440,7 @@ public void close_closesToolsets() throws Exception { } @Test - public void close_closesToolsetsOnException() { + public void close_closesToolsetsOnException() throws Exception { ClosableToolset toolset1 = new ClosableToolset() { @Override @@ -494,7 +494,7 @@ public void runAsync_withTools_createsToolSpans() throws InterruptedException { List spans = openTelemetryRule.getSpans(); SpanData agentSpan = findSpanByName(spans, "invoke_agent test agent"); List llmSpans = findSpansByName(spans, "call_llm"); - List toolSpans = findSpansByName(spans, "execute_tool echo_tool"); + List toolSpans = findSpansByName(spans, "execute_tool [echo_tool]"); assertThat(llmSpans).hasSize(2); assertThat(toolSpans).hasSize(1); diff --git a/core/src/test/java/com/google/adk/flows/llmflows/BaseLlmFlowTest.java b/core/src/test/java/com/google/adk/flows/llmflows/BaseLlmFlowTest.java index a58a206d9..2a06c1f0a 100644 --- a/core/src/test/java/com/google/adk/flows/llmflows/BaseLlmFlowTest.java +++ b/core/src/test/java/com/google/adk/flows/llmflows/BaseLlmFlowTest.java @@ -25,12 +25,9 @@ import static com.google.adk.testing.TestUtils.createTestLlm; import static com.google.common.collect.Iterables.getOnlyElement; import static com.google.common.truth.Truth.assertThat; -import static org.junit.Assert.assertThrows; import com.google.adk.agents.Callbacks; import com.google.adk.agents.InvocationContext; -import com.google.adk.agents.LlmAgent; -import com.google.adk.agents.ReadonlyContext; import com.google.adk.events.Event; import com.google.adk.flows.llmflows.RequestProcessor.RequestProcessingResult; import com.google.adk.flows.llmflows.ResponseProcessor.ResponseProcessingResult; @@ -38,7 +35,6 @@ import com.google.adk.models.LlmResponse; import com.google.adk.testing.TestLlm; import com.google.adk.tools.BaseTool; -import com.google.adk.tools.BaseToolset; import com.google.adk.tools.ToolContext; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; @@ -51,7 +47,6 @@ import io.opentelemetry.context.Context; import io.opentelemetry.context.ContextKey; import io.opentelemetry.context.Scope; -import io.reactivex.rxjava3.core.Completable; import io.reactivex.rxjava3.core.Flowable; import io.reactivex.rxjava3.core.Maybe; import io.reactivex.rxjava3.core.Single; @@ -767,112 +762,4 @@ public void postprocess_noResponseProcessors_onlyUsageMetadata_returnsEvent() { assertThat(event.author()).isEqualTo(invocationContext.agent().name()); assertThat(event.invocationId()).isEqualTo(invocationContext.invocationId()); } - - @Test - public void getRequestProcessorFromTools_sequentiallyAppliesToolProcessors() { - BaseTool tool1 = - new BaseTool("tool1", "test tool 1") { - @Override - public Completable processLlmRequest( - LlmRequest.Builder builder, ToolContext toolContext) { - return Completable.fromAction( - () -> builder.appendInstructions(ImmutableList.of("instruction1"))); - } - }; - BaseTool tool2 = - new BaseTool("tool2", "test tool 2") { - @Override - public Completable processLlmRequest( - LlmRequest.Builder builder, ToolContext toolContext) { - return Completable.fromAction( - () -> builder.appendInstructions(ImmutableList.of("instruction2"))); - } - }; - - LlmAgent agent = - createTestAgentBuilder(createTestLlm(LlmResponse.builder().build())) - .tools(tool1, tool2) - .build(); - - InvocationContext invocationContext = createInvocationContext(agent); - BaseLlmFlow baseLlmFlow = createBaseLlmFlowWithoutProcessors(); - RequestProcessor requestProcessor = baseLlmFlow.getRequestProcessorFromTools(agent); - - LlmRequest processedRequest = - requestProcessor - .processRequest(invocationContext, LlmRequest.builder().build()) - .map(RequestProcessingResult::updatedRequest) - .blockingGet(); - - assertThat(processedRequest.getSystemInstructions()) - .containsExactly("instruction1\n\ninstruction2"); - } - - @Test - public void getRequestProcessorFromTools_appliesToolsetAndItsToolsProcessors() { - BaseTool tool1 = - new BaseTool("tool1", "test tool 1") { - @Override - public Completable processLlmRequest( - LlmRequest.Builder builder, ToolContext toolContext) { - return Completable.fromAction( - () -> builder.appendInstructions(ImmutableList.of("tool-instruction"))); - } - }; - - BaseToolset toolset = - new BaseToolset() { - @Override - public Flowable getTools(ReadonlyContext readonlyContext) { - return Flowable.just(tool1); - } - - @Override - public Completable processLlmRequest( - LlmRequest.Builder builder, ToolContext toolContext) { - return Completable.fromAction( - () -> builder.appendInstructions(ImmutableList.of("toolset-instruction"))); - } - - @Override - public void close() {} - }; - - LlmAgent agent = - createTestAgentBuilder(createTestLlm(LlmResponse.builder().build())).tools(toolset).build(); - - InvocationContext invocationContext = createInvocationContext(agent); - BaseLlmFlow baseLlmFlow = createBaseLlmFlowWithoutProcessors(); - RequestProcessor requestProcessor = baseLlmFlow.getRequestProcessorFromTools(agent); - - LlmRequest processedRequest = - requestProcessor - .processRequest(invocationContext, LlmRequest.builder().build()) - .map(RequestProcessingResult::updatedRequest) - .blockingGet(); - - assertThat(processedRequest.getSystemInstructions()) - .containsExactly("toolset-instruction\n\ntool-instruction"); - } - - @Test - public void getRequestProcessorFromTools_throwsOnUnsupportedType() { - LlmAgent agent = - createTestAgentBuilder(createTestLlm(LlmResponse.builder().build())) - .tools("unsupported-tool-type-string") - .build(); - - InvocationContext invocationContext = createInvocationContext(agent); - BaseLlmFlow baseLlmFlow = createBaseLlmFlowWithoutProcessors(); - RequestProcessor requestProcessor = baseLlmFlow.getRequestProcessorFromTools(agent); - - LlmRequest request = LlmRequest.builder().build(); - IllegalArgumentException thrown = - assertThrows( - IllegalArgumentException.class, - () -> requestProcessor.processRequest(invocationContext, request)); - assertThat(thrown) - .hasMessageThat() - .contains("Object in tools list is not of a supported type: java.lang.String"); - } } diff --git a/core/src/test/java/com/google/adk/plugins/agentanalytics/BigQueryAgentAnalyticsPluginTest.java b/core/src/test/java/com/google/adk/plugins/agentanalytics/BigQueryAgentAnalyticsPluginTest.java index 6d65e2f15..836442cad 100644 --- a/core/src/test/java/com/google/adk/plugins/agentanalytics/BigQueryAgentAnalyticsPluginTest.java +++ b/core/src/test/java/com/google/adk/plugins/agentanalytics/BigQueryAgentAnalyticsPluginTest.java @@ -17,7 +17,6 @@ package com.google.adk.plugins.agentanalytics; import static com.google.common.collect.ImmutableList.toImmutableList; -import static java.nio.charset.StandardCharsets.UTF_8; import static java.util.concurrent.TimeUnit.MILLISECONDS; import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertFalse; @@ -25,13 +24,11 @@ import static org.junit.Assert.assertTrue; import static org.mockito.ArgumentMatchers.any; import static org.mockito.ArgumentMatchers.anyMap; -import static org.mockito.ArgumentMatchers.anyString; import static org.mockito.Mockito.atLeastOnce; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.verify; import static org.mockito.Mockito.when; -import com.fasterxml.jackson.databind.JsonNode; import com.fasterxml.jackson.databind.node.ArrayNode; import com.fasterxml.jackson.databind.node.ObjectNode; import com.google.adk.agents.BaseAgent; @@ -59,8 +56,6 @@ import com.google.cloud.bigquery.storage.v1.AppendRowsResponse; import com.google.cloud.bigquery.storage.v1.BigQueryWriteClient; import com.google.cloud.bigquery.storage.v1.StreamWriter; -import com.google.cloud.storage.BlobInfo; -import com.google.cloud.storage.Storage; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; import com.google.genai.types.Candidate; @@ -807,7 +802,6 @@ public void logEvent_handlesExceptionFromFormatter() throws Exception { (content, eventType) -> { throw new RuntimeException("Formatter error"); }; - BigQueryLoggerConfig formattedConfig = config.toBuilder().contentFormatter(formatter).build(); PluginState formattedState = new PluginState(formattedConfig) { @@ -1048,133 +1042,6 @@ public void logEvent_createsUniqueProcessorPerInvocation() throws Exception { testExecutor.shutdown(); } - @Test - public void logEvent_offloadsToGcs_whenLargeContent() throws Exception { - GcsOffloader mockOffloader = mock(GcsOffloader.class); - when(mockOffloader.uploadContent(anyString(), anyString(), anyString())) - .thenReturn(CompletableFuture.completedFuture("gs://test-bucket/large.txt")); - - BigQueryLoggerConfig gcsConfig = config.toBuilder().gcsBucketName("test-bucket").build(); - PluginState gcsState = - new PluginState(gcsConfig) { - @Override - protected BigQueryWriteClient createWriteClient(BigQueryLoggerConfig config) { - return mockWriteClient; - } - - @Override - protected StreamWriter createWriter() { - return mockWriter; - } - - @Override - protected GcsOffloader getGcsOffloader(BigQueryLoggerConfig config) { - return mockOffloader; - } - }; - BigQueryAgentAnalyticsPlugin gcsPlugin = - new BigQueryAgentAnalyticsPlugin(gcsConfig, mockBigQuery, gcsState); - - // Large text (> 32KB default threshold) - String largeText = "a".repeat(40000); - Content content = Content.fromParts(Part.fromText(largeText)); - gcsPlugin.onUserMessageCallback(mockInvocationContext, content).blockingSubscribe(); - - verify(mockOffloader, atLeastOnce()).uploadContent(anyString(), anyString(), anyString()); - - Map row = gcsState.getBatchProcessor("invocation_id").queue.poll(); - assertNotNull(row); - @SuppressWarnings("unchecked") // Test only - List contentParts = (List) row.get("content_parts"); - assertEquals("GCS_REFERENCE", contentParts.get(0).get("storage_mode").asText()); - assertEquals("gs://test-bucket/large.txt", contentParts.get(0).get("uri").asText()); - } - - @Test - public void logEvent_offloadsToGcs_whenMultimodalContent() throws Exception { - GcsOffloader mockOffloader = mock(GcsOffloader.class); - when(mockOffloader.uploadContent(any(byte[].class), anyString(), anyString())) - .thenReturn(CompletableFuture.completedFuture("gs://test-bucket/image.png")); - - BigQueryLoggerConfig gcsConfig = config.toBuilder().gcsBucketName("test-bucket").build(); - PluginState gcsState = - new PluginState(gcsConfig) { - @Override - protected BigQueryWriteClient createWriteClient(BigQueryLoggerConfig config) { - return mockWriteClient; - } - - @Override - protected StreamWriter createWriter() { - return mockWriter; - } - - @Override - protected GcsOffloader getGcsOffloader(BigQueryLoggerConfig config) { - return mockOffloader; - } - }; - BigQueryAgentAnalyticsPlugin gcsPlugin = - new BigQueryAgentAnalyticsPlugin(gcsConfig, mockBigQuery, gcsState); - - Content content = Content.fromParts(Part.fromBytes("test-data".getBytes(UTF_8), "image/png")); - gcsPlugin.onUserMessageCallback(mockInvocationContext, content).blockingSubscribe(); - - verify(mockOffloader, atLeastOnce()).uploadContent(any(byte[].class), anyString(), anyString()); - - Map row = gcsState.getBatchProcessor("invocation_id").queue.poll(); - assertNotNull(row); - @SuppressWarnings("unchecked") // Test only - List contentParts = (List) row.get("content_parts"); - assertEquals("GCS_REFERENCE", contentParts.get(0).get("storage_mode").asText()); - assertEquals("gs://test-bucket/image.png", contentParts.get(0).get("uri").asText()); - } - - @Test - public void logEvent_integrationWithRealGcsOffloader_whenLargeContent() throws Exception { - Storage mockStorage = mock(Storage.class); - - BigQueryLoggerConfig gcsConfig = config.toBuilder().gcsBucketName("test-bucket").build(); - PluginState gcsState = - new PluginState(gcsConfig) { - @Override - protected BigQueryWriteClient createWriteClient(BigQueryLoggerConfig config) { - return mockWriteClient; - } - - @Override - protected StreamWriter createWriter() { - return mockWriter; - } - - @Override - protected GcsOffloader getGcsOffloader(BigQueryLoggerConfig config) { - return new GcsOffloader( - config.projectId(), - config.gcsBucketName(), - Runnable::run, // Use direct executor for synchronous execution - config.credentials(), - mockStorage); - } - }; - BigQueryAgentAnalyticsPlugin gcsPlugin = - new BigQueryAgentAnalyticsPlugin(gcsConfig, mockBigQuery, gcsState); - - // Large text (> 32KB default threshold) - String largeText = "a".repeat(40000); - Content content = Content.fromParts(Part.fromText(largeText)); - gcsPlugin.onUserMessageCallback(mockInvocationContext, content).blockingSubscribe(); - - verify(mockStorage, atLeastOnce()).create(any(BlobInfo.class), any(byte[].class)); - - Map row = gcsState.getBatchProcessor("invocation_id").queue.poll(); - assertNotNull(row); - @SuppressWarnings("unchecked") // Test only - List contentParts = (List) row.get("content_parts"); - assertEquals("GCS_REFERENCE", contentParts.get(0).get("storage_mode").asText()); - assertTrue(contentParts.get(0).get("uri").asText().startsWith("gs://test-bucket/")); - } - private static class FakeAgent extends BaseAgent { FakeAgent(String name) { super(name, "description", null, null, null); diff --git a/core/src/test/java/com/google/adk/plugins/agentanalytics/GcsOffloaderTest.java b/core/src/test/java/com/google/adk/plugins/agentanalytics/GcsOffloaderTest.java deleted file mode 100644 index 7b8de3813..000000000 --- a/core/src/test/java/com/google/adk/plugins/agentanalytics/GcsOffloaderTest.java +++ /dev/null @@ -1,122 +0,0 @@ -/* - * Copyright 2026 Google LLC - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package com.google.adk.plugins.agentanalytics; - -import static java.nio.charset.StandardCharsets.UTF_8; -import static org.junit.Assert.assertEquals; -import static org.junit.Assert.assertThrows; -import static org.junit.Assert.assertTrue; -import static org.mockito.ArgumentMatchers.any; -import static org.mockito.Mockito.mock; -import static org.mockito.Mockito.never; -import static org.mockito.Mockito.verify; - -import com.google.cloud.storage.BlobId; -import com.google.cloud.storage.BlobInfo; -import com.google.cloud.storage.Storage; -import java.util.concurrent.CompletableFuture; -import java.util.concurrent.ExecutionException; -import java.util.concurrent.Executor; -import java.util.concurrent.ExecutorService; -import java.util.concurrent.Executors; -import java.util.concurrent.RejectedExecutionException; -import org.junit.After; -import org.junit.Before; -import org.junit.Test; -import org.junit.runner.RunWith; -import org.junit.runners.JUnit4; -import org.mockito.ArgumentCaptor; - -@RunWith(JUnit4.class) -public final class GcsOffloaderTest { - private static final String PROJECT_ID = "test-project"; - private static final String BUCKET_NAME = "test-bucket"; - private static final String PATH = "test-path/file.txt"; - private static final String CONTENT_TYPE = "text/plain"; - - private Storage mockStorage; - private ExecutorService executor; - private GcsOffloader gcsOffloader; - - @Before - public void setUp() { - mockStorage = mock(Storage.class); - executor = Executors.newSingleThreadExecutor(); - gcsOffloader = new GcsOffloader(PROJECT_ID, BUCKET_NAME, executor, null, mockStorage); - } - - @After - public void tearDown() { - executor.shutdown(); - } - - @Test - public void uploadContent_bytes_succeeds() throws Exception { - byte[] data = "hello world".getBytes(UTF_8); - CompletableFuture future = gcsOffloader.uploadContent(data, CONTENT_TYPE, PATH); - - String result = future.get(); - - assertEquals("gs://" + BUCKET_NAME + "/" + PATH, result); - - ArgumentCaptor blobInfoCaptor = ArgumentCaptor.forClass(BlobInfo.class); - verify(mockStorage).create(blobInfoCaptor.capture(), any(byte[].class)); - - BlobInfo blobInfo = blobInfoCaptor.getValue(); - assertEquals(BlobId.of(BUCKET_NAME, PATH), blobInfo.getBlobId()); - assertEquals(CONTENT_TYPE, blobInfo.getContentType()); - } - - @Test - public void uploadContent_string_succeeds() throws Exception { - String data = "hello world string"; - CompletableFuture future = gcsOffloader.uploadContent(data, CONTENT_TYPE, PATH); - - String result = future.get(); - - assertEquals("gs://" + BUCKET_NAME + "/" + PATH, result); - - ArgumentCaptor blobInfoCaptor = ArgumentCaptor.forClass(BlobInfo.class); - verify(mockStorage).create(blobInfoCaptor.capture(), any(byte[].class)); - - BlobInfo blobInfo = blobInfoCaptor.getValue(); - assertEquals(BlobId.of(BUCKET_NAME, PATH), blobInfo.getBlobId()); - assertEquals(CONTENT_TYPE, blobInfo.getContentType()); - } - - @Test - public void uploadContent_executorRejected_returnsFailedFuture() { - Executor rejectingExecutor = - r -> { - throw new RejectedExecutionException("Rejected"); - }; - GcsOffloader offloaderWithRejectingExecutor = - new GcsOffloader(PROJECT_ID, BUCKET_NAME, rejectingExecutor, null, mockStorage); - - CompletableFuture future = - offloaderWithRejectingExecutor.uploadContent("data".getBytes(UTF_8), CONTENT_TYPE, PATH); - - assertTrue(future.isCompletedExceptionally()); - assertThrows(ExecutionException.class, future::get); - } - - @Test - public void close_doesNotCloseStorageOverride() throws Exception { - gcsOffloader.close(); - verify(mockStorage, never()).close(); - } -} diff --git a/core/src/test/java/com/google/adk/plugins/agentanalytics/JsonFormatterTest.java b/core/src/test/java/com/google/adk/plugins/agentanalytics/JsonFormatterTest.java index 3ade94093..4883438b6 100644 --- a/core/src/test/java/com/google/adk/plugins/agentanalytics/JsonFormatterTest.java +++ b/core/src/test/java/com/google/adk/plugins/agentanalytics/JsonFormatterTest.java @@ -16,24 +16,16 @@ package com.google.adk.plugins.agentanalytics; -import static java.nio.charset.StandardCharsets.UTF_8; import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertFalse; import static org.junit.Assert.assertNull; import static org.junit.Assert.assertTrue; -import static org.mockito.ArgumentMatchers.any; -import static org.mockito.ArgumentMatchers.anyString; -import static org.mockito.Mockito.mock; -import static org.mockito.Mockito.when; import com.fasterxml.jackson.databind.JsonNode; -import com.fasterxml.jackson.databind.ObjectMapper; import com.fasterxml.jackson.databind.node.ArrayNode; -import com.fasterxml.jackson.databind.node.ObjectNode; import com.google.adk.models.LlmRequest; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; -import com.google.genai.types.Blob; import com.google.genai.types.Content; import com.google.genai.types.FileData; import com.google.genai.types.FunctionCall; @@ -41,7 +33,6 @@ import com.google.genai.types.Part; import java.util.Arrays; import java.util.List; -import java.util.concurrent.CompletableFuture; import org.junit.Test; import org.junit.runner.RunWith; import org.junit.runners.JUnit4; @@ -58,8 +49,7 @@ public void parse_llmRequest_populatesPrompt() throws Exception { Content.fromParts(Part.fromText("hello")).toBuilder().role("user").build())) .build(); - Parser.ParsedContent result = - new Parser(null, 100, null, true).parse(request, "trace", "span").get(); + Parser.ParsedContent result = new Parser(100).parse(request).get(); assertTrue(result.content().has("prompt")); ArrayNode prompt = (ArrayNode) result.content().get("prompt"); @@ -79,8 +69,7 @@ public void parse_llmRequest_populatesSystemPrompt() throws Exception { .build()) .build(); - Parser.ParsedContent result = - new Parser(null, 100, null, true).parse(request, "trace", "span").get(); + Parser.ParsedContent result = new Parser(100).parse(request).get(); assertTrue(result.content().has("system_prompt")); assertEquals("be helpful", result.content().get("system_prompt").asText()); @@ -90,8 +79,7 @@ public void parse_llmRequest_populatesSystemPrompt() throws Exception { @Test public void parse_string_truncates() throws Exception { String longString = "this is a very long string that should be truncated"; - Parser.ParsedContent result = - new Parser(null, 24, null, true).parse(longString, "trace", "span").get(); + Parser.ParsedContent result = new Parser(24).parse(longString).get(); assertTrue(result.isTruncated()); assertEquals("this is a ...[truncated]", result.content().asText()); @@ -101,8 +89,7 @@ public void parse_string_truncates() throws Exception { public void parse_map_truncatesNested() throws Exception { ImmutableMap map = ImmutableMap.of("key", "this is a very long value that should definitely be truncated"); - Parser.ParsedContent result = - new Parser(null, 24, null, true).parse(map, "trace", "span").get(); + Parser.ParsedContent result = new Parser(24).parse(map).get(); assertTrue(result.isTruncated()); assertEquals("this is a ...[truncated]", result.content().get("key").asText()); @@ -111,8 +98,7 @@ public void parse_map_truncatesNested() throws Exception { @Test public void parse_content_returnsSummary() throws Exception { Content content = Content.fromParts(Part.fromText("part 1"), Part.fromText("part 2")); - Parser.ParsedContent result = - new Parser(null, 100, null, true).parse(content, "trace", "span").get(); + Parser.ParsedContent result = new Parser(100).parse(content).get(); assertEquals("part 1 | part 2", result.content().get("text_summary").asText()); assertEquals(2, result.parts().size()); @@ -123,8 +109,7 @@ public void parse_content_withFileData() throws Exception { FileData fileData = FileData.builder().fileUri("gs://bucket/file.txt").mimeType("text/plain").build(); Content content = Content.fromParts(Part.builder().fileData(fileData).build()); - Parser.ParsedContent result = - new Parser(null, 100, null, true).parse(content, "trace", "span").get(); + Parser.ParsedContent result = new Parser(100).parse(content).get(); assertEquals(1, result.parts().size()); JsonNode partData = result.parts().get(0); @@ -137,8 +122,7 @@ public void parse_content_withFileData() throws Exception { public void parse_content_withFunctionCall() throws Exception { FunctionCall fc = FunctionCall.builder().name("myFunction").build(); Content content = Content.fromParts(Part.builder().functionCall(fc).build()); - Parser.ParsedContent result = - new Parser(null, 100, null, true).parse(content, "trace", "span").get(); + Parser.ParsedContent result = new Parser(100).parse(content).get(); assertEquals(1, result.parts().size()); JsonNode partData = result.parts().get(0); @@ -151,8 +135,7 @@ public void parse_content_withFunctionCall() throws Exception { public void parse_list_truncatesElements() throws Exception { List list = Arrays.asList("short", "this is a very long string that should be truncated"); - Parser.ParsedContent result = - new Parser(null, 24, null, true).parse(list, "trace", "span").get(); + Parser.ParsedContent result = new Parser(24).parse(list).get(); assertTrue(result.isTruncated()); JsonNode arrayNode = result.content(); @@ -162,44 +145,6 @@ public void parse_list_truncatesElements() throws Exception { assertEquals("this is a ...[truncated]", arrayNode.get(1).asText()); } - @Test - public void parse_withOffloader_offloadsLargeText() throws Exception { - GcsOffloader offloader = mock(GcsOffloader.class); - when(offloader.uploadContent(anyString(), anyString(), anyString())) - .thenReturn(CompletableFuture.completedFuture("gs://mock-bucket/path")); - - Content content = - Content.fromParts(Part.fromText("this text is longer than 10 characters".repeat(100))); - Parser.ParsedContent result = - new Parser(offloader, 10, "conn", true).parse(content, "trace", "span").get(); - - assertEquals(1, result.parts().size()); - JsonNode partData = result.parts().get(0); - assertEquals("GCS_REFERENCE", partData.get("storage_mode").asText()); - assertEquals("gs://mock-bucket/path", partData.get("uri").asText()); - assertTrue(partData.get("text").asText().contains("[OFFLOADED]")); - assertEquals("conn", partData.get("object_ref").get("authorizer").asText()); - } - - @Test - public void parse_withOffloader_offloadsBinaryData() throws Exception { - GcsOffloader offloader = mock(GcsOffloader.class); - when(offloader.uploadContent(any(byte[].class), anyString(), anyString())) - .thenReturn(CompletableFuture.completedFuture("gs://mock-bucket/image.png")); - - Blob blob = Blob.builder().data("fake-image".getBytes(UTF_8)).mimeType("image/png").build(); - Content content = Content.fromParts(Part.builder().inlineData(blob).build()); - Parser.ParsedContent result = - new Parser(offloader, 100, "conn", true).parse(content, "trace", "span").get(); - - assertEquals(1, result.parts().size()); - JsonNode partData = result.parts().get(0); - assertEquals("GCS_REFERENCE", partData.get("storage_mode").asText()); - assertEquals("gs://mock-bucket/image.png", partData.get("uri").asText()); - assertEquals("image/png", partData.get("mime_type").asText()); - assertEquals("[MEDIA OFFLOADED]", partData.get("text").asText()); - } - @Test public void truncate_variousInputs() { assertNull(JsonFormatter.truncate(null, 10)); @@ -243,8 +188,7 @@ public void parse_multibyteString_truncatesBasedOnBytes() throws Exception { // "こんにちはこんにちは" is 30 bytes, but 10 characters. String nihongo = "こんにちはこんにちは"; // With budget 20, effective budget is 6, so only 2 characters (6 bytes) should be kept. - Parser.ParsedContent result = - new Parser(null, 20, null, true).parse(nihongo, "trace", "span").get(); + Parser.ParsedContent result = new Parser(20).parse(nihongo).get(); assertTrue(result.isTruncated()); assertEquals("こん...[truncated]", result.content().asText()); @@ -253,23 +197,9 @@ public void parse_multibyteString_truncatesBasedOnBytes() throws Exception { @Test public void parse_multibyteContent_truncatesBasedOnBytes() throws Exception { Content content = Content.fromParts(Part.fromText("こんにちはこんにちは")); - Parser.ParsedContent result = - new Parser(null, 20, null, true).parse(content, "trace", "span").get(); + Parser.ParsedContent result = new Parser(20).parse(content).get(); assertTrue(result.isTruncated()); assertEquals("こん...[truncated]", result.content().get("text_summary").asText()); } - - @Test - public void smartTruncate_withCycle_detectsCycle() { - ObjectMapper mapper = new ObjectMapper(); - ObjectNode node = mapper.createObjectNode(); - node.set("child", node); - - // Verify that smartTruncate handles circular JsonNode structures by detecting the cycle. - JsonFormatter.TruncationResult result = JsonFormatter.smartTruncate(node, 100); - - assertTrue(result.isTruncated()); - assertEquals("[cycle detected]", result.node().get("child").asText()); - } } diff --git a/core/src/test/java/com/google/adk/plugins/agentanalytics/MimeTypeMapperTest.java b/core/src/test/java/com/google/adk/plugins/agentanalytics/MimeTypeMapperTest.java deleted file mode 100644 index 4930b28be..000000000 --- a/core/src/test/java/com/google/adk/plugins/agentanalytics/MimeTypeMapperTest.java +++ /dev/null @@ -1,52 +0,0 @@ -/* - * Copyright 2026 Google LLC - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package com.google.adk.plugins.agentanalytics; - -import static org.junit.Assert.assertEquals; - -import org.junit.Test; -import org.junit.runner.RunWith; -import org.junit.runners.JUnit4; - -@RunWith(JUnit4.class) -public final class MimeTypeMapperTest { - - @Test - public void getExtension_commonImages_returnsExtension() { - assertEquals(".jpg", MimeTypeMapper.getExtension("image/jpeg")); - assertEquals(".png", MimeTypeMapper.getExtension("image/png")); - assertEquals(".gif", MimeTypeMapper.getExtension("image/gif")); - } - - @Test - public void getExtension_commonAudio_returnsExtension() { - assertEquals(".mp3", MimeTypeMapper.getExtension("audio/mpeg")); - assertEquals(".wav", MimeTypeMapper.getExtension("audio/wav")); - } - - @Test - public void getExtension_commonVideo_returnsExtension() { - assertEquals(".mp4", MimeTypeMapper.getExtension("video/mp4")); - assertEquals(".mov", MimeTypeMapper.getExtension("video/quicktime")); - } - - @Test - public void getExtension_unknownType_returnsEmptyString() { - assertEquals("", MimeTypeMapper.getExtension("application/octet-stream")); - assertEquals("", MimeTypeMapper.getExtension("text/plain")); - } -} diff --git a/core/src/test/java/com/google/adk/plugins/agentanalytics/ParserTest.java b/core/src/test/java/com/google/adk/plugins/agentanalytics/ParserTest.java index 385e81082..9bae03331 100644 --- a/core/src/test/java/com/google/adk/plugins/agentanalytics/ParserTest.java +++ b/core/src/test/java/com/google/adk/plugins/agentanalytics/ParserTest.java @@ -38,13 +38,13 @@ public final class ParserTest { @Before public void setUp() { - parser = new Parser(null, 100, "connectionId", true); + parser = new Parser(100); } @Test public void parse_part_coversLine280() throws Exception { Part part = Part.fromText("test part"); - CompletableFuture future = parser.parse(part, "traceId", "spanId"); + CompletableFuture future = parser.parse(part); Parser.ParsedContent result = future.get(); assertEquals("{\"text_summary\":\"test part\"}", result.content().toString()); @@ -56,7 +56,7 @@ public void parse_part_coversLine280() throws Exception { public void parse_part_withInlineData_coversProcessPart() throws Exception { Blob blob = Blob.builder().mimeType("image/png").data(new byte[] {1, 2, 3}).build(); Part part = Part.builder().inlineData(blob).build(); - CompletableFuture future = parser.parse(part, "traceId", "spanId"); + CompletableFuture future = parser.parse(part); Parser.ParsedContent result = future.get(); assertEquals(1, result.parts().size()); @@ -104,7 +104,7 @@ public void parse_multipartContent_coversLine310() throws Exception { // Call private method using helper if necessary, but parseContentObject is private. // However, parse(Object content, ...) calls it. - CompletableFuture future = parser.parse(content, "traceId", "spanId"); + CompletableFuture future = parser.parse(content); Parser.ParsedContent result = future.get(); assertTrue(result.isTruncated()); diff --git a/core/src/test/java/com/google/adk/plugins/agentanalytics/PluginStateTest.java b/core/src/test/java/com/google/adk/plugins/agentanalytics/PluginStateTest.java index 14dcc390e..444cc8a6d 100644 --- a/core/src/test/java/com/google/adk/plugins/agentanalytics/PluginStateTest.java +++ b/core/src/test/java/com/google/adk/plugins/agentanalytics/PluginStateTest.java @@ -17,7 +17,6 @@ package com.google.adk.plugins.agentanalytics; import static java.util.concurrent.TimeUnit.SECONDS; -import static org.junit.Assert.assertNull; import static org.junit.Assert.assertTrue; import static org.mockito.ArgumentMatchers.any; import static org.mockito.Mockito.atLeastOnce; @@ -29,14 +28,10 @@ import com.google.cloud.bigquery.storage.v1.AppendRowsResponse; import com.google.cloud.bigquery.storage.v1.BigQueryWriteClient; import com.google.cloud.bigquery.storage.v1.StreamWriter; -import com.google.common.util.concurrent.Uninterruptibles; import java.io.IOException; -import java.lang.reflect.Field; import java.time.Duration; import java.time.Instant; import java.util.concurrent.CompletableFuture; -import java.util.concurrent.ExecutorService; -import java.util.concurrent.TimeUnit; import java.util.logging.Handler; import java.util.logging.Level; import java.util.logging.LogRecord; @@ -70,6 +65,10 @@ protected BigQueryWriteClient createWriteClient(BigQueryLoggerConfig config) { return mockWriteClient; } + BigQueryWriteClient getMockWriteClient() { + return mockWriteClient; + } + @Override protected StreamWriter createWriter() { StreamWriter writer = mock(StreamWriter.class); @@ -103,11 +102,6 @@ public void tearDown() { pluginLogger.setLevel(originalLevel); } - @Test - public void getGcsOffloader_emptyBucketName_returnsNull() { - assertNull(pluginState.getGcsOffloader(config)); - } - @Test public void addPendingTask_removedTaskOnCompletion() { String invocationId = "testInvocation"; @@ -213,8 +207,7 @@ public void ensureInvocationCompleted_timeout_cleansUpState() throws IOException // Wait for cleanup side effects which run after terminal signal. long deadline = Instant.now().plusMillis(1000).toEpochMilli(); - while (!pluginState.getPendingTasksForInvocation(invocationId).isEmpty() - && Instant.now().toEpochMilli() < deadline) { + while (!pluginState.isProcessed(invocationId) && Instant.now().toEpochMilli() < deadline) { try { Thread.sleep(10); } catch (InterruptedException e) { @@ -249,51 +242,4 @@ public void close_succeedsAndCleansUp() throws Exception { assertTrue(pluginState.getTraceManagers().isEmpty()); assertTrue(pluginState.getExecutor().isShutdown()); } - - @Test - public void close_respectsRemainingTimeoutBudget() throws Exception { - config = config.toBuilder().shutdownTimeout(Duration.ofMillis(500)).build(); - pluginState = new TestPluginState(config); - - ExecutorService mockOffloadExecutor = mock(ExecutorService.class); - Field field = PluginState.class.getDeclaredField("offloadExecutor"); - field.setAccessible(true); - field.set(pluginState, mockOffloadExecutor); - - pluginState - .getExecutor() - .execute( - () -> { - Uninterruptibles.sleepUninterruptibly(Duration.ofMillis(200)); - }); - - when(mockOffloadExecutor.awaitTermination(any(Long.class), any(TimeUnit.class))) - .thenReturn(true); - - pluginState.close().test().awaitDone(2, SECONDS); - - ArgumentCaptor timeoutCaptor = ArgumentCaptor.forClass(Long.class); - verify(mockOffloadExecutor).awaitTermination(timeoutCaptor.capture(), any(TimeUnit.class)); - - long capturedTimeout = timeoutCaptor.getValue(); - assertTrue("Timeout should be less than 400", capturedTimeout < 400); - assertTrue("Timeout should be greater than 100", capturedTimeout > 100); - } - - @Test - public void close_closesGcsOffloader() throws Exception { - GcsOffloader mockOffloader = mock(GcsOffloader.class); - BigQueryLoggerConfig gcsConfig = config.toBuilder().gcsBucketName("test-bucket").build(); - PluginState gcsState = - new TestPluginState(gcsConfig) { - @Override - protected GcsOffloader getGcsOffloader(BigQueryLoggerConfig config) { - return mockOffloader; - } - }; - - gcsState.close().test().assertComplete(); - - verify(mockOffloader).close(); - } } diff --git a/core/src/test/java/com/google/adk/runner/RunnerTest.java b/core/src/test/java/com/google/adk/runner/RunnerTest.java index 95718e3e0..ff75c97b0 100644 --- a/core/src/test/java/com/google/adk/runner/RunnerTest.java +++ b/core/src/test/java/com/google/adk/runner/RunnerTest.java @@ -42,7 +42,6 @@ import com.google.adk.agents.LiveRequestQueue; import com.google.adk.agents.LlmAgent; import com.google.adk.agents.RunConfig; -import com.google.adk.agents.SequentialAgent; import com.google.adk.apps.App; import com.google.adk.artifacts.BaseArtifactService; import com.google.adk.events.Event; @@ -1123,42 +1122,6 @@ public void beforeRunCallback_withStateDelta_seesMergedState() { assertThat(sessionInCallback.state()).containsEntry("number", 123); } - @Test - public void onUserMessageCallback_withStateDelta_seesMergedState() { - // Snapshot the session state *inside* the callback, otherwise the assertion would - // observe the post-runAsync state which is mutated by appendEvent regardless of whether - // the pre-merge in Runner is applied. - AtomicReference> stateInCallback = new AtomicReference<>(); - when(plugin.onUserMessageCallback(any(), any())) - .thenAnswer( - invocation -> { - InvocationContext ctx = invocation.getArgument(0); - stateInCallback.set(new ConcurrentHashMap<>(ctx.session().state())); - return Maybe.empty(); - }); - - ImmutableMap stateDelta = - ImmutableMap.of("callback_key", "callback_value", "number", 123); - - var unused = - runner - .runAsync( - "user", - session.id(), - createContent("test with state"), - RunConfig.builder().build(), - stateDelta) - .toList() - .blockingGet(); - - // Verify onUserMessageCallback was called - verify(plugin).onUserMessageCallback(any(), any()); - - // Verify state delta was merged before onUserMessageCallback was invoked - assertThat(stateInCallback.get()).containsEntry("callback_key", "callback_value"); - assertThat(stateInCallback.get()).containsEntry("number", 123); - } - @Test public void runAsync_ensureEventsAreAppendedInOrder() throws Exception { Event event1 = TestUtils.createEvent("1"); @@ -1367,7 +1330,7 @@ public void runAsync_createsToolSpansWithCorrectParent() { List spans = openTelemetryRule.getSpans(); List llmSpans = spans.stream().filter(s -> s.getName().equals("call_llm")).toList(); List toolSpans = - spans.stream().filter(s -> s.getName().equals("execute_tool echo_tool")).toList(); + spans.stream().filter(s -> s.getName().equals("execute_tool [echo_tool]")).toList(); assertThat(llmSpans).hasSize(2); assertThat(toolSpans).hasSize(1); @@ -1402,7 +1365,7 @@ public void runLive_createsToolSpansWithCorrectParent() throws Exception { List spans = openTelemetryRule.getSpans(); List llmSpans = spans.stream().filter(s -> s.getName().equals("call_llm")).toList(); List toolSpans = - spans.stream().filter(s -> s.getName().equals("execute_tool echo_tool")).toList(); + spans.stream().filter(s -> s.getName().equals("execute_tool [echo_tool]")).toList(); // In runLive, there is one call_llm span for the execution assertThat(llmSpans).hasSize(1); @@ -1605,107 +1568,6 @@ public void runAsync_withToolConfirmation() { .inOrder(); } - // HITL tool confirmation must resume the originating sub-agent even when wrapped inside a - // non-LlmAgent workflow agent (e.g. SequentialAgent). - @Test - public void runAsync_withToolConfirmation_inSequentialAgentSubAgent_resumesSubAgent() { - TestLlm childTestLlm = - createTestLlm( - createFunctionCallLlmResponse( - "tool_call_id", "echoTool", ImmutableMap.of("message", "hello")), - createTextLlmResponse("Response after observing tool needs confirmation."), - createTextLlmResponse("Response after user confirmed.")); - LlmAgent childAgent = - createTestAgentBuilder(childTestLlm) - .name("child_agent") - .tools(FunctionTool.create(Tools.class, "echoTool", /* requireConfirmation= */ true)) - .build(); - SequentialAgent workflowAgent = - SequentialAgent.builder() - .name("workflow_agent") - .subAgents(ImmutableList.of(childAgent)) - .build(); - // Root transfers to workflow_agent to mirror the bug report's control flow. - TestLlm rootTestLlm = - createTestLlm( - createLlmResponse( - Content.fromParts( - Part.fromFunctionCall( - "transfer_to_agent", ImmutableMap.of("agent_name", "workflow_agent"))))); - LlmAgent rootAgent = - createTestAgentBuilder(rootTestLlm) - .name("root_agent") - .subAgents(ImmutableList.of(workflowAgent)) - .build(); - Runner runner = - Runner.builder().app(App.builder().name("test").rootAgent(rootAgent).build()).build(); - Session session = runner.sessionService().createSession("test", "user").blockingGet(); - - List eventsBeforeConfirmation = - runner - .runAsync("user", session.id(), Content.fromParts(Part.fromText("from user"))) - .toList() - .blockingGet(); - FunctionCall askUserConfirmationFunctionCall = - Iterables.getOnlyElement( - eventsBeforeConfirmation.stream() - .map(Functions::getAskUserConfirmationFunctionCalls) - .filter(functionCalls -> !functionCalls.isEmpty()) - .findFirst() - .get()); - List eventsAfterConfirmation = - runner - .runAsync( - "user", - session.id(), - Content.fromParts( - Part.builder() - .functionResponse( - FunctionResponse.builder() - .id(askUserConfirmationFunctionCall.id().get()) - .name(askUserConfirmationFunctionCall.name().get()) - .response(ImmutableMap.of("confirmed", true))) - .build())) - .toList() - .blockingGet(); - - // The originating child agent (not the root agent) must execute the tool. - assertThat(simplifyEvents(eventsAfterConfirmation)) - .containsExactly( - "child_agent: FunctionResponse(name=echoTool, response={message=hello})", - "child_agent: Response after user confirmed.") - .inOrder(); - } - - // Orphan function responses (id not matching any prior call) should fall back to the root agent. - @Test - public void runAsync_withFunctionResponseNotMatchingAnyCall_fallsBackToRootAgent() { - TestLlm rootLlm = createTestLlm(createTextLlmResponse("after function response")); - LlmAgent rootAgent = createTestAgentBuilder(rootLlm).name("root_agent").build(); - Runner runner = - Runner.builder().app(App.builder().name("test").rootAgent(rootAgent).build()).build(); - Session session = runner.sessionService().createSession("test", "user").blockingGet(); - - // Function response with id that does not match any prior function call. - List events = - runner - .runAsync( - "user", - session.id(), - Content.fromParts( - Part.builder() - .functionResponse( - FunctionResponse.builder() - .id("non_existent_id") - .name("orphanFn") - .response(ImmutableMap.of("x", 1))) - .build())) - .toList() - .blockingGet(); - - assertThat(simplifyEvents(events)).containsExactly("root_agent: after function response"); - } - @Test public void close_closesPluginsAndCodeExecutors() { BasePlugin plugin = mockPlugin("close_test_plugin"); diff --git a/core/src/test/java/com/google/adk/skills/LocalSkillSourceTest.java b/core/src/test/java/com/google/adk/skills/LocalSkillSourceTest.java index 5efdb4ab9..25e7be525 100644 --- a/core/src/test/java/com/google/adk/skills/LocalSkillSourceTest.java +++ b/core/src/test/java/com/google/adk/skills/LocalSkillSourceTest.java @@ -224,8 +224,6 @@ public void testLoadResource_notFound() throws IOException { var single = source.loadResource("my-skill", "non-existent.txt"); RuntimeException exception = assertThrows(RuntimeException.class, single::blockingGet); assertThat(exception).hasCauseThat().isInstanceOf(SkillSourceException.class); - SkillSourceException cause = (SkillSourceException) exception.getCause(); - assertThat(cause.getErrorCode()).isEqualTo(SkillSourceException.RESOURCE_NOT_FOUND); } @Test @@ -236,8 +234,6 @@ public void testLoadFrontmatter_skillNotFound() { var single = source.loadFrontmatter("non-existent"); RuntimeException exception = assertThrows(RuntimeException.class, single::blockingGet); assertThat(exception).hasCauseThat().isInstanceOf(SkillSourceException.class); - SkillSourceException cause = (SkillSourceException) exception.getCause(); - assertThat(cause.getErrorCode()).isEqualTo(SkillSourceException.SKILL_NOT_FOUND); } @Test @@ -253,125 +249,74 @@ public void testListSkillMdPaths_skillSourceException() throws IOException { var single = source.listFrontmatters(); RuntimeException exception = assertThrows(RuntimeException.class, single::blockingGet); assertThat(exception).hasCauseThat().isInstanceOf(SkillSourceException.class); - SkillSourceException cause = (SkillSourceException) exception.getCause(); - assertThat(cause.getErrorCode()).isEqualTo(SkillSourceException.SKILL_LOAD_ERROR); } @Test - public void testLoadFrontmatter_missingStartDashes() throws IOException { + public void testLoadResource_pathTraversalInSkillName() throws IOException { Path skillsBase = tempFolder.getRoot().toPath().resolve("skills"); Files.createDirectory(skillsBase); - Path skillDir = skillsBase.resolve("my-skill"); - Files.createDirectory(skillDir); - Files.writeString( - skillDir.resolve("SKILL.md"), - """ - name: my-skill - description: This is a test skill - --- - body - """); - SkillSource source = new LocalSkillSource(skillsBase); - var single = source.loadFrontmatter("my-skill"); + var single = source.loadResource("../other-dir", "file.txt"); RuntimeException exception = assertThrows(RuntimeException.class, single::blockingGet); assertThat(exception).hasCauseThat().isInstanceOf(SkillSourceException.class); - assertThat(exception) - .hasCauseThat() - .hasMessageThat() - .contains("Skill file must start with ---"); + assertThat(exception).hasCauseThat().hasMessageThat().contains("Path traversal detected"); } @Test - public void testLoadInstructions_missingStartDashes() throws IOException { + public void testLoadResource_pathTraversalInResourcePath() throws IOException { Path skillsBase = tempFolder.getRoot().toPath().resolve("skills"); Files.createDirectory(skillsBase); - Path skillDir = skillsBase.resolve("my-skill"); Files.createDirectory(skillDir); - Files.writeString( - skillDir.resolve("SKILL.md"), - """ - name: my-skill - description: Test - --- - Some Markdown Body - """); SkillSource source = new LocalSkillSource(skillsBase); - var single = source.loadInstructions("my-skill"); + var single = source.loadResource("my-skill", "../../../etc/passwd"); RuntimeException exception = assertThrows(RuntimeException.class, single::blockingGet); assertThat(exception).hasCauseThat().isInstanceOf(SkillSourceException.class); - assertThat(exception) - .hasCauseThat() - .hasMessageThat() - .contains("Skill file must start with ---"); + assertThat(exception).hasCauseThat().hasMessageThat().contains("Path traversal detected"); } @Test - public void testLoadFrontmatter_nameMismatch() throws IOException { + public void testLoadResource_absoluteResourcePath() throws IOException { Path skillsBase = tempFolder.getRoot().toPath().resolve("skills"); Files.createDirectory(skillsBase); - Path skillDir = skillsBase.resolve("my-skill"); Files.createDirectory(skillDir); - Files.writeString( - skillDir.resolve("SKILL.md"), - """ - --- - name: other-skill - description: This is a test skill - --- - body - """); SkillSource source = new LocalSkillSource(skillsBase); - var single = source.loadFrontmatter("my-skill"); + var single = source.loadResource("my-skill", "/etc/passwd"); RuntimeException exception = assertThrows(RuntimeException.class, single::blockingGet); assertThat(exception).hasCauseThat().isInstanceOf(SkillSourceException.class); assertThat(exception) .hasCauseThat() .hasMessageThat() - .contains( - "Skill name in the frontmatter 'other-skill' does not match skill name 'my-skill'."); + .contains("Absolute paths are not allowed"); } @Test - public void testLoadFrontmatter_emptyFile() throws IOException { + public void testListResources_pathTraversalInSkillName() throws IOException { Path skillsBase = tempFolder.getRoot().toPath().resolve("skills"); Files.createDirectory(skillsBase); - Path skillDir = skillsBase.resolve("my-skill"); - Files.createDirectory(skillDir); - Files.writeString(skillDir.resolve("SKILL.md"), ""); - SkillSource source = new LocalSkillSource(skillsBase); - var single = source.loadFrontmatter("my-skill"); + var single = source.listResources("../../etc", "passwd"); RuntimeException exception = assertThrows(RuntimeException.class, single::blockingGet); assertThat(exception).hasCauseThat().isInstanceOf(SkillSourceException.class); - assertThat(exception) - .hasCauseThat() - .hasMessageThat() - .contains("Skill file must start with ---"); + assertThat(exception).hasCauseThat().hasMessageThat().contains("Path traversal detected"); } @Test - public void testLoadInstructions_emptyFile() throws IOException { + public void testListResources_pathTraversalInResourceDirectory() throws IOException { Path skillsBase = tempFolder.getRoot().toPath().resolve("skills"); Files.createDirectory(skillsBase); - Path skillDir = skillsBase.resolve("my-skill"); Files.createDirectory(skillDir); - Files.writeString(skillDir.resolve("SKILL.md"), ""); SkillSource source = new LocalSkillSource(skillsBase); - var single = source.loadInstructions("my-skill"); + var single = source.listResources("my-skill", "../other-skill"); RuntimeException exception = assertThrows(RuntimeException.class, single::blockingGet); assertThat(exception).hasCauseThat().isInstanceOf(SkillSourceException.class); - assertThat(exception) - .hasCauseThat() - .hasMessageThat() - .contains("Skill file must start with ---"); + assertThat(exception).hasCauseThat().hasMessageThat().contains("Path traversal detected"); } } diff --git a/core/src/test/java/com/google/adk/telemetry/ContextPropagationTest.java b/core/src/test/java/com/google/adk/telemetry/ContextPropagationTest.java index 331ae77b2..44877e972 100644 --- a/core/src/test/java/com/google/adk/telemetry/ContextPropagationTest.java +++ b/core/src/test/java/com/google/adk/telemetry/ContextPropagationTest.java @@ -471,7 +471,7 @@ public void testAgentWithToolCallTraceHierarchy() throws InterruptedException { // invocation // └── invoke_agent test_agent // ├── call_llm - // │ └── execute_tool search_flights + // │ └── execute_tool [search_flights] // └── call_llm SearchFlightsTool searchFlightsTool = new SearchFlightsTool(); @@ -499,7 +499,7 @@ public void testAgentWithToolCallTraceHierarchy() throws InterruptedException { SpanData invocation = findSpanByName("invocation"); SpanData invokeAgent = findSpanByName("invoke_agent test_agent"); - SpanData toolResponse = findSpanByName("execute_tool search_flights"); + SpanData toolResponse = findSpanByName("execute_tool [search_flights]"); List callLlmSpans = openTelemetryRule.getSpans().stream() .filter(s -> s.getName().equals("call_llm")) @@ -515,7 +515,7 @@ public void testAgentWithToolCallTraceHierarchy() throws InterruptedException { assertParent(invocation, invokeAgent); // ├── call_llm 1 assertParent(invokeAgent, callLlm1); - // │ └── execute_tool search_flights + // │ └── execute_tool [search_flights] assertParent(callLlm1, toolResponse); // └── call_llm 2 assertParent(invokeAgent, callLlm2); @@ -546,7 +546,7 @@ public void testNestedAgentTraceHierarchy() throws InterruptedException { // invocation // └── invoke_agent AgentA // ├── call_llm - // │ └── execute_tool transfer_to_agent + // │ └── execute_tool [transfer_to_agent] // └── invoke_agent AgentB // └── call_llm TestLlm llm = @@ -573,7 +573,7 @@ public void testNestedAgentTraceHierarchy() throws InterruptedException { SpanData invocation = findSpanByName("invocation"); SpanData agentASpan = findSpanByName("invoke_agent AgentA"); - SpanData executeTool = findSpanByName("execute_tool transfer_to_agent"); + SpanData executeTool = findSpanByName("execute_tool [transfer_to_agent]"); SpanData agentBSpan = findSpanByName("invoke_agent AgentB"); List callLlmSpans = @@ -586,17 +586,10 @@ public void testNestedAgentTraceHierarchy() throws InterruptedException { SpanData agentACallLlm1 = callLlmSpans.get(0); SpanData agentBCallLlm = callLlmSpans.get(1); - // Assert hierarchy: - // invocation - // └── invoke_agent AgentA assertParent(invocation, agentASpan); - // └── call_llm 1 assertParent(agentASpan, agentACallLlm1); - // ├── execute_tool transfer_to_agent assertParent(agentACallLlm1, executeTool); - // └── invoke_agent AgentB - assertParent(agentACallLlm1, agentBSpan); - // └── call_llm 2 + assertParent(agentASpan, agentBSpan); assertParent(agentBSpan, agentBCallLlm); } diff --git a/core/src/test/java/com/google/adk/testing/TestCallback.java b/core/src/test/java/com/google/adk/testing/TestCallback.java index 403e3874a..6f35f5a3c 100644 --- a/core/src/test/java/com/google/adk/testing/TestCallback.java +++ b/core/src/test/java/com/google/adk/testing/TestCallback.java @@ -91,7 +91,7 @@ public Supplier> asRunAsyncImplSupplier(Content content) { Flowable.defer( () -> { markAsCalled(); - return Flowable.just(Event.builder().author("testAgent").content(content).build()); + return Flowable.just(Event.builder().content(content).build()); }); } @@ -111,7 +111,7 @@ public Supplier> asRunLiveImplSupplier(Content content) { Flowable.defer( () -> { markAsCalled(); - return Flowable.just(Event.builder().author("testAgent").content(content).build()); + return Flowable.just(Event.builder().content(content).build()); }); } diff --git a/core/src/test/java/com/google/adk/testing/TestLlm.java b/core/src/test/java/com/google/adk/testing/TestLlm.java index fc9ce3850..aaacf00a0 100644 --- a/core/src/test/java/com/google/adk/testing/TestLlm.java +++ b/core/src/test/java/com/google/adk/testing/TestLlm.java @@ -42,7 +42,7 @@ import java.util.concurrent.atomic.AtomicInteger; import java.util.function.Predicate; import java.util.function.Supplier; -import org.jspecify.annotations.Nullable; +import javax.annotation.Nullable; /** * A test implementation of {@link BaseLlm}. diff --git a/core/src/test/java/com/google/adk/tools/mcp/DefaultMcpTransportBuilderTest.java b/core/src/test/java/com/google/adk/tools/mcp/DefaultMcpTransportBuilderTest.java deleted file mode 100644 index d2c263ac4..000000000 --- a/core/src/test/java/com/google/adk/tools/mcp/DefaultMcpTransportBuilderTest.java +++ /dev/null @@ -1,274 +0,0 @@ -/* - * Copyright 2026 Google LLC - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package com.google.adk.tools.mcp; - -import static com.google.common.truth.Truth.assertThat; -import static org.junit.Assert.assertThrows; - -import com.google.common.collect.ImmutableMap; -import io.modelcontextprotocol.client.transport.HttpClientSseClientTransport; -import io.modelcontextprotocol.client.transport.HttpClientStreamableHttpTransport; -import io.modelcontextprotocol.client.transport.ServerParameters; -import io.modelcontextprotocol.client.transport.StdioClientTransport; -import io.modelcontextprotocol.client.transport.customizer.McpAsyncHttpClientRequestCustomizer; -import io.modelcontextprotocol.common.McpTransportContext; -import io.modelcontextprotocol.spec.McpClientTransport; -import java.lang.reflect.Field; -import java.net.URI; -import java.net.http.HttpRequest; -import java.util.HashMap; -import java.util.Map; -import org.junit.Test; -import org.junit.runner.RunWith; -import org.junit.runners.JUnit4; -import reactor.core.publisher.Mono; - -/** Unit tests for {@link DefaultMcpTransportBuilder}. */ -@RunWith(JUnit4.class) -public final class DefaultMcpTransportBuilderTest { - - private final DefaultMcpTransportBuilder transportBuilder = new DefaultMcpTransportBuilder(); - - @Test - public void build_withServerParameters_returnsStdioTransport() { - ServerParameters params = ServerParameters.builder("test-command").build(); - - McpClientTransport transport = transportBuilder.build(params); - - assertThat(transport).isInstanceOf(StdioClientTransport.class); - } - - @Test - public void build_withSseServerParameters_returnsSseTransport() { - SseServerParameters params = SseServerParameters.builder().url("http://localhost:1234").build(); - - McpClientTransport transport = transportBuilder.build(params); - - assertThat(transport).isInstanceOf(HttpClientSseClientTransport.class); - } - - @Test - public void build_withStreamableHttpServerParameters_returnsStreamableHttpTransport() { - StreamableHttpServerParameters params = - StreamableHttpServerParameters.builder().url("http://localhost:1234").build(); - - McpClientTransport transport = transportBuilder.build(params); - - assertThat(transport).isInstanceOf(HttpClientStreamableHttpTransport.class); - } - - @Test - public void build_withUnknownConnectionParams_throwsIllegalArgumentException() { - Object unknownParams = new Object(); - - IllegalArgumentException ex = - assertThrows(IllegalArgumentException.class, () -> transportBuilder.build(unknownParams)); - - assertThat(ex).hasMessageThat().contains("DefaultMcpTransportBuilder supports only"); - } - - @Test - public void build_withStreamableHttpUrlWithoutPath_usesDefaultEndpoint() throws Exception { - StreamableHttpServerParameters params = - StreamableHttpServerParameters.builder().url("http://localhost:8080").build(); - - HttpClientStreamableHttpTransport transport = - (HttpClientStreamableHttpTransport) transportBuilder.build(params); - - assertThat(getBaseUri(transport)).isEqualTo(URI.create("http://localhost:8080")); - assertThat(getEndpoint(transport)).isEqualTo("/mcp"); - } - - @Test - public void build_withStreamableHttpUrlWithRootPath_usesDefaultEndpoint() throws Exception { - StreamableHttpServerParameters params = - StreamableHttpServerParameters.builder().url("http://localhost:8080/").build(); - - HttpClientStreamableHttpTransport transport = - (HttpClientStreamableHttpTransport) transportBuilder.build(params); - - assertThat(getEndpoint(transport)).isEqualTo("/mcp"); - } - - @Test - public void build_withStreamableHttpCustomEndpointPath_preservesCustomPath() throws Exception { - // Regression test for google/adk-java#1196. - StreamableHttpServerParameters params = - StreamableHttpServerParameters.builder().url("http://localhost:8080/mcp/stream").build(); - - HttpClientStreamableHttpTransport transport = - (HttpClientStreamableHttpTransport) transportBuilder.build(params); - - assertThat(getBaseUri(transport)).isEqualTo(URI.create("http://localhost:8080")); - assertThat(getEndpoint(transport)).isEqualTo("/mcp/stream"); - } - - @Test - public void build_withStreamableHttpCustomEndpoint_resolvesToFullUrl() throws Exception { - StreamableHttpServerParameters params = - StreamableHttpServerParameters.builder().url("http://localhost:8080/mcp/stream").build(); - - HttpClientStreamableHttpTransport transport = - (HttpClientStreamableHttpTransport) transportBuilder.build(params); - - URI resolved = getBaseUri(transport).resolve(getEndpoint(transport)); - assertThat(resolved).isEqualTo(URI.create("http://localhost:8080/mcp/stream")); - } - - @Test - public void build_withStreamableHttpDeepCustomPath_preservesEntirePath() throws Exception { - StreamableHttpServerParameters params = - StreamableHttpServerParameters.builder() - .url("https://example.com/api/v1/mcp/stream") - .build(); - - HttpClientStreamableHttpTransport transport = - (HttpClientStreamableHttpTransport) transportBuilder.build(params); - - assertThat(getBaseUri(transport)).isEqualTo(URI.create("https://example.com")); - assertThat(getEndpoint(transport)).isEqualTo("/api/v1/mcp/stream"); - } - - @Test - public void build_withStreamableHttpQueryAndFragment_preservesQueryAndFragment() - throws Exception { - StreamableHttpServerParameters params = - StreamableHttpServerParameters.builder() - .url("https://example.com/mcp/stream?token=abc#frag") - .build(); - - HttpClientStreamableHttpTransport transport = - (HttpClientStreamableHttpTransport) transportBuilder.build(params); - - assertThat(getBaseUri(transport)).isEqualTo(URI.create("https://example.com")); - assertThat(getEndpoint(transport)).isEqualTo("/mcp/stream?token=abc#frag"); - } - - @Test - public void build_withStreamableHttpEncodedPath_preservesEncoding() throws Exception { - StreamableHttpServerParameters params = - StreamableHttpServerParameters.builder() - .url("https://example.com/mcp%20stream/path") - .build(); - - HttpClientStreamableHttpTransport transport = - (HttpClientStreamableHttpTransport) transportBuilder.build(params); - - assertThat(getBaseUri(transport)).isEqualTo(URI.create("https://example.com")); - assertThat(getEndpoint(transport)).isEqualTo("/mcp%20stream/path"); - } - - @Test - public void build_withStreamableHttpHeaders_customizerForwardsHeadersToRequest() - throws Exception { - StreamableHttpServerParameters params = - StreamableHttpServerParameters.builder() - .url("http://localhost:8080/mcp/stream") - .headers(ImmutableMap.of("X-Custom", "value", "Authorization", "Bearer token")) - .build(); - - HttpClientStreamableHttpTransport transport = - (HttpClientStreamableHttpTransport) transportBuilder.build(params); - McpAsyncHttpClientRequestCustomizer customizer = getCustomizer(transport); - HttpRequest.Builder requestBuilder = HttpRequest.newBuilder().uri(URI.create("http://x/")); - - HttpRequest.Builder returned = - Mono.from( - customizer.customize( - requestBuilder, - "POST", - URI.create("http://x/"), - null, - McpTransportContext.EMPTY)) - .block(); - - assertThat(returned).isSameInstanceAs(requestBuilder); - Map headers = collectHeaders(requestBuilder); - assertThat(headers).containsEntry("X-Custom", "value"); - assertThat(headers).containsEntry("Authorization", "Bearer token"); - } - - @Test - public void build_withStreamableHttpEmptyHeaders_customizerIsNoOp() throws Exception { - StreamableHttpServerParameters params = - StreamableHttpServerParameters.builder() - .url("http://localhost:8080/mcp/stream") - .headers(ImmutableMap.of()) - .build(); - - HttpClientStreamableHttpTransport transport = - (HttpClientStreamableHttpTransport) transportBuilder.build(params); - McpAsyncHttpClientRequestCustomizer customizer = getCustomizer(transport); - HttpRequest.Builder requestBuilder = HttpRequest.newBuilder().uri(URI.create("http://x/")); - - Mono.from( - customizer.customize( - requestBuilder, "POST", URI.create("http://x/"), null, McpTransportContext.EMPTY)) - .block(); - - assertThat(collectHeaders(requestBuilder)).isEmpty(); - } - - @Test - public void build_withStreamableHttpMalformedUrl_doesNotMaskUnderlyingError() { - // Unparseable URL: split helper forwards it as-is so the transport surfaces its own error. - StreamableHttpServerParameters params = - StreamableHttpServerParameters.builder().url("http://example.com/path with space").build(); - - assertThrows(IllegalArgumentException.class, () -> transportBuilder.build(params)); - } - - @Test - public void build_withStreamableHttpSchemelessUrl_forwardsUnchangedAsBaseUri() throws Exception { - // No scheme/authority: split helper forwards the URL as-is and keeps the default endpoint. - StreamableHttpServerParameters params = - StreamableHttpServerParameters.builder().url("relative/path").build(); - - HttpClientStreamableHttpTransport transport = - (HttpClientStreamableHttpTransport) transportBuilder.build(params); - - assertThat(getBaseUri(transport)).isEqualTo(URI.create("relative/path")); - assertThat(getEndpoint(transport)).isEqualTo("/mcp"); - } - - private static URI getBaseUri(HttpClientStreamableHttpTransport transport) throws Exception { - Field field = HttpClientStreamableHttpTransport.class.getDeclaredField("baseUri"); - field.setAccessible(true); - return (URI) field.get(transport); - } - - private static String getEndpoint(HttpClientStreamableHttpTransport transport) throws Exception { - Field field = HttpClientStreamableHttpTransport.class.getDeclaredField("endpoint"); - field.setAccessible(true); - return (String) field.get(transport); - } - - private static McpAsyncHttpClientRequestCustomizer getCustomizer( - HttpClientStreamableHttpTransport transport) throws Exception { - Field field = HttpClientStreamableHttpTransport.class.getDeclaredField("httpRequestCustomizer"); - field.setAccessible(true); - return (McpAsyncHttpClientRequestCustomizer) field.get(transport); - } - - /** Reads back the headers set on a builder by building a throwaway request. */ - private static Map collectHeaders(HttpRequest.Builder builder) { - HttpRequest request = builder.GET().build(); - Map result = new HashMap<>(); - request.headers().map().forEach((key, values) -> result.put(key, String.join(",", values))); - return result; - } -} diff --git a/core/src/test/java/com/google/adk/tools/skills/ListSkillsToolTest.java b/core/src/test/java/com/google/adk/tools/skills/ListSkillsToolTest.java deleted file mode 100644 index fe5b202a2..000000000 --- a/core/src/test/java/com/google/adk/tools/skills/ListSkillsToolTest.java +++ /dev/null @@ -1,151 +0,0 @@ -/* - * Copyright 2026 Google LLC - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package com.google.adk.tools.skills; - -import static com.google.adk.skills.SkillSourceException.SKILL_LOAD_ERROR; -import static com.google.common.truth.Truth.assertThat; -import static org.junit.Assert.assertThrows; -import static org.mockito.Mockito.mock; -import static org.mockito.Mockito.when; - -import com.google.adk.agents.InvocationContext; -import com.google.adk.sessions.Session; -import com.google.adk.skills.Frontmatter; -import com.google.adk.skills.InMemorySkillSource; -import com.google.adk.skills.SkillSource; -import com.google.adk.skills.SkillSourceException; -import com.google.adk.testing.TestBaseAgent; -import com.google.adk.tools.ToolContext; -import com.google.common.collect.ImmutableList; -import com.google.common.collect.ImmutableMap; -import io.reactivex.rxjava3.core.Flowable; -import io.reactivex.rxjava3.core.Single; -import java.util.Map; -import org.junit.Test; -import org.junit.runner.RunWith; -import org.junit.runners.JUnit4; - -@RunWith(JUnit4.class) -public class ListSkillsToolTest { - - @Test - public void call_listSkillsTool_success() { - Frontmatter testFrontmatter = - Frontmatter.builder().name("test-skill").description("test skill").build(); - - TestBaseAgent testAgent = - new TestBaseAgent( - "test agent", "test agent", ImmutableList.of(), ImmutableList.of(), Flowable::empty); - Session session = Session.builder("session").build(); - - InvocationContext invocationContext = mock(InvocationContext.class); - when(invocationContext.agent()).thenReturn(testAgent); - when(invocationContext.session()).thenReturn(session); - - SkillSource skillSource = - InMemorySkillSource.builder() - .skill(testFrontmatter.name()) - .frontmatter(testFrontmatter) - .instructions("Test instructions") - .build(); - ListSkillsTool listSkillsTool = new ListSkillsTool(skillSource); - Map response = - listSkillsTool - .runAsync(ImmutableMap.of(), ToolContext.builder(invocationContext).build()) - .blockingGet(); - - assertThat(response) - .containsExactly( - "skills_xml", "" + testFrontmatter.toXml() + ""); - } - - @Test - public void call_listSkillsTool_empty() { - TestBaseAgent testAgent = - new TestBaseAgent( - "test agent", "test agent", ImmutableList.of(), ImmutableList.of(), Flowable::empty); - Session session = Session.builder("session").build(); - - InvocationContext invocationContext = mock(InvocationContext.class); - when(invocationContext.agent()).thenReturn(testAgent); - when(invocationContext.session()).thenReturn(session); - - ListSkillsTool listSkillsTool = new ListSkillsTool(InMemorySkillSource.builder().build()); - Map response = - listSkillsTool - .runAsync(ImmutableMap.of(), ToolContext.builder(invocationContext).build()) - .blockingGet(); - - assertThat(response).containsExactly("skills_xml", ""); - } - - @Test - public void call_listSkillsTool_skillSourceException() { - TestBaseAgent testAgent = - new TestBaseAgent( - "test agent", "test agent", ImmutableList.of(), ImmutableList.of(), Flowable::empty); - Session session = Session.builder("session").build(); - - InvocationContext invocationContext = mock(InvocationContext.class); - when(invocationContext.agent()).thenReturn(testAgent); - when(invocationContext.session()).thenReturn(session); - - SkillSource skillSource = mock(SkillSource.class); - when(skillSource.listFrontmatters()) - .thenReturn( - Single.error(new SkillSourceException("Failed to list skills", SKILL_LOAD_ERROR))); - - ListSkillsTool listSkillsTool = new ListSkillsTool(skillSource); - Map response = - listSkillsTool - .runAsync(ImmutableMap.of(), ToolContext.builder(invocationContext).build()) - .blockingGet(); - - assertThat(response) - .containsExactly("error", "Failed to list skills", "error_code", "SKILL_LOAD_ERROR"); - } - - @Test - public void call_listSkillsTool_otherException() { - TestBaseAgent testAgent = - new TestBaseAgent( - "test agent", "test agent", ImmutableList.of(), ImmutableList.of(), Flowable::empty); - Session session = Session.builder("session").build(); - - InvocationContext invocationContext = mock(InvocationContext.class); - when(invocationContext.agent()).thenReturn(testAgent); - when(invocationContext.session()).thenReturn(session); - - SkillSource skillSource = mock(SkillSource.class); - RuntimeException expectedException = new RuntimeException("Unexpected error"); - when(skillSource.listFrontmatters()).thenReturn(Single.error(expectedException)); - - ListSkillsTool listSkillsTool = new ListSkillsTool(skillSource); - var single = - listSkillsTool.runAsync(ImmutableMap.of(), ToolContext.builder(invocationContext).build()); - - RuntimeException thrown = assertThrows(RuntimeException.class, single::blockingGet); - - assertThat(thrown).hasMessageThat().contains("Unexpected error"); - } - - @Test - public void call_listSkillsTool_declaration() { - ListSkillsTool listSkillsTool = new ListSkillsTool(mock(SkillSource.class)); - assertThat(listSkillsTool.declaration().get().name()).hasValue("list_skills"); - } -} diff --git a/core/src/test/java/com/google/adk/tools/skills/LoadSkillResourceToolTest.java b/core/src/test/java/com/google/adk/tools/skills/LoadSkillResourceToolTest.java deleted file mode 100644 index 3ab9b2a17..000000000 --- a/core/src/test/java/com/google/adk/tools/skills/LoadSkillResourceToolTest.java +++ /dev/null @@ -1,330 +0,0 @@ -/* - * Copyright 2026 Google LLC - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package com.google.adk.tools.skills; - -import static com.google.common.truth.Truth.assertThat; -import static java.nio.charset.StandardCharsets.UTF_8; -import static org.mockito.Mockito.mock; -import static org.mockito.Mockito.when; - -import com.google.adk.agents.InvocationContext; -import com.google.adk.models.LlmRequest; -import com.google.adk.sessions.Session; -import com.google.adk.skills.Frontmatter; -import com.google.adk.skills.InMemorySkillSource; -import com.google.adk.skills.SkillSource; -import com.google.adk.testing.TestBaseAgent; -import com.google.adk.tools.ToolContext; -import com.google.common.collect.ImmutableList; -import com.google.common.collect.ImmutableMap; -import com.google.genai.types.Content; -import com.google.genai.types.FunctionResponse; -import com.google.genai.types.Part; -import io.reactivex.rxjava3.core.Flowable; -import java.util.List; -import java.util.Map; -import org.junit.Test; -import org.junit.runner.RunWith; -import org.junit.runners.JUnit4; - -@RunWith(JUnit4.class) -public class LoadSkillResourceToolTest { - - @Test - public void call_loadSkillResourceTool_reference_success() { - LoadSkillResourceTool loadSkillResourceTool = - new LoadSkillResourceTool(createTestSkillSource()); - Map response = - loadSkillResourceTool - .runAsync( - ImmutableMap.of("skill_name", "test-skill", "file_path", "references/my_doc.md"), - createToolContext()) - .blockingGet(); - - assertThat(response) - .containsExactly( - "skill_name", "test-skill", - "file_path", "references/my_doc.md", - "mime_type", "text/markdown", - "content", "doc content"); - } - - @Test - public void call_loadSkillResourceTool_asset_success() { - LoadSkillResourceTool loadSkillResourceTool = - new LoadSkillResourceTool(createTestSkillSource()); - Map response = - loadSkillResourceTool - .runAsync( - ImmutableMap.of("skill_name", "test-skill", "file_path", "assets/template.txt"), - createToolContext()) - .blockingGet(); - - assertThat(response) - .containsExactly( - "skill_name", "test-skill", - "file_path", "assets/template.txt", - "mime_type", "text/plain", - "content", "asset content"); - } - - @Test - public void call_loadSkillResourceTool_script_success() { - LoadSkillResourceTool loadSkillResourceTool = - new LoadSkillResourceTool(createTestSkillSource()); - Map response = - loadSkillResourceTool - .runAsync( - ImmutableMap.of("skill_name", "test-skill", "file_path", "scripts/setup.sh"), - createToolContext()) - .blockingGet(); - - assertThat(response) - .containsExactly( - "skill_name", "test-skill", - "file_path", "scripts/setup.sh", - "mime_type", "application/x-sh", - "content", "echo hello"); - } - - @Test - public void call_loadSkillResourceTool_streamDetection_success() { - LoadSkillResourceTool loadSkillResourceTool = - new LoadSkillResourceTool(createTestSkillSource()); - Map response = - loadSkillResourceTool - .runAsync( - ImmutableMap.of("skill_name", "test-skill", "file_path", "assets/data_no_ext"), - createToolContext()) - .blockingGet(); - - assertThat(response) - .containsExactly( - "skill_name", "test-skill", - "file_path", "assets/data_no_ext", - "mime_type", "application/xml", - "content", ""); - } - - @Test - public void call_loadSkillResourceTool_binaryReference_detected() { - LoadSkillResourceTool loadSkillResourceTool = - new LoadSkillResourceTool(createTestSkillSource()); - ToolContext toolContext = createToolContext(); - Map response = - loadSkillResourceTool - .runAsync( - ImmutableMap.of("skill_name", "test-skill", "file_path", "references/binary.dat"), - toolContext) - .blockingGet(); - - Part partFunctionResponse = - Part.builder() - .functionResponse( - FunctionResponse.builder() - .id(toolContext.functionCallId().orElse("")) - .name(loadSkillResourceTool.name()) - .response(response) - .build()) - .build(); - - // Binary data is added as separate part in the next request to LLM - LlmRequest.Builder builder = - LlmRequest.builder() - .contents( - ImmutableList.of( - Content.builder().role("user").parts(partFunctionResponse).build())); - loadSkillResourceTool.processLlmRequest(builder, toolContext).blockingAwait(); - - List contents = builder.build().contents(); - assertThat(contents).hasSize(1); - List parts = contents.get(0).parts().get(); - assertThat(parts).hasSize(2); - - FunctionResponse updatedFunctionResponse = parts.get(0).functionResponse().get(); - assertThat(updatedFunctionResponse.response().get()) - .containsExactly( - "skill_name", - "test-skill", - "file_path", - "references/binary.dat", - "content", - "Binary file detected. The content has been included in the next part of the function" - + " response for you to analyze."); - - Part binaryPart = parts.get(1); - assertThat(binaryPart.inlineData().get().mimeType()).hasValue("application/octet-stream"); - assertThat(binaryPart.inlineData().get().data().get()).isEqualTo(new byte[] {0, 1, 2, 3}); - } - - @Test - public void call_loadSkillResourceTool_nonBinaryReference_notChanged() { - LoadSkillResourceTool loadSkillResourceTool = - new LoadSkillResourceTool(createTestSkillSource()); - ToolContext toolContext = createToolContext(); - Map response = - loadSkillResourceTool - .runAsync( - ImmutableMap.of("skill_name", "test-skill", "file_path", "references/my_doc.md"), - toolContext) - .blockingGet(); - - Part partFunctionResponse = - Part.builder() - .functionResponse( - FunctionResponse.builder() - .id(toolContext.functionCallId().orElse("")) - .name(loadSkillResourceTool.name()) - .response(response) - .build()) - .build(); - - LlmRequest.Builder builder = - LlmRequest.builder() - .contents( - ImmutableList.of( - Content.builder().role("user").parts(partFunctionResponse).build())); - List expectedContents = builder.build().contents(); - - loadSkillResourceTool.processLlmRequest(builder, toolContext).blockingAwait(); - - assertThat(builder.build().contents()).isEqualTo(expectedContents); - } - - @Test - public void call_loadSkillResourceTool_missingSkillName() { - LoadSkillResourceTool loadSkillResourceTool = - new LoadSkillResourceTool(createTestSkillSource()); - Map response = - loadSkillResourceTool - .runAsync(ImmutableMap.of("file_path", "references/my_doc.md"), createToolContext()) - .blockingGet(); - - assertThat(response) - .containsExactly( - "error", "Skill name is required.", - "error_code", "MISSING_SKILL_NAME"); - } - - @Test - public void call_loadSkillResourceTool_missingPath() { - LoadSkillResourceTool loadSkillResourceTool = - new LoadSkillResourceTool(createTestSkillSource()); - Map response = - loadSkillResourceTool - .runAsync(ImmutableMap.of("skill_name", "test-skill"), createToolContext()) - .blockingGet(); - - assertThat(response) - .containsExactly( - "error", "Resource path is required.", - "error_code", "MISSING_RESOURCE_PATH"); - } - - @Test - public void call_loadSkillResourceTool_skillNotFound() { - LoadSkillResourceTool loadSkillResourceTool = - new LoadSkillResourceTool(createTestSkillSource()); - Map response = - loadSkillResourceTool - .runAsync( - ImmutableMap.of("skill_name", "other-skill", "file_path", "references/my_doc.md"), - createToolContext()) - .blockingGet(); - - assertThat(response) - .containsExactly( - "error", "Skill not found: other-skill", - "error_code", "SKILL_NOT_FOUND"); - } - - @Test - public void call_loadSkillResourceTool_invalidPathPrefix() { - LoadSkillResourceTool loadSkillResourceTool = - new LoadSkillResourceTool(createTestSkillSource()); - Map response = - loadSkillResourceTool - .runAsync( - ImmutableMap.of("skill_name", "test-skill", "file_path", "invalid/my_doc.md"), - createToolContext()) - .blockingGet(); - - assertThat(response) - .containsExactly( - "error", "Path must start with 'references/', 'assets/', or 'scripts/'.", - "error_code", "INVALID_RESOURCE_PATH"); - } - - @Test - public void call_loadSkillResourceTool_resourceNotFound() { - LoadSkillResourceTool loadSkillResourceTool = - new LoadSkillResourceTool(createTestSkillSource()); - Map response = - loadSkillResourceTool - .runAsync( - ImmutableMap.of("skill_name", "test-skill", "file_path", "references/missing.md"), - createToolContext()) - .blockingGet(); - - assertThat(response) - .containsExactly( - "error", "Resource not found: references/missing.md", - "error_code", "RESOURCE_NOT_FOUND"); - } - - @Test - public void call_loadSkillResourceTool_declaration() { - LoadSkillResourceTool loadSkillResourceTool = - new LoadSkillResourceTool(mock(SkillSource.class)); - assertThat(loadSkillResourceTool.declaration().get().name()).hasValue("load_skill_resource"); - } - - @Test - public void call_loadSkillResourceTool_processLlmRequest_emptyContents() { - LoadSkillResourceTool loadSkillResourceTool = - new LoadSkillResourceTool(mock(SkillSource.class)); - LlmRequest.Builder builder = LlmRequest.builder().contents(ImmutableList.of()); - loadSkillResourceTool.processLlmRequest(builder, createToolContext()).blockingAwait(); - assertThat(builder.build().contents()).isEmpty(); - } - - private ToolContext createToolContext() { - TestBaseAgent testAgent = - new TestBaseAgent( - "test agent", "test agent", ImmutableList.of(), ImmutableList.of(), Flowable::empty); - Session session = Session.builder("session").build(); - - InvocationContext invocationContext = mock(InvocationContext.class); - when(invocationContext.agent()).thenReturn(testAgent); - when(invocationContext.session()).thenReturn(session); - - return ToolContext.builder(invocationContext).build(); - } - - private SkillSource createTestSkillSource() { - return InMemorySkillSource.builder() - .skill("test-skill") - .frontmatter(Frontmatter.builder().name("test-skill").description("test skill").build()) - .instructions("Test instructions") - .addResource("references/my_doc.md", "doc content".getBytes(UTF_8)) - .addResource("references/binary.dat", new byte[] {0, 1, 2, 3}) - .addResource("assets/template.txt", "asset content".getBytes(UTF_8)) - .addResource("scripts/setup.sh", "echo hello".getBytes(UTF_8)) - .addResource("assets/data_no_ext", "".getBytes(UTF_8)) - .build(); - } -} diff --git a/core/src/test/java/com/google/adk/tools/skills/LoadSkillToolTest.java b/core/src/test/java/com/google/adk/tools/skills/LoadSkillToolTest.java deleted file mode 100644 index 051b0fc11..000000000 --- a/core/src/test/java/com/google/adk/tools/skills/LoadSkillToolTest.java +++ /dev/null @@ -1,162 +0,0 @@ -/* - * Copyright 2026 Google LLC - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package com.google.adk.tools.skills; - -import static com.google.adk.skills.SkillSourceException.SKILL_NOT_FOUND; -import static com.google.common.truth.Truth.assertThat; -import static org.junit.Assert.assertThrows; -import static org.mockito.Mockito.mock; -import static org.mockito.Mockito.when; - -import com.google.adk.agents.InvocationContext; -import com.google.adk.sessions.Session; -import com.google.adk.skills.Frontmatter; -import com.google.adk.skills.InMemorySkillSource; -import com.google.adk.skills.SkillSource; -import com.google.adk.skills.SkillSourceException; -import com.google.adk.testing.TestBaseAgent; -import com.google.adk.tools.ToolContext; -import com.google.common.collect.ImmutableList; -import com.google.common.collect.ImmutableMap; -import io.reactivex.rxjava3.core.Flowable; -import io.reactivex.rxjava3.core.Single; -import java.util.Map; -import org.junit.Test; -import org.junit.runner.RunWith; -import org.junit.runners.JUnit4; - -@RunWith(JUnit4.class) -public class LoadSkillToolTest { - - @Test - public void call_loadSkillTool_success() { - TestBaseAgent testAgent = - new TestBaseAgent( - "test agent", "test agent", ImmutableList.of(), ImmutableList.of(), Flowable::empty); - Session session = Session.builder("session").build(); - - InvocationContext invocationContext = mock(InvocationContext.class); - when(invocationContext.agent()).thenReturn(testAgent); - when(invocationContext.session()).thenReturn(session); - - SkillSource skillSource = - InMemorySkillSource.builder() - .skill("test-skill") - .frontmatter(Frontmatter.builder().name("test-skill").description("test skill").build()) - .instructions("Test instructions") - .build(); - LoadSkillTool loadSkillTool = new LoadSkillTool(skillSource); - Map response = - loadSkillTool - .runAsync( - ImmutableMap.of("skill_name", "test-skill"), - ToolContext.builder(invocationContext).build()) - .blockingGet(); - - assertThat(response) - .containsExactly( - "skill_name", - "test-skill", - "instructions", - "Test instructions", - "frontmatter", - ImmutableMap.of( - "name", "test-skill", "description", "test skill", "metadata", ImmutableMap.of())); - } - - @Test - public void call_loadSkillTool_missingSkillName() { - TestBaseAgent testAgent = - new TestBaseAgent( - "test agent", "test agent", ImmutableList.of(), ImmutableList.of(), Flowable::empty); - Session session = Session.builder("session").build(); - - InvocationContext invocationContext = mock(InvocationContext.class); - when(invocationContext.agent()).thenReturn(testAgent); - when(invocationContext.session()).thenReturn(session); - - SkillSource skillSource = mock(SkillSource.class); - LoadSkillTool loadSkillTool = new LoadSkillTool(skillSource); - Map response = - loadSkillTool - .runAsync(ImmutableMap.of(), ToolContext.builder(invocationContext).build()) - .blockingGet(); - - assertThat(response) - .containsExactly("error", "Skill name is required.", "error_code", "MISSING_SKILL_NAME"); - } - - @Test - public void call_loadSkillTool_skillSourceException() { - TestBaseAgent testAgent = - new TestBaseAgent( - "test agent", "test agent", ImmutableList.of(), ImmutableList.of(), Flowable::empty); - Session session = Session.builder("session").build(); - - InvocationContext invocationContext = mock(InvocationContext.class); - when(invocationContext.agent()).thenReturn(testAgent); - when(invocationContext.session()).thenReturn(session); - - SkillSource skillSource = mock(SkillSource.class); - when(skillSource.loadFrontmatter("test-skill")) - .thenReturn(Single.error(new SkillSourceException("Skill not found", SKILL_NOT_FOUND))); - when(skillSource.loadInstructions("test-skill")).thenReturn(Single.just("instructions")); - - LoadSkillTool loadSkillTool = new LoadSkillTool(skillSource); - Map response = - loadSkillTool - .runAsync( - ImmutableMap.of("skill_name", "test-skill"), - ToolContext.builder(invocationContext).build()) - .blockingGet(); - - assertThat(response).containsExactly("error", "Skill not found", "error_code", SKILL_NOT_FOUND); - } - - @Test - public void call_loadSkillTool_otherException() { - TestBaseAgent testAgent = - new TestBaseAgent( - "test agent", "test agent", ImmutableList.of(), ImmutableList.of(), Flowable::empty); - Session session = Session.builder("session").build(); - - InvocationContext invocationContext = mock(InvocationContext.class); - when(invocationContext.agent()).thenReturn(testAgent); - when(invocationContext.session()).thenReturn(session); - - SkillSource skillSource = mock(SkillSource.class); - RuntimeException expectedException = new RuntimeException("Unexpected error"); - when(skillSource.loadFrontmatter("test-skill")).thenReturn(Single.error(expectedException)); - when(skillSource.loadInstructions("test-skill")).thenReturn(Single.just("instructions")); - - LoadSkillTool loadSkillTool = new LoadSkillTool(skillSource); - var single = - loadSkillTool.runAsync( - ImmutableMap.of("skill_name", "test-skill"), - ToolContext.builder(invocationContext).build()); - - RuntimeException thrown = assertThrows(RuntimeException.class, single::blockingGet); - - assertThat(thrown).hasMessageThat().contains("Unexpected error"); - } - - @Test - public void call_loadSkillTool_declaration() { - LoadSkillTool loadSkillTool = new LoadSkillTool(mock(SkillSource.class)); - assertThat(loadSkillTool.declaration().get().name()).hasValue("load_skill"); - } -} diff --git a/core/src/test/java/com/google/adk/tools/skills/SkillToolsetTest.java b/core/src/test/java/com/google/adk/tools/skills/SkillToolsetTest.java deleted file mode 100644 index 4be781469..000000000 --- a/core/src/test/java/com/google/adk/tools/skills/SkillToolsetTest.java +++ /dev/null @@ -1,143 +0,0 @@ -/* - * Copyright 2026 Google LLC - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package com.google.adk.tools.skills; - -import static com.google.common.truth.Truth.assertThat; -import static org.mockito.Mockito.mock; - -import com.google.adk.agents.ReadonlyContext; -import com.google.adk.models.LlmRequest; -import com.google.adk.skills.Frontmatter; -import com.google.adk.skills.InMemorySkillSource; -import com.google.adk.skills.SkillSource; -import com.google.adk.tools.BaseTool; -import com.google.adk.tools.BaseToolset; -import com.google.adk.tools.ToolContext; -import com.google.common.collect.ImmutableList; -import com.google.common.truth.Correspondence; -import io.reactivex.rxjava3.core.Flowable; -import java.util.List; -import org.junit.Test; -import org.junit.runner.RunWith; -import org.junit.runners.JUnit4; - -@RunWith(JUnit4.class) -public final class SkillToolsetTest { - - @Test - public void getTools_returnsCoreTools() throws Exception { - SkillSource mockSkillSource = mock(SkillSource.class); - try (SkillToolset toolSet = new SkillToolset(mockSkillSource)) { - Flowable tools = toolSet.getTools(null); - List baseTools = tools.toList().blockingGet(); - - assertThat(baseTools) - .comparingElementsUsing(Correspondence.transforming(BaseTool::name, "Tool name")) - .containsExactly("list_skills", "load_skill", "load_skill_resource"); - } - } - - @Test - public void getTools_withInMemorySkills() throws Exception { - SkillSource skillSource = - InMemorySkillSource.builder() - .skill("test-skill") - .frontmatter(Frontmatter.builder().name("test-skill").description("test skill").build()) - .instructions("Test instructions") - .build(); - try (SkillToolset toolSet = new SkillToolset(skillSource)) { - - Flowable tools = toolSet.getTools(null); - List baseTools = tools.toList().blockingGet(); - - assertThat(baseTools) - .comparingElementsUsing(Correspondence.transforming(BaseTool::name, "Tool name")) - .containsExactly("list_skills", "load_skill", "load_skill_resource"); - } - } - - @Test - public void processLlmRequest_addsInstructions() throws Exception { - SkillSource skillSource = - InMemorySkillSource.builder() - .skill("test-skill") - .frontmatter(Frontmatter.builder().name("test-skill").description("test skill").build()) - .instructions("Test instructions") - .build(); - try (SkillToolset toolSet = new SkillToolset(skillSource)) { - - LlmRequest.Builder requestBuilder = LlmRequest.builder(); - ToolContext mockToolContext = mock(ToolContext.class); - - toolSet.processLlmRequest(requestBuilder, mockToolContext).blockingAwait(); - - LlmRequest request = requestBuilder.build(); - ImmutableList instructions = request.getSystemInstructions(); - - assertThat(instructions).isNotEmpty(); - String instruction = instructions.get(0); - assertThat(instruction) - .contains("You can use specialized 'skills' to help you with complex tasks"); - assertThat(instruction).contains(""); - assertThat(instruction).contains("test-skill"); - } - } - - @Test - public void processLlmRequest_withCustomSystemInstruction_addsCustomInstructions() - throws Exception { - SkillSource skillSource = - InMemorySkillSource.builder() - .skill("test-skill") - .frontmatter(Frontmatter.builder().name("test-skill").description("test skill").build()) - .instructions("Test instructions") - .build(); - String customInstruction = "Custom system instruction for testing."; - try (SkillToolset toolSet = new SkillToolset(skillSource, customInstruction)) { - - LlmRequest.Builder requestBuilder = LlmRequest.builder(); - ToolContext mockToolContext = mock(ToolContext.class); - - toolSet.processLlmRequest(requestBuilder, mockToolContext).blockingAwait(); - - LlmRequest request = requestBuilder.build(); - ImmutableList instructions = request.getSystemInstructions(); - - assertThat(instructions).isNotEmpty(); - String instruction = instructions.get(0); - assertThat(instruction).contains(customInstruction); - assertThat(instruction).contains(""); - assertThat(instruction).contains("test-skill"); - } - } - - @Test - public void baseToolset_defaultProcessLlmRequest() throws Exception { - try (BaseToolset baseToolset = - new BaseToolset() { - @Override - public Flowable getTools(ReadonlyContext context) { - return Flowable.empty(); - } - - @Override - public void close() {} - }) { - baseToolset.processLlmRequest(LlmRequest.builder(), mock(ToolContext.class)).blockingAwait(); - } - } -}