diff --git a/core/src/main/java/com/google/adk/agents/RunConfig.java b/core/src/main/java/com/google/adk/agents/RunConfig.java index 747454bed..001dfe870 100644 --- a/core/src/main/java/com/google/adk/agents/RunConfig.java +++ b/core/src/main/java/com/google/adk/agents/RunConfig.java @@ -41,17 +41,26 @@ public enum StreamingMode { /** * Execution mode when the model requests multiple tools. * - *

NONE: defaults to SEQUENTIAL. + *

NONE: defaults to PARALLEL. * - *

SEQUENTIAL: tools execute in request order on the caller thread. + *

SEQUENTIAL: tools execute strictly in request order on the caller thread; each tool must + * complete (including any asynchronous work) before the next one is subscribed to. * - *

PARALLEL: tools execute concurrently on worker threads. Tool implementations must be - * thread-safe. + *

PARALLEL: tools are subscribed to eagerly on the caller thread (i.e. all are kicked off + * up-front), but no worker threads are introduced. Tools that are truly asynchronous (e.g. they + * return a {@code Single} backed by I/O or another scheduler) will run concurrently; tools that + * block the subscribing thread (e.g. {@code Single.fromCallable} that performs blocking work) + * will still execute sequentially. This preserves the historical default behavior. + * + *

PARALLEL_SUBSCRIBE: like {@code PARALLEL}, but every tool is additionally subscribed on a + * worker thread, so blocking tools also run concurrently. Tool implementations must be + * thread-safe. The worker is the agent's executor when set, otherwise the RxJava IO scheduler. */ public enum ToolExecutionMode { NONE, SEQUENTIAL, - PARALLEL + PARALLEL, + PARALLEL_SUBSCRIBE } public abstract @Nullable SpeechConfig speechConfig(); diff --git a/core/src/main/java/com/google/adk/flows/llmflows/Functions.java b/core/src/main/java/com/google/adk/flows/llmflows/Functions.java index 19dfbc5dc..4aa20798d 100644 --- a/core/src/main/java/com/google/adk/flows/llmflows/Functions.java +++ b/core/src/main/java/com/google/adk/flows/llmflows/Functions.java @@ -236,23 +236,40 @@ public static Maybe handleFunctionCallsLive( } /** - * Sequential by default; only {@link ToolExecutionMode#PARALLEL} with multiple calls dispatches - * tools on workers (using {@code concatMapEager} to preserve input order). + * Builds the tool-execution {@link Observable} for the configured {@link ToolExecutionMode}. + * + *

*/ private static Observable buildToolExecutionObservable( InvocationContext invocationContext, List validFunctionCalls, Function> functionCallMapper) { - boolean parallel = - invocationContext.runConfig().toolExecutionMode() == ToolExecutionMode.PARALLEL - && validFunctionCalls.size() > 1; - if (!parallel) { + ToolExecutionMode mode = invocationContext.runConfig().toolExecutionMode(); + boolean sequential = mode == ToolExecutionMode.SEQUENTIAL || validFunctionCalls.size() <= 1; + if (sequential) { return Observable.fromIterable(validFunctionCalls).concatMapMaybe(functionCallMapper); } - Scheduler scheduler = resolveToolExecutionScheduler(invocationContext); + if (mode == ToolExecutionMode.PARALLEL_SUBSCRIBE) { + Scheduler scheduler = resolveToolExecutionScheduler(invocationContext); + return Observable.fromIterable(validFunctionCalls) + .concatMapEager( + call -> functionCallMapper.apply(call).toObservable().subscribeOn(scheduler)); + } + // PARALLEL (and NONE, which defaults to PARALLEL): eager subscribe on the caller thread, + // without offloading to a worker. Async tools run concurrently; blocking tools still block. return Observable.fromIterable(validFunctionCalls) - .concatMapEager( - call -> functionCallMapper.apply(call).toObservable().subscribeOn(scheduler)); + .concatMapEager(call -> functionCallMapper.apply(call).toObservable()); } /** Agent executor if set, otherwise the IO scheduler. */ diff --git a/core/src/test/java/com/google/adk/flows/llmflows/FunctionsTest.java b/core/src/test/java/com/google/adk/flows/llmflows/FunctionsTest.java index 3ce31deb4..b2ffc4443 100644 --- a/core/src/test/java/com/google/adk/flows/llmflows/FunctionsTest.java +++ b/core/src/test/java/com/google/adk/flows/llmflows/FunctionsTest.java @@ -398,9 +398,10 @@ public void getAskUserConfirmationFunctionCalls_eventWithConfirmationFunctionCal assertThat(result).containsExactly(confirmationCall1, confirmationCall2); } - // Default ToolExecutionMode.NONE must execute tools sequentially. + // Default ToolExecutionMode.NONE behaves like PARALLEL: blocking tools still execute serially + // on the caller thread (no worker scheduler is used), preserving the historical default. @Test - public void handleFunctionCalls_defaultMode_blockingTools_runSequentially() { + public void handleFunctionCalls_defaultMode_blockingTools_runSerially() { long sleepMillis = 300L; int toolCount = 2; InvocationContext invocationContext = @@ -435,29 +436,69 @@ public void handleFunctionCalls_defaultMode_blockingTools_runSequentially() { assertThat(durationMillis).isAtLeast((long) toolCount * sleepMillis); } + // PARALLEL mode does NOT introduce worker threads; blocking tools still run serially on the + // caller thread. PARALLEL_SUBSCRIBE is the mode that runs blocking tools concurrently. @Test - public void handleFunctionCalls_parallel_blockingTools_runConcurrently_twoTools() { - runParallelBlockingToolsTest(/* toolCount= */ 2); + public void handleFunctionCalls_parallel_blockingTools_runSerially() { + long sleepMillis = 300L; + int toolCount = 2; + InvocationContext invocationContext = + createInvocationContext( + createRootAgent(), + RunConfig.builder().setToolExecutionMode(ToolExecutionMode.PARALLEL).build()); + + Map tools = new LinkedHashMap<>(); + List callParts = new ArrayList<>(); + for (int i = 1; i <= toolCount; i++) { + String toolName = "slow_tool_" + i; + tools.put(toolName, new SleepingTool(toolName, sleepMillis)); + callParts.add( + Part.builder() + .functionCall( + FunctionCall.builder() + .id("call_" + i) + .name(toolName) + .args(ImmutableMap.of()) + .build()) + .build()); + } + Event event = + createEvent("event").toBuilder() + .content(Content.fromParts(callParts.toArray(new Part[0]))) + .build(); + + long start = System.currentTimeMillis(); + Event functionResponseEvent = + Functions.handleFunctionCalls(invocationContext, event, tools).blockingGet(); + long durationMillis = System.currentTimeMillis() - start; + + assertThat(functionResponseEvent).isNotNull(); + assertThat(durationMillis).isAtLeast((long) toolCount * sleepMillis); } @Test - public void handleFunctionCalls_parallel_blockingTools_runConcurrently_threeTools() { - runParallelBlockingToolsTest(/* toolCount= */ 3); + public void handleFunctionCalls_parallelSubscribe_blockingTools_runConcurrently_twoTools() { + runParallelSubscribeBlockingToolsTest(/* toolCount= */ 2); } @Test - public void handleFunctionCalls_parallel_blockingTools_runConcurrently_fiveTools() { - runParallelBlockingToolsTest(/* toolCount= */ 5); + public void handleFunctionCalls_parallelSubscribe_blockingTools_runConcurrently_threeTools() { + runParallelSubscribeBlockingToolsTest(/* toolCount= */ 3); + } + + @Test + public void handleFunctionCalls_parallelSubscribe_blockingTools_runConcurrently_fiveTools() { + runParallelSubscribeBlockingToolsTest(/* toolCount= */ 5); } /** Single-tool case bypasses the parallel scheduler path; must still return the correct event. */ @Test - public void handleFunctionCalls_parallel_blockingTool_singleTool() { + public void handleFunctionCalls_parallelSubscribe_blockingTool_singleTool() { long sleepMillis = 200L; InvocationContext invocationContext = createInvocationContext( createRootAgent(), - RunConfig.builder().setToolExecutionMode(ToolExecutionMode.PARALLEL).build()); + RunConfig.builder().setToolExecutionMode(ToolExecutionMode.PARALLEL_SUBSCRIBE).build()); SleepingTool tool = new SleepingTool("slow_tool_1", sleepMillis); Event event = createEvent("event").toBuilder() @@ -491,13 +532,16 @@ public void handleFunctionCalls_parallel_blockingTool_singleTool() { .build()); } - /** Asserts that {@code toolCount} blocking tools in PARALLEL mode run faster than sequential. */ - private static void runParallelBlockingToolsTest(int toolCount) { + /** + * Asserts that {@code toolCount} blocking tools in PARALLEL_SUBSCRIBE mode run faster than + * sequential, since each tool is subscribed on a worker thread. + */ + private static void runParallelSubscribeBlockingToolsTest(int toolCount) { long sleepMillis = 500L; InvocationContext invocationContext = createInvocationContext( createRootAgent(), - RunConfig.builder().setToolExecutionMode(ToolExecutionMode.PARALLEL).build()); + RunConfig.builder().setToolExecutionMode(ToolExecutionMode.PARALLEL_SUBSCRIBE).build()); Map tools = new LinkedHashMap<>(); List callParts = new ArrayList<>();