diff --git a/core/src/main/java/com/google/adk/runner/Runner.java b/core/src/main/java/com/google/adk/runner/Runner.java index 26a523fdd..1ab101398 100644 --- a/core/src/main/java/com/google/adk/runner/Runner.java +++ b/core/src/main/java/com/google/adk/runner/Runner.java @@ -16,6 +16,8 @@ package com.google.adk.runner; +import static com.google.common.collect.ImmutableSet.toImmutableSet; + import com.google.adk.agents.ActiveStreamingTool; import com.google.adk.agents.BaseAgent; import com.google.adk.agents.ContextCacheConfig; @@ -45,6 +47,8 @@ import com.google.adk.utils.CollectionUtils; import com.google.common.base.Preconditions; import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableSet; +import com.google.common.collect.Iterables; import com.google.common.collect.MapMaker; import com.google.errorprone.annotations.CanIgnoreReturnValue; import com.google.genai.types.AudioTranscriptionConfig; @@ -64,6 +68,7 @@ import java.util.Collections; import java.util.List; import java.util.Map; +import java.util.Objects; import java.util.Optional; import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.ConcurrentMap; @@ -772,12 +777,15 @@ private boolean isTransferableAcrossAgentTree(BaseAgent agentToRun) { return true; } - /** - * Returns the agent that should handle the next request based on session history. - * - * @return agent to run. - */ + /** Returns the agent that should handle the next request based on session history. */ private BaseAgent findAgentToRun(Session session, BaseAgent rootAgent) { + // Route function responses back to the originating function-call author so HITL tool + // confirmations resume the sub-agent even through non-LlmAgent ancestors. + Optional functionCallAuthor = findFunctionCallAuthor(session, rootAgent); + if (functionCallAuthor.isPresent()) { + return functionCallAuthor.get(); + } + List events = new ArrayList<>(session.events()); Collections.reverse(events); @@ -808,6 +816,39 @@ private BaseAgent findAgentToRun(Session session, BaseAgent rootAgent) { return rootAgent; } + /** + * If the last event is a function response, returns the agent that emitted the matching function + * call (by id), or empty if no match is found in the agent tree. + */ + private static Optional findFunctionCallAuthor(Session session, BaseAgent rootAgent) { + List events = session.events(); + if (events.isEmpty()) { + return Optional.empty(); + } + ImmutableSet functionResponseIds = + Iterables.getLast(events).functionResponses().stream() + .map(fr -> fr.id().orElse(null)) + .filter(Objects::nonNull) + .collect(toImmutableSet()); + + // Iterate in reverse to prefer the most recent matching call, mirroring Python ADK's + // find_event_by_function_call_id. Function call IDs are unique in normal flows, so this + // is defense-in-depth and not covered by mutation testing. + List precedingEvents = new ArrayList<>(events.subList(0, events.size() - 1)); + Collections.reverse(precedingEvents); + for (Event event : precedingEvents) { + boolean matches = + event.functionCalls().stream() + .map(fc -> fc.id().orElse(null)) + .filter(Objects::nonNull) + .anyMatch(functionResponseIds::contains); + if (matches && event.author() != null) { + return rootAgent.findAgent(event.author()); + } + } + return Optional.empty(); + } + private void addActiveStreamingTools(InvocationContext invocationContext, List tools) { tools.stream() .filter(FunctionTool.class::isInstance) diff --git a/core/src/test/java/com/google/adk/runner/RunnerTest.java b/core/src/test/java/com/google/adk/runner/RunnerTest.java index 00d5d63bf..95718e3e0 100644 --- a/core/src/test/java/com/google/adk/runner/RunnerTest.java +++ b/core/src/test/java/com/google/adk/runner/RunnerTest.java @@ -42,6 +42,7 @@ import com.google.adk.agents.LiveRequestQueue; import com.google.adk.agents.LlmAgent; import com.google.adk.agents.RunConfig; +import com.google.adk.agents.SequentialAgent; import com.google.adk.apps.App; import com.google.adk.artifacts.BaseArtifactService; import com.google.adk.events.Event; @@ -1604,6 +1605,107 @@ public void runAsync_withToolConfirmation() { .inOrder(); } + // HITL tool confirmation must resume the originating sub-agent even when wrapped inside a + // non-LlmAgent workflow agent (e.g. SequentialAgent). + @Test + public void runAsync_withToolConfirmation_inSequentialAgentSubAgent_resumesSubAgent() { + TestLlm childTestLlm = + createTestLlm( + createFunctionCallLlmResponse( + "tool_call_id", "echoTool", ImmutableMap.of("message", "hello")), + createTextLlmResponse("Response after observing tool needs confirmation."), + createTextLlmResponse("Response after user confirmed.")); + LlmAgent childAgent = + createTestAgentBuilder(childTestLlm) + .name("child_agent") + .tools(FunctionTool.create(Tools.class, "echoTool", /* requireConfirmation= */ true)) + .build(); + SequentialAgent workflowAgent = + SequentialAgent.builder() + .name("workflow_agent") + .subAgents(ImmutableList.of(childAgent)) + .build(); + // Root transfers to workflow_agent to mirror the bug report's control flow. + TestLlm rootTestLlm = + createTestLlm( + createLlmResponse( + Content.fromParts( + Part.fromFunctionCall( + "transfer_to_agent", ImmutableMap.of("agent_name", "workflow_agent"))))); + LlmAgent rootAgent = + createTestAgentBuilder(rootTestLlm) + .name("root_agent") + .subAgents(ImmutableList.of(workflowAgent)) + .build(); + Runner runner = + Runner.builder().app(App.builder().name("test").rootAgent(rootAgent).build()).build(); + Session session = runner.sessionService().createSession("test", "user").blockingGet(); + + List eventsBeforeConfirmation = + runner + .runAsync("user", session.id(), Content.fromParts(Part.fromText("from user"))) + .toList() + .blockingGet(); + FunctionCall askUserConfirmationFunctionCall = + Iterables.getOnlyElement( + eventsBeforeConfirmation.stream() + .map(Functions::getAskUserConfirmationFunctionCalls) + .filter(functionCalls -> !functionCalls.isEmpty()) + .findFirst() + .get()); + List eventsAfterConfirmation = + runner + .runAsync( + "user", + session.id(), + Content.fromParts( + Part.builder() + .functionResponse( + FunctionResponse.builder() + .id(askUserConfirmationFunctionCall.id().get()) + .name(askUserConfirmationFunctionCall.name().get()) + .response(ImmutableMap.of("confirmed", true))) + .build())) + .toList() + .blockingGet(); + + // The originating child agent (not the root agent) must execute the tool. + assertThat(simplifyEvents(eventsAfterConfirmation)) + .containsExactly( + "child_agent: FunctionResponse(name=echoTool, response={message=hello})", + "child_agent: Response after user confirmed.") + .inOrder(); + } + + // Orphan function responses (id not matching any prior call) should fall back to the root agent. + @Test + public void runAsync_withFunctionResponseNotMatchingAnyCall_fallsBackToRootAgent() { + TestLlm rootLlm = createTestLlm(createTextLlmResponse("after function response")); + LlmAgent rootAgent = createTestAgentBuilder(rootLlm).name("root_agent").build(); + Runner runner = + Runner.builder().app(App.builder().name("test").rootAgent(rootAgent).build()).build(); + Session session = runner.sessionService().createSession("test", "user").blockingGet(); + + // Function response with id that does not match any prior function call. + List events = + runner + .runAsync( + "user", + session.id(), + Content.fromParts( + Part.builder() + .functionResponse( + FunctionResponse.builder() + .id("non_existent_id") + .name("orphanFn") + .response(ImmutableMap.of("x", 1))) + .build())) + .toList() + .blockingGet(); + + assertThat(simplifyEvents(events)).containsExactly("root_agent: after function response"); + } + @Test public void close_closesPluginsAndCodeExecutors() { BasePlugin plugin = mockPlugin("close_test_plugin");