Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 7 additions & 2 deletions core/src/main/java/com/google/adk/JsonBaseModel.java
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
import com.fasterxml.jackson.databind.PropertyNamingStrategies;
import com.fasterxml.jackson.datatype.jdk8.Jdk8Module;
import com.fasterxml.jackson.datatype.jsr310.JavaTimeModule;
import java.util.Optional;

/** The base class for the types that needs JSON serialization/deserialization capability. */
public abstract class JsonBaseModel {
Expand All @@ -32,11 +33,15 @@ public abstract class JsonBaseModel {

static {
objectMapper
.setSerializationInclusion(JsonInclude.Include.NON_ABSENT)
.setSerializationInclusion(JsonInclude.Include.ALWAYS)
.setPropertyNamingStrategy(PropertyNamingStrategies.LOWER_CAMEL_CASE)
.registerModule(new Jdk8Module())
.registerModule(new JavaTimeModule()) // TODO: echo sec module replace, locale
.configure(DeserializationFeature.FAIL_ON_UNKNOWN_PROPERTIES, false);
.configure(DeserializationFeature.FAIL_ON_UNKNOWN_PROPERTIES, false)
.configOverride(Optional.class)
.setInclude(
JsonInclude.Value.construct(
JsonInclude.Include.NON_ABSENT, JsonInclude.Include.NON_ABSENT));
}

/** Serializes an object to a Json string. */
Expand Down
8 changes: 4 additions & 4 deletions core/src/main/java/com/google/adk/agents/RunConfig.java
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
import com.google.genai.types.AudioTranscriptionConfig;
import com.google.genai.types.Modality;
import com.google.genai.types.SpeechConfig;
import org.jspecify.annotations.Nullable;
import javax.annotation.Nullable;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

Expand Down Expand Up @@ -78,7 +78,7 @@ public static Builder builder(RunConfig runConfig) {
public abstract static class Builder {

@CanIgnoreReturnValue
public abstract Builder setSpeechConfig(SpeechConfig speechConfig);
public abstract Builder setSpeechConfig(@Nullable SpeechConfig speechConfig);

@CanIgnoreReturnValue
public abstract Builder setResponseModalities(Iterable<Modality> responseModalities);
Expand All @@ -91,11 +91,11 @@ public abstract static class Builder {

@CanIgnoreReturnValue
public abstract Builder setOutputAudioTranscription(
AudioTranscriptionConfig outputAudioTranscription);
@Nullable AudioTranscriptionConfig outputAudioTranscription);

@CanIgnoreReturnValue
public abstract Builder setInputAudioTranscription(
AudioTranscriptionConfig inputAudioTranscription);
@Nullable AudioTranscriptionConfig inputAudioTranscription);

@CanIgnoreReturnValue
public abstract Builder setMaxLlmCalls(int maxLlmCalls);
Expand Down
62 changes: 46 additions & 16 deletions core/src/main/java/com/google/adk/flows/llmflows/AgentTransfer.java
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
import com.google.adk.agents.LlmAgent;
import com.google.adk.events.EventActions;
import com.google.adk.models.LlmRequest;
import com.google.adk.tools.Annotations.Schema;
import com.google.adk.tools.FunctionTool;
import com.google.adk.tools.ToolContext;
import com.google.common.collect.ImmutableList;
Expand Down Expand Up @@ -69,32 +70,47 @@ public Single<RequestProcessor.RequestProcessingResult> processRequest(
/** Builds a string with the target agent’s name and description. */
private String buildTargetAgentsInfo(BaseAgent targetAgent) {
return String.format(
"Agent name: %s\nAgent description: %s", targetAgent.name(), targetAgent.description());
"\nAgent name: %s\nAgent description: %s", targetAgent.name(), targetAgent.description());
}

/** Builds LLM instructions about when and how to transfer to another agent. */
private String buildTargetAgentsInstructions(LlmAgent agent, List<BaseAgent> transferTargets) {
StringBuilder sb = new StringBuilder();
sb.append("You have a list of other agents to transfer to:\n");
sb.append("\nYou have a list of other agents to transfer to:");
sb.append("\n\n");
List<String> agentNames = new ArrayList<>();
for (BaseAgent targetAgent : transferTargets) {
agentNames.add("`" + targetAgent.name() + "`");
sb.append(buildTargetAgentsInfo(targetAgent));
sb.append("\n");
sb.append("\n\n");
}
sb.append(
"If you are the best to answer the question according to your description, you can answer"
+ " it.\n");
sb.append(
"If another agent is better for answering the question according to its description, call"
+ " `transferToAgent` function to transfer the question to that agent. When"
+ " transferring, do not generate any text other than the function call.\n");
"""

If you are the best to answer the question according to your description, you
can answer it.

If another agent is better for answering the question according to its
description, call `transfer_to_agent` function to transfer the
question to that agent. When transferring, do not generate any text other than
the function call.

**NOTE**: the only available agents for `transfer_to_agent` function are\
""");
sb.append(" ");
agentNames.sort(String::compareTo);
sb.append(String.join(", ", agentNames));
sb.append(".\n");

if (agent.parentAgent() != null && !agent.disallowTransferToParent()) {
sb.append("Your parent agent is ");
sb.append(agent.parentAgent().name());
sb.append(
".If neither the other agents nor you are best for answering the question according to"
+ " the descriptions, transfer to your parent agent. If you don't have parent agent,"
+ " try answer by yourself.\n");
"\n"
+ "If neither you nor the other agents are best for the question, transfer to your"
+ " parent agent ");
sb.append(agent.parentAgent().name());
sb.append(".\n");
}

return sb.toString();
}

Expand Down Expand Up @@ -124,8 +140,22 @@ private List<BaseAgent> getTransferTargets(LlmAgent agent) {
return transferTargets;
}

/** Marks the target agent for transfer using the tool context. */
public static void transferToAgent(String agentName, ToolContext toolContext) {
@Schema(
name = "transfer_to_agent",
description =
"""
Transfer the question to another agent.

This tool hands off control to another agent when it's more suitable to
answer the user's question according to the agent's description.

Args:
agent_name: the agent name to transfer to.
\
""")
public static void transferToAgent(
@Schema(name = "agent_name") String agentName,
@Schema(optional = true) ToolContext toolContext) {
EventActions eventActions = toolContext.eventActions();
toolContext.setActions(eventActions.toBuilder().transferToAgent(agentName).build());
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,9 @@ public static FunctionDeclaration buildFunctionDeclaration(
Schema.builder().required(required).properties(properties).type("OBJECT").build());

Type returnType = func.getGenericReturnType();
if (returnType != Void.TYPE) {
if (returnType == Void.TYPE || returnType == Void.class) {
builder.response(Schema.builder().type("NULL").build());
} else {
Type actualReturnType = returnType;
if (returnType instanceof ParameterizedType parameterizedReturnType) {
String rawTypeName = ((Class<?>) parameterizedReturnType.getRawType()).getName();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -53,11 +53,11 @@
@RunWith(JUnit4.class)
public final class AgentTransferTest {
public static Part createTransferCallPart(String agentName) {
return Part.fromFunctionCall("transferToAgent", ImmutableMap.of("agentName", agentName));
return Part.fromFunctionCall("transfer_to_agent", ImmutableMap.of("agent_name", agentName));
}

public static Part createTransferResponsePart() {
return Part.fromFunctionResponse("transferToAgent", ImmutableMap.<String, Object>of());
return Part.fromFunctionResponse("transfer_to_agent", ImmutableMap.<String, Object>of());
}

// Helper tool for testing LoopAgent
Expand Down Expand Up @@ -121,8 +121,8 @@ public void testAutoToAuto() {

assertThat(simplifyEvents(actualEvents))
.containsExactly(
"root_agent: FunctionCall(name=transferToAgent, args={agentName=sub_agent_1})",
"root_agent: FunctionResponse(name=transferToAgent, response={})",
"root_agent: FunctionCall(name=transfer_to_agent, args={agent_name=sub_agent_1})",
"root_agent: FunctionResponse(name=transfer_to_agent, response={})",
"sub_agent_1: response1")
.inOrder();

Expand Down Expand Up @@ -161,8 +161,8 @@ public void testAutoToSingle() {

assertThat(simplifyEvents(actualEvents))
.containsExactly(
"root_agent: FunctionCall(name=transferToAgent, args={agentName=sub_agent_1})",
"root_agent: FunctionResponse(name=transferToAgent, response={})",
"root_agent: FunctionCall(name=transfer_to_agent, args={agent_name=sub_agent_1})",
"root_agent: FunctionResponse(name=transfer_to_agent, response={})",
"sub_agent_1: response1")
.inOrder();

Expand Down Expand Up @@ -208,10 +208,10 @@ public void testAutoToAutoToSingle() {

assertThat(simplifyEvents(actualEvents))
.containsExactly(
"root_agent: FunctionCall(name=transferToAgent, args={agentName=sub_agent_1})",
"root_agent: FunctionResponse(name=transferToAgent, response={})",
"sub_agent_1: FunctionCall(name=transferToAgent, args={agentName=sub_agent_1_1})",
"sub_agent_1: FunctionResponse(name=transferToAgent, response={})",
"root_agent: FunctionCall(name=transfer_to_agent, args={agent_name=sub_agent_1})",
"root_agent: FunctionResponse(name=transfer_to_agent, response={})",
"sub_agent_1: FunctionCall(name=transfer_to_agent, args={agent_name=sub_agent_1_1})",
"sub_agent_1: FunctionResponse(name=transfer_to_agent, response={})",
"sub_agent_1_1: response1")
.inOrder();

Expand Down Expand Up @@ -265,8 +265,8 @@ public void testAutoToSequential() {

assertThat(simplifyEvents(actualEvents))
.containsExactly(
"root_agent: FunctionCall(name=transferToAgent, args={agentName=sub_agent_1})",
"root_agent: FunctionResponse(name=transferToAgent, response={})",
"root_agent: FunctionCall(name=transfer_to_agent, args={agent_name=sub_agent_1})",
"root_agent: FunctionResponse(name=transfer_to_agent, response={})",
"sub_agent_1_1: response1",
"sub_agent_1_2: response2")
.inOrder();
Expand Down Expand Up @@ -330,11 +330,12 @@ public void testAutoToSequentialToAuto() {

assertThat(simplifyEvents(actualEvents))
.containsExactly(
"root_agent: FunctionCall(name=transferToAgent, args={agentName=sub_agent_1})",
"root_agent: FunctionResponse(name=transferToAgent, response={})",
"root_agent: FunctionCall(name=transfer_to_agent, args={agent_name=sub_agent_1})",
"root_agent: FunctionResponse(name=transfer_to_agent, response={})",
"sub_agent_1_1: response1",
"sub_agent_1_2: FunctionCall(name=transferToAgent, args={agentName=sub_agent_1_2_1})",
"sub_agent_1_2: FunctionResponse(name=transferToAgent, response={})",
"sub_agent_1_2: FunctionCall(name=transfer_to_agent,"
+ " args={agent_name=sub_agent_1_2_1})",
"sub_agent_1_2: FunctionResponse(name=transfer_to_agent, response={})",
"sub_agent_1_2_1: response2",
"sub_agent_1_3: response3")
.inOrder();
Expand Down Expand Up @@ -396,8 +397,8 @@ public void testAutoToLoop() {

assertThat(simplifyEvents(actualEvents))
.containsExactly(
"root_agent: FunctionCall(name=transferToAgent, args={agentName=sub_agent_1})",
"root_agent: FunctionResponse(name=transferToAgent, response={})",
"root_agent: FunctionCall(name=transfer_to_agent, args={agent_name=sub_agent_1})",
"root_agent: FunctionResponse(name=transfer_to_agent, response={})",
"sub_agent_1_1: response1",
"sub_agent_1_2: response2",
"sub_agent_1_1: response3",
Expand Down
3 changes: 3 additions & 0 deletions core/src/test/java/com/google/adk/tools/FunctionToolTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@ public void create_withStaticMethod_success() throws NoSuchMethodException {
.properties(ImmutableMap.of())
.required(ImmutableList.of())
.build())
.response(Schema.builder().type("NULL").build())
.build());
}

Expand Down Expand Up @@ -100,6 +101,7 @@ public void create_withClassAndStaticMethodName_success() {
.build()))
.required(ImmutableList.of("first_param", "second_param"))
.build())
.response(Schema.builder().type("NULL").build())
.build());
}

Expand Down Expand Up @@ -137,6 +139,7 @@ public void create_withInstanceAndNonStaticMethodName_success() throws NoSuchMet
.properties(ImmutableMap.of())
.required(ImmutableList.of())
.build())
.response(Schema.builder().type("NULL").build())
.build());
}

Expand Down
80 changes: 76 additions & 4 deletions dev/src/main/java/com/google/adk/plugins/LlmRequestComparator.java
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,10 @@
import com.google.genai.types.LiveConnectConfig;
import java.util.Map;
import java.util.Optional;
import java.util.regex.Matcher;
import java.util.regex.Pattern;
import java.util.stream.Collectors;
import java.util.stream.StreamSupport;

/**
* Compares LlmRequest objects for equality, excluding fields that can vary between runs.
Expand All @@ -39,6 +43,10 @@
* </ul>
*/
class LlmRequestComparator {
private static final Pattern TEXT_PATH_PATTERN =
Pattern.compile("/contents/\\d+/parts/\\d+/text");
private static final Pattern PARAMS_RESULT_PATTERN =
Pattern.compile("^(.*(?:parameters|result): )(\\{.*\\})(.*)$", Pattern.DOTALL);
private final ObjectMapper objectMapper;

LlmRequestComparator() {
Expand All @@ -58,9 +66,10 @@ class LlmRequestComparator {
* @return true if the requests match (excluding runtime-variable fields)
*/
boolean equals(LlmRequest recorded, LlmRequest current) {
JsonNode recordedNode = objectMapper.valueToTree(recorded);
JsonNode currentNode = objectMapper.valueToTree(current);
JsonNode recordedNode = toJsonNode(recorded);
JsonNode currentNode = toJsonNode(current);
JsonNode patch = JsonDiff.asJson(recordedNode, currentNode);
patch = filterPatch(patch, recordedNode, currentNode);
return patch.isEmpty();
}

Expand All @@ -72,9 +81,10 @@ boolean equals(LlmRequest recorded, LlmRequest current) {
* @return a string describing the differences, or empty string if they match
*/
String diff(LlmRequest recorded, LlmRequest current) {
JsonNode recordedNode = objectMapper.valueToTree(recorded);
JsonNode currentNode = objectMapper.valueToTree(current);
JsonNode recordedNode = toJsonNode(recorded);
JsonNode currentNode = toJsonNode(current);
JsonNode patch = JsonDiff.asJson(recordedNode, currentNode);
patch = filterPatch(patch, recordedNode, currentNode);
if (patch.isEmpty()) {
return "";
}
Expand Down Expand Up @@ -106,6 +116,68 @@ String diff(LlmRequest recorded, LlmRequest current) {
return sb.toString();
}

private JsonNode toJsonNode(LlmRequest request) {
try {
return objectMapper.readTree(objectMapper.writeValueAsString(request));
} catch (Exception e) {
throw new RuntimeException("Failed to serialize request to JSON.", e);
}
}

private JsonNode filterPatch(JsonNode patch, JsonNode recordedNode, JsonNode currentNode) {
var filteredOps =
StreamSupport.stream(patch.spliterator(), false)
.filter(op -> !isEquivalentChange(op, recordedNode, currentNode))
.collect(Collectors.toList());
return objectMapper.valueToTree(filteredOps);
}

private boolean isEquivalentChange(JsonNode op, JsonNode recordedNode, JsonNode currentNode) {
if (!op.get("op").asText().equals("replace")) {
return false;
}
String path = op.get("path").asText();
if (TEXT_PATH_PATTERN.matcher(path).matches()) {
String recordedText = recordedNode.at(path).asText();
String currentText = currentNode.at(path).asText();
return areTextValuesEquivalent(recordedText, currentText);
}
return false;
}

private boolean areTextValuesEquivalent(String recorded, String current) {
Matcher recordedMatcher = PARAMS_RESULT_PATTERN.matcher(recorded);
Matcher currentMatcher = PARAMS_RESULT_PATTERN.matcher(current);

if (recordedMatcher.matches() && currentMatcher.matches()) {
if (!recordedMatcher.group(1).equals(currentMatcher.group(1))
|| !recordedMatcher.group(3).equals(currentMatcher.group(3))) {
return false; // prefix or suffix differ
}
String recordedJson = recordedMatcher.group(2);
String currentJson = currentMatcher.group(2);
return compareJsonDictStrings(recordedJson, currentJson);
}
return recorded.equals(current);
}

private boolean compareJsonDictStrings(String recorded, String current) {
String rStr = recorded.replace('\'', '"').replace("None", "null");
String cStr = current.replace('\'', '"').replace("None", "null");
try {
JsonNode rNode = objectMapper.readTree(rStr);
JsonNode cNode = objectMapper.readTree(cStr);

if (rNode.equals(cNode)) {
return true;
}
} catch (Exception e) {
return false;
}

return false;
}

/** Mix-in to exclude GenerateContentConfig fields that vary between runs. */
abstract static class GenerateContentConfigMixin {
@JsonIgnore
Expand Down
Loading