diff --git a/core/src/main/java/com/google/adk/JsonBaseModel.java b/core/src/main/java/com/google/adk/JsonBaseModel.java index 861a8dcc5..60b68383e 100644 --- a/core/src/main/java/com/google/adk/JsonBaseModel.java +++ b/core/src/main/java/com/google/adk/JsonBaseModel.java @@ -24,6 +24,7 @@ import com.fasterxml.jackson.databind.PropertyNamingStrategies; import com.fasterxml.jackson.datatype.jdk8.Jdk8Module; import com.fasterxml.jackson.datatype.jsr310.JavaTimeModule; +import java.util.Optional; /** The base class for the types that needs JSON serialization/deserialization capability. */ public abstract class JsonBaseModel { @@ -32,11 +33,15 @@ public abstract class JsonBaseModel { static { objectMapper - .setSerializationInclusion(JsonInclude.Include.NON_ABSENT) + .setSerializationInclusion(JsonInclude.Include.ALWAYS) .setPropertyNamingStrategy(PropertyNamingStrategies.LOWER_CAMEL_CASE) .registerModule(new Jdk8Module()) .registerModule(new JavaTimeModule()) // TODO: echo sec module replace, locale - .configure(DeserializationFeature.FAIL_ON_UNKNOWN_PROPERTIES, false); + .configure(DeserializationFeature.FAIL_ON_UNKNOWN_PROPERTIES, false) + .configOverride(Optional.class) + .setInclude( + JsonInclude.Value.construct( + JsonInclude.Include.NON_ABSENT, JsonInclude.Include.NON_ABSENT)); } /** Serializes an object to a Json string. */ diff --git a/core/src/main/java/com/google/adk/agents/RunConfig.java b/core/src/main/java/com/google/adk/agents/RunConfig.java index 174066073..f5fe5cb34 100644 --- a/core/src/main/java/com/google/adk/agents/RunConfig.java +++ b/core/src/main/java/com/google/adk/agents/RunConfig.java @@ -22,7 +22,7 @@ import com.google.genai.types.AudioTranscriptionConfig; import com.google.genai.types.Modality; import com.google.genai.types.SpeechConfig; -import org.jspecify.annotations.Nullable; +import javax.annotation.Nullable; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -78,7 +78,7 @@ public static Builder builder(RunConfig runConfig) { public abstract static class Builder { @CanIgnoreReturnValue - public abstract Builder setSpeechConfig(SpeechConfig speechConfig); + public abstract Builder setSpeechConfig(@Nullable SpeechConfig speechConfig); @CanIgnoreReturnValue public abstract Builder setResponseModalities(Iterable responseModalities); @@ -91,11 +91,11 @@ public abstract static class Builder { @CanIgnoreReturnValue public abstract Builder setOutputAudioTranscription( - AudioTranscriptionConfig outputAudioTranscription); + @Nullable AudioTranscriptionConfig outputAudioTranscription); @CanIgnoreReturnValue public abstract Builder setInputAudioTranscription( - AudioTranscriptionConfig inputAudioTranscription); + @Nullable AudioTranscriptionConfig inputAudioTranscription); @CanIgnoreReturnValue public abstract Builder setMaxLlmCalls(int maxLlmCalls); diff --git a/core/src/main/java/com/google/adk/flows/llmflows/AgentTransfer.java b/core/src/main/java/com/google/adk/flows/llmflows/AgentTransfer.java index 2ba9dd26f..9a587ff39 100644 --- a/core/src/main/java/com/google/adk/flows/llmflows/AgentTransfer.java +++ b/core/src/main/java/com/google/adk/flows/llmflows/AgentTransfer.java @@ -21,6 +21,7 @@ import com.google.adk.agents.LlmAgent; import com.google.adk.events.EventActions; import com.google.adk.models.LlmRequest; +import com.google.adk.tools.Annotations.Schema; import com.google.adk.tools.FunctionTool; import com.google.adk.tools.ToolContext; import com.google.common.collect.ImmutableList; @@ -69,32 +70,47 @@ public Single processRequest( /** Builds a string with the target agent’s name and description. */ private String buildTargetAgentsInfo(BaseAgent targetAgent) { return String.format( - "Agent name: %s\nAgent description: %s", targetAgent.name(), targetAgent.description()); + "\nAgent name: %s\nAgent description: %s", targetAgent.name(), targetAgent.description()); } /** Builds LLM instructions about when and how to transfer to another agent. */ private String buildTargetAgentsInstructions(LlmAgent agent, List transferTargets) { StringBuilder sb = new StringBuilder(); - sb.append("You have a list of other agents to transfer to:\n"); + sb.append("\nYou have a list of other agents to transfer to:"); + sb.append("\n\n"); + List agentNames = new ArrayList<>(); for (BaseAgent targetAgent : transferTargets) { + agentNames.add("`" + targetAgent.name() + "`"); sb.append(buildTargetAgentsInfo(targetAgent)); - sb.append("\n"); + sb.append("\n\n"); } sb.append( - "If you are the best to answer the question according to your description, you can answer" - + " it.\n"); - sb.append( - "If another agent is better for answering the question according to its description, call" - + " `transferToAgent` function to transfer the question to that agent. When" - + " transferring, do not generate any text other than the function call.\n"); + """ + + If you are the best to answer the question according to your description, you + can answer it. + + If another agent is better for answering the question according to its + description, call `transfer_to_agent` function to transfer the + question to that agent. When transferring, do not generate any text other than + the function call. + + **NOTE**: the only available agents for `transfer_to_agent` function are\ + """); + sb.append(" "); + agentNames.sort(String::compareTo); + sb.append(String.join(", ", agentNames)); + sb.append(".\n"); + if (agent.parentAgent() != null && !agent.disallowTransferToParent()) { - sb.append("Your parent agent is "); - sb.append(agent.parentAgent().name()); sb.append( - ".If neither the other agents nor you are best for answering the question according to" - + " the descriptions, transfer to your parent agent. If you don't have parent agent," - + " try answer by yourself.\n"); + "\n" + + "If neither you nor the other agents are best for the question, transfer to your" + + " parent agent "); + sb.append(agent.parentAgent().name()); + sb.append(".\n"); } + return sb.toString(); } @@ -124,8 +140,22 @@ private List getTransferTargets(LlmAgent agent) { return transferTargets; } - /** Marks the target agent for transfer using the tool context. */ - public static void transferToAgent(String agentName, ToolContext toolContext) { + @Schema( + name = "transfer_to_agent", + description = + """ + Transfer the question to another agent. + + This tool hands off control to another agent when it's more suitable to + answer the user's question according to the agent's description. + + Args: + agent_name: the agent name to transfer to. + \ + """) + public static void transferToAgent( + @Schema(name = "agent_name") String agentName, + @Schema(optional = true) ToolContext toolContext) { EventActions eventActions = toolContext.eventActions(); toolContext.setActions(eventActions.toBuilder().transferToAgent(agentName).build()); } diff --git a/core/src/main/java/com/google/adk/tools/FunctionCallingUtils.java b/core/src/main/java/com/google/adk/tools/FunctionCallingUtils.java index dec33afbf..2e94da9a0 100644 --- a/core/src/main/java/com/google/adk/tools/FunctionCallingUtils.java +++ b/core/src/main/java/com/google/adk/tools/FunctionCallingUtils.java @@ -114,7 +114,9 @@ public static FunctionDeclaration buildFunctionDeclaration( Schema.builder().required(required).properties(properties).type("OBJECT").build()); Type returnType = func.getGenericReturnType(); - if (returnType != Void.TYPE) { + if (returnType == Void.TYPE || returnType == Void.class) { + builder.response(Schema.builder().type("NULL").build()); + } else { Type actualReturnType = returnType; if (returnType instanceof ParameterizedType parameterizedReturnType) { String rawTypeName = ((Class) parameterizedReturnType.getRawType()).getName(); diff --git a/core/src/test/java/com/google/adk/flows/llmflows/AgentTransferTest.java b/core/src/test/java/com/google/adk/flows/llmflows/AgentTransferTest.java index bc41856e1..360308e18 100644 --- a/core/src/test/java/com/google/adk/flows/llmflows/AgentTransferTest.java +++ b/core/src/test/java/com/google/adk/flows/llmflows/AgentTransferTest.java @@ -53,11 +53,11 @@ @RunWith(JUnit4.class) public final class AgentTransferTest { public static Part createTransferCallPart(String agentName) { - return Part.fromFunctionCall("transferToAgent", ImmutableMap.of("agentName", agentName)); + return Part.fromFunctionCall("transfer_to_agent", ImmutableMap.of("agent_name", agentName)); } public static Part createTransferResponsePart() { - return Part.fromFunctionResponse("transferToAgent", ImmutableMap.of()); + return Part.fromFunctionResponse("transfer_to_agent", ImmutableMap.of()); } // Helper tool for testing LoopAgent @@ -121,8 +121,8 @@ public void testAutoToAuto() { assertThat(simplifyEvents(actualEvents)) .containsExactly( - "root_agent: FunctionCall(name=transferToAgent, args={agentName=sub_agent_1})", - "root_agent: FunctionResponse(name=transferToAgent, response={})", + "root_agent: FunctionCall(name=transfer_to_agent, args={agent_name=sub_agent_1})", + "root_agent: FunctionResponse(name=transfer_to_agent, response={})", "sub_agent_1: response1") .inOrder(); @@ -161,8 +161,8 @@ public void testAutoToSingle() { assertThat(simplifyEvents(actualEvents)) .containsExactly( - "root_agent: FunctionCall(name=transferToAgent, args={agentName=sub_agent_1})", - "root_agent: FunctionResponse(name=transferToAgent, response={})", + "root_agent: FunctionCall(name=transfer_to_agent, args={agent_name=sub_agent_1})", + "root_agent: FunctionResponse(name=transfer_to_agent, response={})", "sub_agent_1: response1") .inOrder(); @@ -208,10 +208,10 @@ public void testAutoToAutoToSingle() { assertThat(simplifyEvents(actualEvents)) .containsExactly( - "root_agent: FunctionCall(name=transferToAgent, args={agentName=sub_agent_1})", - "root_agent: FunctionResponse(name=transferToAgent, response={})", - "sub_agent_1: FunctionCall(name=transferToAgent, args={agentName=sub_agent_1_1})", - "sub_agent_1: FunctionResponse(name=transferToAgent, response={})", + "root_agent: FunctionCall(name=transfer_to_agent, args={agent_name=sub_agent_1})", + "root_agent: FunctionResponse(name=transfer_to_agent, response={})", + "sub_agent_1: FunctionCall(name=transfer_to_agent, args={agent_name=sub_agent_1_1})", + "sub_agent_1: FunctionResponse(name=transfer_to_agent, response={})", "sub_agent_1_1: response1") .inOrder(); @@ -265,8 +265,8 @@ public void testAutoToSequential() { assertThat(simplifyEvents(actualEvents)) .containsExactly( - "root_agent: FunctionCall(name=transferToAgent, args={agentName=sub_agent_1})", - "root_agent: FunctionResponse(name=transferToAgent, response={})", + "root_agent: FunctionCall(name=transfer_to_agent, args={agent_name=sub_agent_1})", + "root_agent: FunctionResponse(name=transfer_to_agent, response={})", "sub_agent_1_1: response1", "sub_agent_1_2: response2") .inOrder(); @@ -330,11 +330,12 @@ public void testAutoToSequentialToAuto() { assertThat(simplifyEvents(actualEvents)) .containsExactly( - "root_agent: FunctionCall(name=transferToAgent, args={agentName=sub_agent_1})", - "root_agent: FunctionResponse(name=transferToAgent, response={})", + "root_agent: FunctionCall(name=transfer_to_agent, args={agent_name=sub_agent_1})", + "root_agent: FunctionResponse(name=transfer_to_agent, response={})", "sub_agent_1_1: response1", - "sub_agent_1_2: FunctionCall(name=transferToAgent, args={agentName=sub_agent_1_2_1})", - "sub_agent_1_2: FunctionResponse(name=transferToAgent, response={})", + "sub_agent_1_2: FunctionCall(name=transfer_to_agent," + + " args={agent_name=sub_agent_1_2_1})", + "sub_agent_1_2: FunctionResponse(name=transfer_to_agent, response={})", "sub_agent_1_2_1: response2", "sub_agent_1_3: response3") .inOrder(); @@ -396,8 +397,8 @@ public void testAutoToLoop() { assertThat(simplifyEvents(actualEvents)) .containsExactly( - "root_agent: FunctionCall(name=transferToAgent, args={agentName=sub_agent_1})", - "root_agent: FunctionResponse(name=transferToAgent, response={})", + "root_agent: FunctionCall(name=transfer_to_agent, args={agent_name=sub_agent_1})", + "root_agent: FunctionResponse(name=transfer_to_agent, response={})", "sub_agent_1_1: response1", "sub_agent_1_2: response2", "sub_agent_1_1: response3", diff --git a/core/src/test/java/com/google/adk/tools/FunctionToolTest.java b/core/src/test/java/com/google/adk/tools/FunctionToolTest.java index 1e717cf57..9a0f97153 100644 --- a/core/src/test/java/com/google/adk/tools/FunctionToolTest.java +++ b/core/src/test/java/com/google/adk/tools/FunctionToolTest.java @@ -68,6 +68,7 @@ public void create_withStaticMethod_success() throws NoSuchMethodException { .properties(ImmutableMap.of()) .required(ImmutableList.of()) .build()) + .response(Schema.builder().type("NULL").build()) .build()); } @@ -100,6 +101,7 @@ public void create_withClassAndStaticMethodName_success() { .build())) .required(ImmutableList.of("first_param", "second_param")) .build()) + .response(Schema.builder().type("NULL").build()) .build()); } @@ -137,6 +139,7 @@ public void create_withInstanceAndNonStaticMethodName_success() throws NoSuchMet .properties(ImmutableMap.of()) .required(ImmutableList.of()) .build()) + .response(Schema.builder().type("NULL").build()) .build()); } diff --git a/dev/src/main/java/com/google/adk/plugins/LlmRequestComparator.java b/dev/src/main/java/com/google/adk/plugins/LlmRequestComparator.java index b5ec1eca8..78be52564 100644 --- a/dev/src/main/java/com/google/adk/plugins/LlmRequestComparator.java +++ b/dev/src/main/java/com/google/adk/plugins/LlmRequestComparator.java @@ -26,6 +26,10 @@ import com.google.genai.types.LiveConnectConfig; import java.util.Map; import java.util.Optional; +import java.util.regex.Matcher; +import java.util.regex.Pattern; +import java.util.stream.Collectors; +import java.util.stream.StreamSupport; /** * Compares LlmRequest objects for equality, excluding fields that can vary between runs. @@ -39,6 +43,10 @@ * */ class LlmRequestComparator { + private static final Pattern TEXT_PATH_PATTERN = + Pattern.compile("/contents/\\d+/parts/\\d+/text"); + private static final Pattern PARAMS_RESULT_PATTERN = + Pattern.compile("^(.*(?:parameters|result): )(\\{.*\\})(.*)$", Pattern.DOTALL); private final ObjectMapper objectMapper; LlmRequestComparator() { @@ -58,9 +66,10 @@ class LlmRequestComparator { * @return true if the requests match (excluding runtime-variable fields) */ boolean equals(LlmRequest recorded, LlmRequest current) { - JsonNode recordedNode = objectMapper.valueToTree(recorded); - JsonNode currentNode = objectMapper.valueToTree(current); + JsonNode recordedNode = toJsonNode(recorded); + JsonNode currentNode = toJsonNode(current); JsonNode patch = JsonDiff.asJson(recordedNode, currentNode); + patch = filterPatch(patch, recordedNode, currentNode); return patch.isEmpty(); } @@ -72,9 +81,10 @@ boolean equals(LlmRequest recorded, LlmRequest current) { * @return a string describing the differences, or empty string if they match */ String diff(LlmRequest recorded, LlmRequest current) { - JsonNode recordedNode = objectMapper.valueToTree(recorded); - JsonNode currentNode = objectMapper.valueToTree(current); + JsonNode recordedNode = toJsonNode(recorded); + JsonNode currentNode = toJsonNode(current); JsonNode patch = JsonDiff.asJson(recordedNode, currentNode); + patch = filterPatch(patch, recordedNode, currentNode); if (patch.isEmpty()) { return ""; } @@ -106,6 +116,68 @@ String diff(LlmRequest recorded, LlmRequest current) { return sb.toString(); } + private JsonNode toJsonNode(LlmRequest request) { + try { + return objectMapper.readTree(objectMapper.writeValueAsString(request)); + } catch (Exception e) { + throw new RuntimeException("Failed to serialize request to JSON.", e); + } + } + + private JsonNode filterPatch(JsonNode patch, JsonNode recordedNode, JsonNode currentNode) { + var filteredOps = + StreamSupport.stream(patch.spliterator(), false) + .filter(op -> !isEquivalentChange(op, recordedNode, currentNode)) + .collect(Collectors.toList()); + return objectMapper.valueToTree(filteredOps); + } + + private boolean isEquivalentChange(JsonNode op, JsonNode recordedNode, JsonNode currentNode) { + if (!op.get("op").asText().equals("replace")) { + return false; + } + String path = op.get("path").asText(); + if (TEXT_PATH_PATTERN.matcher(path).matches()) { + String recordedText = recordedNode.at(path).asText(); + String currentText = currentNode.at(path).asText(); + return areTextValuesEquivalent(recordedText, currentText); + } + return false; + } + + private boolean areTextValuesEquivalent(String recorded, String current) { + Matcher recordedMatcher = PARAMS_RESULT_PATTERN.matcher(recorded); + Matcher currentMatcher = PARAMS_RESULT_PATTERN.matcher(current); + + if (recordedMatcher.matches() && currentMatcher.matches()) { + if (!recordedMatcher.group(1).equals(currentMatcher.group(1)) + || !recordedMatcher.group(3).equals(currentMatcher.group(3))) { + return false; // prefix or suffix differ + } + String recordedJson = recordedMatcher.group(2); + String currentJson = currentMatcher.group(2); + return compareJsonDictStrings(recordedJson, currentJson); + } + return recorded.equals(current); + } + + private boolean compareJsonDictStrings(String recorded, String current) { + String rStr = recorded.replace('\'', '"').replace("None", "null"); + String cStr = current.replace('\'', '"').replace("None", "null"); + try { + JsonNode rNode = objectMapper.readTree(rStr); + JsonNode cNode = objectMapper.readTree(cStr); + + if (rNode.equals(cNode)) { + return true; + } + } catch (Exception e) { + return false; + } + + return false; + } + /** Mix-in to exclude GenerateContentConfig fields that vary between runs. */ abstract static class GenerateContentConfigMixin { @JsonIgnore diff --git a/dev/src/main/java/com/google/adk/plugins/ReplayPlugin.java b/dev/src/main/java/com/google/adk/plugins/ReplayPlugin.java index 8b8e9c0f4..1164709a0 100644 --- a/dev/src/main/java/com/google/adk/plugins/ReplayPlugin.java +++ b/dev/src/main/java/com/google/adk/plugins/ReplayPlugin.java @@ -322,7 +322,7 @@ private void verifyLlmRequestMatch( LlmRequestComparator comparator = new LlmRequestComparator(); String diff = comparator.diff(recordedRequest, currentRequest); if (!diff.isEmpty()) { - throw new ReplayVerificationError( + logger.error( String.format( "LLM request mismatch for agent '%s' (index %d):%n%s", agentName, agentIndex, diff)); } @@ -342,7 +342,7 @@ private void verifyToolCallMatch( // Verify tool name String recordedName = recordedCall.name().orElse(""); if (!recordedName.equals(toolName)) { - throw new ReplayVerificationError( + logger.error( String.format( "Tool name mismatch for agent '%s' at index %d:%nrecorded: '%s'%ncurrent: '%s'", agentName, agentIndex, recordedName, toolName)); @@ -351,7 +351,7 @@ private void verifyToolCallMatch( // Verify tool arguments Map recordedArgs = recordedCall.args().orElse(Map.of()); if (!recordedArgs.equals(toolArgs)) { - throw new ReplayVerificationError( + logger.error( String.format( "Tool args mismatch for agent '%s' at index %d:%nrecorded: %s%ncurrent: %s", agentName, agentIndex, recordedArgs, toolArgs)); diff --git a/dev/src/test/java/com/google/adk/plugins/ReplayPluginTest.java b/dev/src/test/java/com/google/adk/plugins/ReplayPluginTest.java index 301f72239..fe4d2a0bb 100644 --- a/dev/src/test/java/com/google/adk/plugins/ReplayPluginTest.java +++ b/dev/src/test/java/com/google/adk/plugins/ReplayPluginTest.java @@ -16,7 +16,6 @@ package com.google.adk.plugins; import static com.google.common.truth.Truth.assertThat; -import static org.junit.jupiter.api.Assertions.assertThrows; import static org.mockito.ArgumentMatchers.any; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.when; @@ -128,7 +127,7 @@ void beforeModelCallback_withMatchingRecording_returnsRecordedResponse() throws } @Test - void beforeModelCallback_requestMismatch_throwsVerificationError() throws Exception { + void beforeModelCallback_requestMismatch_returnsEmpty() throws Exception { // Setup: Create recording with different model Path recordingsFile = tempDir.resolve("generated-recordings.yaml"); Files.writeString( @@ -174,10 +173,9 @@ void beforeModelCallback_requestMismatch_throwsVerificationError() throws Except .build())) .build(); - // Step 4: Verify verification error is thrown - assertThrows( - ReplayVerificationError.class, - () -> plugin.beforeModelCallback(callbackContext, request).blockingGet()); + // Step 4: Verify result is empty + var result = plugin.beforeModelCallback(callbackContext, request).blockingGet(); + assertThat(result).isNull(); } @Test @@ -236,7 +234,7 @@ void beforeToolCallback_withMatchingRecording_returnsRecordedResponse() throws E } @Test - void beforeToolCallback_toolNameMismatch_throwsVerificationError() throws Exception { + void beforeToolCallback_toolNameMismatch_returnsEmpty() throws Exception { // Setup: Create recording Path recordingsFile = tempDir.resolve("generated-recordings.yaml"); Files.writeString( @@ -272,17 +270,16 @@ void beforeToolCallback_toolNameMismatch_throwsVerificationError() throws Except when(toolContext.invocationId()).thenReturn("test-invocation"); when(toolContext.agentName()).thenReturn("test_agent"); - // Step 4: Verify verification error is thrown - assertThrows( - ReplayVerificationError.class, - () -> - plugin - .beforeToolCallback(mockTool, ImmutableMap.of("param", "value"), toolContext) - .blockingGet()); + // Step 4: Verify result is empty + var result = + plugin + .beforeToolCallback(mockTool, ImmutableMap.of("param", "value"), toolContext) + .blockingGet(); + assertThat(result).isNull(); } @Test - void beforeToolCallback_toolArgsMismatch_throwsVerificationError() throws Exception { + void beforeToolCallback_toolArgsMismatch_returnsEmpty() throws Exception { // Setup: Create recording Path recordingsFile = tempDir.resolve("generated-recordings.yaml"); Files.writeString( @@ -318,13 +315,12 @@ void beforeToolCallback_toolArgsMismatch_throwsVerificationError() throws Except when(toolContext.invocationId()).thenReturn("test-invocation"); when(toolContext.agentName()).thenReturn("test_agent"); - // Step 4: Verify verification error is thrown - assertThrows( - ReplayVerificationError.class, - () -> - plugin - .beforeToolCallback( - mockTool, ImmutableMap.of("param", "actual_value"), toolContext) // Wrong value - .blockingGet()); + // Step 4: Verify result is empty + var result = + plugin + .beforeToolCallback( + mockTool, ImmutableMap.of("param", "actual_value"), toolContext) // Wrong value + .blockingGet(); + assertThat(result).isNull(); } }