diff --git a/core/src/main/java/com/google/adk/agents/RunConfig.java b/core/src/main/java/com/google/adk/agents/RunConfig.java index f5fe5cb34..2f8e417d7 100644 --- a/core/src/main/java/com/google/adk/agents/RunConfig.java +++ b/core/src/main/java/com/google/adk/agents/RunConfig.java @@ -38,6 +38,22 @@ public enum StreamingMode { BIDI } + /** + * Tool execution mode for the runner, when they are multiple tools requested (by the models or + * callbacks). + * + *

NONE: default to PARALLEL. + * + *

SEQUENTIAL: Multiple tools are executed in the order they are requested. + * + *

PARALLEL: Multiple tools are executed in parallel. + */ + public enum ToolExecutionMode { + NONE, + SEQUENTIAL, + PARALLEL + } + public abstract @Nullable SpeechConfig speechConfig(); public abstract ImmutableList responseModalities(); @@ -46,6 +62,8 @@ public enum StreamingMode { public abstract StreamingMode streamingMode(); + public abstract ToolExecutionMode toolExecutionMode(); + public abstract @Nullable AudioTranscriptionConfig outputAudioTranscription(); public abstract @Nullable AudioTranscriptionConfig inputAudioTranscription(); @@ -59,6 +77,7 @@ public static Builder builder() { .setSaveInputBlobsAsArtifacts(false) .setResponseModalities(ImmutableList.of()) .setStreamingMode(StreamingMode.NONE) + .setToolExecutionMode(ToolExecutionMode.NONE) .setMaxLlmCalls(500); } @@ -66,6 +85,7 @@ public static Builder builder(RunConfig runConfig) { return new AutoValue_RunConfig.Builder() .setSaveInputBlobsAsArtifacts(runConfig.saveInputBlobsAsArtifacts()) .setStreamingMode(runConfig.streamingMode()) + .setToolExecutionMode(runConfig.toolExecutionMode()) .setMaxLlmCalls(runConfig.maxLlmCalls()) .setResponseModalities(runConfig.responseModalities()) .setSpeechConfig(runConfig.speechConfig()) @@ -89,6 +109,9 @@ public abstract static class Builder { @CanIgnoreReturnValue public abstract Builder setStreamingMode(StreamingMode streamingMode); + @CanIgnoreReturnValue + public abstract Builder setToolExecutionMode(ToolExecutionMode toolExecutionMode); + @CanIgnoreReturnValue public abstract Builder setOutputAudioTranscription( @Nullable AudioTranscriptionConfig outputAudioTranscription); diff --git a/core/src/main/java/com/google/adk/flows/llmflows/Functions.java b/core/src/main/java/com/google/adk/flows/llmflows/Functions.java index 234283c64..61cbab5a1 100644 --- a/core/src/main/java/com/google/adk/flows/llmflows/Functions.java +++ b/core/src/main/java/com/google/adk/flows/llmflows/Functions.java @@ -25,6 +25,7 @@ import com.google.adk.agents.Callbacks.BeforeToolCallback; import com.google.adk.agents.InvocationContext; import com.google.adk.agents.LlmAgent; +import com.google.adk.agents.RunConfig.ToolExecutionMode; import com.google.adk.events.Event; import com.google.adk.events.EventActions; import com.google.adk.tools.BaseTool; @@ -198,7 +199,13 @@ public static Maybe handleFunctionCalls( functionResponseEvents.add(maybeFunctionResponseEvent); } - return Maybe.merge(functionResponseEvents) + Flowable functionResponseEventsFlowable; + if (invocationContext.runConfig().toolExecutionMode() == ToolExecutionMode.SEQUENTIAL) { + functionResponseEventsFlowable = Maybe.concat(functionResponseEvents); + } else { + functionResponseEventsFlowable = Maybe.merge(functionResponseEvents); + } + return functionResponseEventsFlowable .toList() .flatMapMaybe( events -> { @@ -296,7 +303,13 @@ public static Maybe handleFunctionCallsLive( responseEvents.add(maybeFunctionResponseEvent); } - return Maybe.merge(responseEvents) + Flowable responseEventsFlowable; + if (invocationContext.runConfig().toolExecutionMode() == ToolExecutionMode.SEQUENTIAL) { + responseEventsFlowable = Maybe.concat(responseEvents); + } else { + responseEventsFlowable = Maybe.merge(responseEvents); + } + return responseEventsFlowable .toList() .flatMapMaybe( events -> {