From eff8f8d14a881c692cc85ee10cf371d755eef999 Mon Sep 17 00:00:00 2001 From: Damian Momot Date: Wed, 27 May 2026 02:16:31 -0700 Subject: [PATCH] fix: Fix ADK Runner race condition for sequential tool execution BaseLlmFlow.run() now appends each event synchronously inside the per-step concatMap, so the next runOneStep sees the previous step's events. Without this, a tool relying on prior events (e.g. a BeforeToolCallback producing a function response) could see stale history and re-call the tool or hallucinate its result. PiperOrigin-RevId: 921989444 --- .../java/com/google/adk/agents/LlmAgent.java | 4 +- .../adk/flows/llmflows/BaseLlmFlow.java | 72 +++- .../java/com/google/adk/runner/Runner.java | 55 ++- .../com/google/adk/sessions/SessionUtils.java | 23 ++ .../com/google/adk/agents/LlmAgentTest.java | 114 ++++++ .../com/google/adk/runner/RunnerTest.java | 349 ++++++++++++++++++ 6 files changed, 601 insertions(+), 16 deletions(-) diff --git a/core/src/main/java/com/google/adk/agents/LlmAgent.java b/core/src/main/java/com/google/adk/agents/LlmAgent.java index 98bba4606..7886b4493 100644 --- a/core/src/main/java/com/google/adk/agents/LlmAgent.java +++ b/core/src/main/java/com/google/adk/agents/LlmAgent.java @@ -654,7 +654,9 @@ private static boolean isThought(Part part) { @Override protected Flowable runAsyncImpl(InvocationContext invocationContext) { - return llmFlow.run(invocationContext).doOnNext(this::maybeSaveOutputToState); + // maybeSaveOutputToState runs as a pre-persist finalizer so the outputKey stateDelta is + // part of the persisted append performed inside BaseLlmFlow.run. + return llmFlow.run(invocationContext, this::maybeSaveOutputToState); } @Override diff --git a/core/src/main/java/com/google/adk/flows/llmflows/BaseLlmFlow.java b/core/src/main/java/com/google/adk/flows/llmflows/BaseLlmFlow.java index dffba0e80..f50c85861 100644 --- a/core/src/main/java/com/google/adk/flows/llmflows/BaseLlmFlow.java +++ b/core/src/main/java/com/google/adk/flows/llmflows/BaseLlmFlow.java @@ -37,6 +37,7 @@ import com.google.adk.models.LlmRegistry; import com.google.adk.models.LlmRequest; import com.google.adk.models.LlmResponse; +import com.google.adk.sessions.SessionUtils; import com.google.adk.telemetry.Tracing; import com.google.adk.tools.BaseTool; import com.google.adk.tools.BaseToolset; @@ -60,8 +61,10 @@ import java.util.List; import java.util.Optional; import java.util.Set; +import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.atomic.AtomicReference; import java.util.function.BiFunction; +import java.util.function.Consumer; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -498,12 +501,54 @@ private Flowable runOneStep(Context spanContext, InvocationContext contex */ @Override public Flowable run(InvocationContext invocationContext) { - return run(Context.current(), invocationContext, 0); + return run(invocationContext, event -> {}); + } + + /** + * Same as {@link #run(InvocationContext)} but invokes {@code eventPreFinalize} on each event + * authored by this agent immediately before it is appended, so any mutation (e.g. {@code + * stateDelta} updates from {@link LlmAgent}'s {@code outputKey}) is part of the persisted append. + */ + public Flowable run( + InvocationContext invocationContext, Consumer eventPreFinalize) { + return run(Context.current(), invocationContext, 0, eventPreFinalize); } private Flowable run( - Context spanContext, InvocationContext invocationContext, int stepsCompleted) { - Flowable currentStepEvents = runOneStep(spanContext, invocationContext).cache(); + Context spanContext, + InvocationContext invocationContext, + int stepsCompleted, + Consumer eventPreFinalize) { + // Append each event synchronously so the next runOneStep sees prior events (avoids racing the + // Runner's append). Record each appended id so the Runner skips re-appending it; skip ids + // already recorded (e.g. bubbling up from a sub-agent's flow). Emit the original event, not the + // service's return (which may be a mock sentinel). + String thisAgentName = invocationContext.agent().name(); + Flowable currentStepEvents = + runOneStep(spanContext, invocationContext) + .concatMap( + event -> { + String eventId = event.id(); + if (eventId != null + && inFlightAppendedEventIds(invocationContext).contains(eventId)) { + return Flowable.just(event); + } + if (thisAgentName != null && thisAgentName.equals(event.author())) { + eventPreFinalize.accept(event); + } + return SessionUtils.safeAppendEvent( + invocationContext.sessionService(), invocationContext.session(), event) + .ignoreElement() + .andThen( + Flowable.fromCallable( + () -> { + if (eventId != null) { + inFlightAppendedEventIds(invocationContext).add(eventId); + } + return event; + })); + }) + .cache(); if (stepsCompleted + 1 >= maxSteps) { logger.debug("Ending flow execution because max steps reached."); return currentStepEvents; @@ -523,11 +568,30 @@ private Flowable run( return Flowable.empty(); } else { logger.debug("Continuing to next step of the flow."); - return run(spanContext, invocationContext, stepsCompleted + 1); + return run( + spanContext, invocationContext, stepsCompleted + 1, eventPreFinalize); } })); } + private static final String IN_FLIGHT_APPENDED_EVENT_IDS_KEY = + "com.google.adk.internal.inFlightAppendedEventIds"; + + /** + * Returns the transient, per-invocation set of event ids appended by the flow but not yet + * consumed by the Runner, lazily creating it. Ids are added here on append and removed by the + * Runner on consume, so this is hand-off state -- not a record of all persisted events. + */ + @SuppressWarnings("unchecked") + private static Set inFlightAppendedEventIds(InvocationContext invocationContext) { + return (Set) + invocationContext + .callbackContextData() + .computeIfAbsent( + IN_FLIGHT_APPENDED_EVENT_IDS_KEY, + unusedKey -> ConcurrentHashMap.newKeySet()); + } + /** * Executes the LLM flow in streaming mode. * diff --git a/core/src/main/java/com/google/adk/runner/Runner.java b/core/src/main/java/com/google/adk/runner/Runner.java index 1ab101398..8ca4b4cf7 100644 --- a/core/src/main/java/com/google/adk/runner/Runner.java +++ b/core/src/main/java/com/google/adk/runner/Runner.java @@ -38,6 +38,7 @@ import com.google.adk.sessions.InMemorySessionService; import com.google.adk.sessions.Session; import com.google.adk.sessions.SessionKey; +import com.google.adk.sessions.SessionUtils; import com.google.adk.summarizer.EventsCompactionConfig; import com.google.adk.summarizer.LlmEventSummarizer; import com.google.adk.summarizer.SlidingWindowEventCompactor; @@ -70,6 +71,7 @@ import java.util.Map; import java.util.Objects; import java.util.Optional; +import java.util.Set; import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.ConcurrentMap; import org.jspecify.annotations.Nullable; @@ -581,19 +583,32 @@ private Flowable runAgentWithUpdatedSession( .agent() .runAsync(contextWithUpdatedSession) .concatMap( - agentEvent -> - this.sessionService - .appendEvent(updatedSession, agentEvent) - .flatMap( - registeredEvent -> { - // TODO: remove this hack after deprecating runAsync with Session. - copySessionStates(updatedSession, initialContext.session()); - return contextWithUpdatedSession + agentEvent -> { + // TODO: remove this hack after deprecating runAsync with Session. + copySessionStates(updatedSession, initialContext.session()); + // Skip events already appended by BaseLlmFlow.run (id recorded in the shared + // set), + // removing the id to keep the set bounded. Append everything else (agent-callback + // events, non-LlmAgent leaves) here, so each event is appended exactly once. + String agentEventId = agentEvent.id(); + boolean alreadyPersisted = + agentEventId != null + && inFlightAppendedEventIds(contextWithUpdatedSession) + .remove(agentEventId); + Single appendResult = + alreadyPersisted + ? Single.just(agentEvent) + : SessionUtils.safeAppendEvent( + this.sessionService, updatedSession, agentEvent); + return appendResult + .flatMap( + registeredEvent -> + contextWithUpdatedSession .pluginManager() .onEventCallback(contextWithUpdatedSession, registeredEvent) - .defaultIfEmpty(registeredEvent); - }) - .toFlowable()); + .defaultIfEmpty(registeredEvent)) + .toFlowable(); + }); // If beforeRunCallback returns content, emit it and skip agent Context capturedContext = Context.current(); @@ -619,6 +634,24 @@ private void copySessionStates(Session source, Session target) { target.state().putAll(source.state()); } + private static final String IN_FLIGHT_APPENDED_EVENT_IDS_KEY = + "com.google.adk.internal.inFlightAppendedEventIds"; + + /** + * Returns the transient, per-invocation set of event ids appended by the flow but not yet + * consumed by the Runner, lazily creating it. Ids are added by the flow on append and removed + * here on consume, so this is hand-off state -- not a record of all persisted events. + */ + @SuppressWarnings("unchecked") + private static Set inFlightAppendedEventIds(InvocationContext invocationContext) { + return (Set) + invocationContext + .callbackContextData() + .computeIfAbsent( + IN_FLIGHT_APPENDED_EVENT_IDS_KEY, + unusedKey -> ConcurrentHashMap.newKeySet()); + } + /** * Creates an {@link InvocationContext} for a live (streaming) run. * diff --git a/core/src/main/java/com/google/adk/sessions/SessionUtils.java b/core/src/main/java/com/google/adk/sessions/SessionUtils.java index 1aeca98c9..acb8a6556 100644 --- a/core/src/main/java/com/google/adk/sessions/SessionUtils.java +++ b/core/src/main/java/com/google/adk/sessions/SessionUtils.java @@ -16,10 +16,12 @@ package com.google.adk.sessions; +import com.google.adk.events.Event; import com.google.common.collect.ImmutableList; import com.google.genai.types.Blob; import com.google.genai.types.Content; import com.google.genai.types.Part; +import io.reactivex.rxjava3.core.Single; import java.util.ArrayList; import java.util.Base64; import java.util.List; @@ -31,6 +33,27 @@ public final class SessionUtils { public SessionUtils() {} + /** + * Appends {@code event} via {@code service}, or just to {@code session.events()} when the session + * is partial (no {@code appName}). The partial-session bypass exists for unit tests that build + * {@code Session.builder(id).build()} and bypass {@code Runner}; most production services + * (including {@link InMemorySessionService}) {@code requireNonNull(appName)}. Production callers + * always pass fully-formed sessions and hit the unchanged {@code service.appendEvent} path. + */ + public static Single safeAppendEvent( + BaseSessionService service, Session session, Event event) { + if (session.appName() == null) { + List events = session.events(); + if (events != null) { + synchronized (events) { + events.add(event); + } + } + return Single.just(event); + } + return service.appendEvent(session, event); + } + /** Base64-encodes inline blobs in content. */ public static Content encodeContent(Content content) { List encodedParts = new ArrayList<>(); diff --git a/core/src/test/java/com/google/adk/agents/LlmAgentTest.java b/core/src/test/java/com/google/adk/agents/LlmAgentTest.java index 26843bb56..1ae0ee6a8 100644 --- a/core/src/test/java/com/google/adk/agents/LlmAgentTest.java +++ b/core/src/test/java/com/google/adk/agents/LlmAgentTest.java @@ -39,7 +39,11 @@ import com.google.adk.models.LlmRequest; import com.google.adk.models.LlmResponse; import com.google.adk.models.Model; +import com.google.adk.sessions.BaseSessionService; +import com.google.adk.sessions.GetSessionConfig; import com.google.adk.sessions.InMemorySessionService; +import com.google.adk.sessions.ListEventsResponse; +import com.google.adk.sessions.ListSessionsResponse; import com.google.adk.sessions.Session; import com.google.adk.telemetry.Tracing; import com.google.adk.testing.TestLlm; @@ -58,11 +62,14 @@ import io.opentelemetry.api.trace.Tracer; import io.opentelemetry.sdk.testing.junit4.OpenTelemetryRule; import io.opentelemetry.sdk.trace.data.SpanData; +import io.reactivex.rxjava3.core.Completable; import io.reactivex.rxjava3.core.Flowable; import io.reactivex.rxjava3.core.Maybe; import io.reactivex.rxjava3.core.Single; import java.util.List; +import java.util.Optional; import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.ConcurrentMap; import java.util.concurrent.atomic.AtomicBoolean; import org.junit.After; import org.junit.Before; @@ -184,6 +191,113 @@ public void testRun_withoutOutputKey_doesNotSaveState() { assertThat(events.get(0).actions().stateDelta()).isEmpty(); } + /** + * Partial-session bypass: tests that build {@code Session.builder(id).build()} and call {@code + * agent.runAsync} directly (bypassing Runner) must not trip the {@code requireNonNull(appName)} + * in {@code InMemorySessionService.appendEvent}. The event is tracked in {@code session.events()} + * so subsequent steps see prior events. Surfaced originally by orcas {@code LlmAgentActionTest}, + * dataworkeragent and asterix small tests. + */ + @Test + public void testRun_partialSessionWithoutAppName_doesNotThrow() { + Content modelContent = Content.fromParts(Part.fromText("Agent Response")); + TestLlm testLlm = createTestLlm(createLlmResponse(modelContent)); + LlmAgent agent = createTestAgentBuilder(testLlm).build(); + + Session partialSession = Session.builder("session-id").build(); + InvocationContext invocationContext = + InvocationContext.builder() + .sessionService(new InMemorySessionService()) + .agent(agent) + .session(partialSession) + .invocationId("invocation-id") + .runConfig(RunConfig.builder().build()) + .userContent(Content.fromParts(Part.fromText("hello"))) + .build(); + + List events = agent.runAsync(invocationContext).toList().blockingGet(); + + assertThat(events).hasSize(1); + assertThat(events.get(0).content()).hasValue(modelContent); + // Event tracked in session.events() so subsequent steps see prior events. + assertThat(partialSession.events()).hasSize(1); + assertThat(partialSession.events().get(0).id()).isEqualTo(events.get(0).id()); + } + + /** + * Mirrors ads-ux researchagent {@code GenerateReportFromSourcesActionTest}: an action drives + * {@code agent.runAsync(...).collect { ... }} directly (bypassing Runner) with a {@link + * BaseSessionService} stub whose {@code appendEvent} returns a sentinel empty Event. The original + * LLM-derived event must flow downstream with its content intact, not be swapped for the + * service's return value. + */ + @Test + public void testRun_appendEventReturnsSentinel_originalEventFlowsDownstream() { + Content modelContent = Content.fromParts(Part.fromText("generated report content")); + TestLlm testLlm = createTestLlm(createLlmResponse(modelContent)); + LlmAgent agent = createTestAgentBuilder(testLlm).build(); + + BaseSessionService sentinelReturningSessionService = new SentinelReturningSessionService(); + // appName set so safeAppendEvent calls the service (we want to verify its return is ignored). + Session sessionWithAppName = + Session.builder("session-id").appName("test").userId("user").build(); + InvocationContext invocationContext = + InvocationContext.builder() + .sessionService(sentinelReturningSessionService) + .agent(agent) + .session(sessionWithAppName) + .invocationId("invocation-id") + .runConfig(RunConfig.builder().build()) + .userContent(Content.fromParts(Part.fromText("hello"))) + .build(); + + List events = agent.runAsync(invocationContext).toList().blockingGet(); + + assertThat(events).hasSize(1); + // Must be the original LLM-derived event, not the sentinel returned by appendEvent. + assertThat(events.get(0).content()).hasValue(modelContent); + } + + /** + * Stub returning an empty Event from {@code appendEvent}, mirroring the shape used by ads-ux + * researchagent and nbu paisa tests ({@code + * Mockito.when(...).thenReturn(Event.builder().build())}). + */ + private static final class SentinelReturningSessionService implements BaseSessionService { + @Override + public Single createSession( + String appName, String userId, ConcurrentMap state, String sessionId) { + return Single.just(Session.builder("session-id").build()); + } + + @Override + public Maybe getSession( + String appName, String userId, String sessionId, Optional config) { + return Maybe.just(Session.builder(sessionId).build()); + } + + @Override + public Single listSessions(String appName, String userId) { + return Single.just(ListSessionsResponse.builder().build()); + } + + @Override + public Completable deleteSession(String appName, String userId, String sessionId) { + return Completable.complete(); + } + + @Override + public Single listEvents(String appName, String userId, String sessionId) { + return Single.just(ListEventsResponse.builder().build()); + } + + @Override + public Single appendEvent(Session session, Event event) { + // Sentinel return value, mirroring downstream test mocks. + return Single.just(Event.builder().build()); + } + } + @Test public void run_withToolsAndMaxSteps_stopsAfterMaxSteps() { ImmutableMap echoArgs = ImmutableMap.of("arg", "value"); diff --git a/core/src/test/java/com/google/adk/runner/RunnerTest.java b/core/src/test/java/com/google/adk/runner/RunnerTest.java index 95718e3e0..10665aa59 100644 --- a/core/src/test/java/com/google/adk/runner/RunnerTest.java +++ b/core/src/test/java/com/google/adk/runner/RunnerTest.java @@ -38,6 +38,8 @@ import static org.mockito.Mockito.when; import com.google.adk.agents.BaseAgent; +import com.google.adk.agents.Callbacks; +import com.google.adk.agents.Callbacks.AfterModelCallback; import com.google.adk.agents.InvocationContext; import com.google.adk.agents.LiveRequestQueue; import com.google.adk.agents.LlmAgent; @@ -47,9 +49,14 @@ import com.google.adk.artifacts.BaseArtifactService; import com.google.adk.events.Event; import com.google.adk.flows.llmflows.Functions; +import com.google.adk.models.LlmRequest; import com.google.adk.models.LlmResponse; import com.google.adk.plugins.BasePlugin; import com.google.adk.sessions.BaseSessionService; +import com.google.adk.sessions.GetSessionConfig; +import com.google.adk.sessions.InMemorySessionService; +import com.google.adk.sessions.ListEventsResponse; +import com.google.adk.sessions.ListSessionsResponse; import com.google.adk.sessions.Session; import com.google.adk.sessions.SessionKey; import com.google.adk.summarizer.EventsCompactionConfig; @@ -85,6 +92,7 @@ import java.util.Optional; import java.util.UUID; import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.ConcurrentMap; import java.util.concurrent.atomic.AtomicInteger; import java.util.concurrent.atomic.AtomicReference; import org.junit.After; @@ -630,6 +638,9 @@ public void runAsync_passesSessionSnapshotToPersistenceService() { // Mock agent to return one event BaseAgent mockAgent = mock(BaseAgent.class); when(mockAgent.runAsync(any())).thenReturn(Flowable.just(agentEvent)); + // BaseAgent.rootAgent() walks the parent chain and returns `this` for a parent-less agent; + // the mock would return null by default, so stub it to mirror real behavior. + when(mockAgent.rootAgent()).thenReturn(mockAgent); // Mock session service Session testSession = Session.builder("session-id").appName("test").userId("user").build(); @@ -676,6 +687,7 @@ public void runAsync_multiEventExecution_lastUpdateTimeProgresses() throws Excep BaseAgent mockAgent = mock(BaseAgent.class); when(mockAgent.runAsync(any())).thenReturn(Flowable.just(event1, event2)); + when(mockAgent.rootAgent()).thenReturn(mockAgent); // Initial session with timestamp 100 Session testSession = @@ -737,6 +749,7 @@ public void runAsync_concurrentCalls_staleRead() throws Exception { BaseAgent mockAgent = mock(BaseAgent.class); when(mockAgent.runAsync(any())).thenReturn(Flowable.just(agentEvent)); + when(mockAgent.rootAgent()).thenReturn(mockAgent); Session initialSession = Session.builder("session-id").appName("test").userId("user").build(); AtomicReference dbSession = new AtomicReference<>(initialSession); @@ -811,6 +824,7 @@ public void runAsync_concurrentCalls_firstFails_secondSucceeds() throws Exceptio when(mockAgent.runAsync(any())) .thenReturn(Flowable.error(new RuntimeException("Agent failed"))) .thenReturn(Flowable.just(agentEvent)); + when(mockAgent.rootAgent()).thenReturn(mockAgent); Session initialSession = Session.builder("session-id").appName("test").userId("user").build(); AtomicReference dbSession = new AtomicReference<>(initialSession); @@ -860,6 +874,341 @@ public void runAsync_concurrentCalls_firstFails_secondSucceeds() throws Exceptio subscriber2.assertValue(agentEvent); } + /** + * A slow appendEvent must not let the next LLM step start with a stale session missing the + * previous step's function-response event. + */ + @Test + public void runAsync_slowAppendEvent_doesNotCauseStaleSessionInNextStep() throws Exception { + TestLlm raceTestLlm = + createTestLlm( + createFunctionCallLlmResponse("call_1", echoTool.name(), ImmutableMap.of("arg", "v1")), + createTextLlmResponse("done")); + + LlmAgent agentForRace = + createTestAgentBuilder(raceTestLlm).tools(ImmutableList.of(echoTool)).build(); + + BaseSessionService delayedSessionService = + new AppendDelayingSessionService(new InMemorySessionService(), 50); + + Runner runnerForRace = + Runner.builder() + .app(App.builder().name("test").rootAgent(agentForRace).build()) + .sessionService(delayedSessionService) + .build(); + Session raceSession = + runnerForRace.sessionService().createSession("test", "user").blockingGet(); + + var unused = + runnerForRace + .runAsync("user", raceSession.id(), createContent("start")) + .toList() + .blockingGet(); + + ImmutableList requests = ImmutableList.copyOf(raceTestLlm.getRequests()); + assertThat(requests).hasSize(2); + + // Second LLM request must see the function response from step 1. + boolean foundToolResponse = + requests.get(1).contents().stream() + .flatMap(c -> c.parts().stream().flatMap(List::stream)) + .anyMatch(part -> part.functionResponse().isPresent()); + assertThat(foundToolResponse).isTrue(); + } + + /** + * When an LlmAgent transfers control to a sub-LlmAgent, the sub-agent's events flow back up + * through the parent's {@code BaseLlmFlow.run()} pipeline. Each event must be appended to the + * session exactly once. + */ + @Test + public void runAsync_transferToSubAgent_eventsAppendedOnce() throws Exception { + LlmAgent subAgent = + createTestAgentBuilder(createTestLlm(createTextLlmResponse("sub response"))) + .name("sub-agent") + .build(); + + // Force a transfer to sub-agent using an afterModelCallback. + AfterModelCallback transferCallback = + (ctx, response) -> { + ctx.eventActions().setTransferToAgent(subAgent.name()); + return Maybe.empty(); + }; + + TestLlm rootTestLlm = createTestLlm(createTextLlmResponse("initial")); + LlmAgent rootAgent = + createTestAgentBuilder(rootTestLlm) + .subAgents(subAgent) + .afterModelCallback(ImmutableList.of(transferCallback)) + .build(); + + Runner transferRunner = + Runner.builder().app(App.builder().name("test").rootAgent(rootAgent).build()).build(); + Session transferSession = + transferRunner.sessionService().createSession("test", "user").blockingGet(); + + var unused = + transferRunner + .runAsync("user", transferSession.id(), createContent("start")) + .toList() + .blockingGet(); + + Session finalSession = + transferRunner + .sessionService() + .getSession( + transferSession.appName(), + transferSession.userId(), + transferSession.id(), + Optional.empty()) + .blockingGet(); + + // Each event id should appear at most once in the session. + List eventIds = finalSession.events().stream().map(Event::id).toList(); + assertThat(eventIds).containsNoDuplicates(); + } + + /** {@link BaseSessionService} that delays {@link #appendEvent} to surface ordering bugs. */ + private static final class AppendDelayingSessionService implements BaseSessionService { + private final BaseSessionService delegate; + private final long appendDelayMs; + + AppendDelayingSessionService(BaseSessionService delegate, long appendDelayMs) { + this.delegate = delegate; + this.appendDelayMs = appendDelayMs; + } + + // Delegates to the underlying BaseSessionService createSession overload, which is itself + // deprecated; suppressed because the wrapper must preserve the same signature. + @SuppressWarnings("deprecation") + @Override + public Single createSession( + String appName, String userId, ConcurrentMap state, String sessionId) { + return delegate.createSession(appName, userId, state, sessionId); + } + + @Override + public Maybe getSession( + String appName, String userId, String sessionId, Optional config) { + return delegate.getSession(appName, userId, sessionId, config); + } + + @Override + public Single listSessions(String appName, String userId) { + return delegate.listSessions(appName, userId); + } + + @Override + public Completable deleteSession(String appName, String userId, String sessionId) { + return delegate.deleteSession(appName, userId, sessionId); + } + + @Override + public Single listEvents(String appName, String userId, String sessionId) { + return delegate.listEvents(appName, userId, sessionId); + } + + @Override + public Single appendEvent(Session session, Event event) { + return delegate.appendEvent(session, event).delay(appendDelayMs, MILLISECONDS); + } + } + + /** + * Regression test: {@code outputKey} state delta must reach {@code session.state()}. {@code + * LlmAgent} passes {@code maybeSaveOutputToState} as a pre-finalize hook so the delta is + * populated before {@code BaseLlmFlow.run}'s per-step append. Surfaced originally by wallet's + * {@code ActionPredictorAgentTest} and legalholdagent's {@code AgentRunnerTest}. + */ + @Test + public void runAsync_llmAgentWithOutputKey_writesValueToSessionState() { + Content modelContent = Content.fromParts(Part.fromText("Saved output")); + TestLlm outputKeyTestLlm = createTestLlm(createLlmResponse(modelContent)); + LlmAgent outputKeyAgent = + createTestAgentBuilder(outputKeyTestLlm).outputKey("myOutput").build(); + + Runner outputKeyRunner = + Runner.builder().app(App.builder().name("test").rootAgent(outputKeyAgent).build()).build(); + Session outputKeySession = + outputKeyRunner.sessionService().createSession("test", "user").blockingGet(); + + var unused = + outputKeyRunner + .runAsync("user", outputKeySession.id(), createContent("hi")) + .toList() + .blockingGet(); + + Session persistedSession = + outputKeyRunner + .sessionService() + .getSession("test", "user", outputKeySession.id(), Optional.empty()) + .blockingGet(); + assertThat(persistedSession.state()).containsEntry("myOutput", "Saved output"); + } + + /** + * Regression test: each LlmAgent event reaches {@code BaseSessionService.appendEvent} exactly + * once. Expected count for a single-step run is 2 (user msg + agent event); without the + * Runner-side skip for LlmAgent-authored events it would be 3. Surfaced originally by shopping's + * {@code AdkAgentTest.run_withPersistentSessionStorage_usesOrcasStorageSessionService} and nbu + * paisa's {@code AgentRunnerImplTest}. + */ + @Test + public void runAsync_serviceAppendEventCalledOncePerEvent() { + TestLlm idempotencyTestLlm = createTestLlm(createLlmResponse(createContent("from agent"))); + LlmAgent llmAgent = createTestAgentBuilder(idempotencyTestLlm).build(); + + InMemorySessionService realSessionService = new InMemorySessionService(); + BaseSessionService mockSessionService = mock(BaseSessionService.class); + Session realSession = realSessionService.createSession("test", "user").blockingGet(); + when(mockSessionService.createSession(anyString(), anyString())) + .thenReturn(Single.just(realSession)); + when(mockSessionService.getSession(anyString(), anyString(), anyString(), any())) + .thenAnswer(invocation -> Maybe.just(realSession)); + when(mockSessionService.appendEvent(any(), any())) + .thenAnswer( + invocation -> + realSessionService.appendEvent( + invocation.getArgument(0), invocation.getArgument(1))); + + Runner countingRunner = + Runner.builder() + .app(App.builder().name("test").rootAgent(llmAgent).build()) + .sessionService(mockSessionService) + .build(); + + var unused = + countingRunner + .runAsync("user", realSession.id(), createContent("user message")) + .toList() + .blockingGet(); + + // Two calls only: user message + agent response. Without the LlmAgent-aware skip in Runner, + // BaseLlmFlow.run and the outer Runner would each call appendEvent for the agent event -> 3. + verify(mockSessionService, times(2)).appendEvent(any(), any()); + } + + /** + * Regression test for dropped agent-callback events: an {@code afterAgentCallback} that mutates + * state emits a state-delta event authored by the agent but emitted outside {@code + * BaseLlmFlow.run}. It must still be persisted (conformance {@code core/after_agent_callback_001} + * regressed to 2 events instead of 3). Exercised through the Runner, unlike {@code + * CallbacksTest}. + */ + @Test + public void runAsync_afterAgentCallbackWritesState_callbackEventIsPersisted() { + TestLlm callbackTestLlm = createTestLlm(createLlmResponse(createContent("from agent"))); + Callbacks.AfterAgentCallback writeState = + callbackContext -> { + var unused = callbackContext.state().put("after_agent_callback_state_key", "value1"); + return Maybe.empty(); + }; + LlmAgent callbackAgent = + createTestAgentBuilder(callbackTestLlm).afterAgentCallback(writeState).build(); + + Runner callbackRunner = + Runner.builder().app(App.builder().name("test").rootAgent(callbackAgent).build()).build(); + Session session = callbackRunner.sessionService().createSession("test", "user").blockingGet(); + + var unused = + callbackRunner.runAsync("user", session.id(), createContent("hi")).toList().blockingGet(); + + Session persisted = + callbackRunner + .sessionService() + .getSession("test", "user", session.id(), Optional.empty()) + .blockingGet(); + + // user message + model response + after-agent-callback state-delta event (dropped pre-fix -> 2; + // a double append -> 4). + assertThat(persisted.events()).hasSize(3); + Event callbackEvent = persisted.events().get(2); + assertThat(callbackEvent.author()).isEqualTo(callbackAgent.name()); + assertThat(callbackEvent.actions().stateDelta()) + .containsEntry("after_agent_callback_state_key", "value1"); + assertThat(persisted.state()).containsEntry("after_agent_callback_state_key", "value1"); + } + + /** + * Mirrors paisa {@code AgentRunnerImplTest}: pure-mock {@link BaseSessionService} returning a + * sentinel from {@code appendEvent}, verifies exact call count. Expected: 2 (user msg + agent + * event). Without the LlmAgent-aware skip in Runner it would be 3. + */ + @Test + public void runAsync_pureMockSessionService_appendEventCalledOncePerLlmAgentEvent() { + Event sentinelEvent = + Event.builder() + .id("sentinel") + .author("test agent") + .content(createContent("sentinel response")) + .build(); + BaseSessionService pureMockSessionService = mock(BaseSessionService.class); + Session backingSession = Session.builder("session-id").appName("test").userId("user").build(); + when(pureMockSessionService.createSession(anyString(), anyString())) + .thenReturn(Single.just(backingSession)); + when(pureMockSessionService.getSession(anyString(), anyString(), anyString(), any())) + .thenAnswer(invocation -> Maybe.just(backingSession)); + when(pureMockSessionService.appendEvent(any(), any())).thenReturn(Single.just(sentinelEvent)); + + TestLlm pureMockLlm = createTestLlm(createLlmResponse(createContent("from agent"))); + LlmAgent pureMockLlmAgent = createTestAgentBuilder(pureMockLlm).build(); + Runner pureMockRunner = + Runner.builder() + .app(App.builder().name("test").rootAgent(pureMockLlmAgent).build()) + .sessionService(pureMockSessionService) + .build(); + + var unused = + pureMockRunner + .runAsync("user", backingSession.id(), createContent("user message")) + .toList() + .blockingGet(); + + // Exactly 2: user message + agent event. Without the LlmAgent-aware skip in Runner, it would + // be 3 (one extra from Runner's own append on the agent event). + verify(pureMockSessionService, times(2)).appendEvent(any(), any()); + } + + /** + * Multi-turn variant of the paisa shape: tool call + final response = 2 agent events. Expected + * append count is 1 (user msg) + N (agent events), never 1 + 2N. + */ + @Test + public void runAsync_pureMockSessionService_multiStepLlmAgent_appendsExactlyOncePerEvent() { + Event sentinelEvent = Event.builder().id("sentinel").author("test agent").build(); + BaseSessionService pureMockSessionService = mock(BaseSessionService.class); + Session backingSession = Session.builder("session-id").appName("test").userId("user").build(); + when(pureMockSessionService.createSession(anyString(), anyString())) + .thenReturn(Single.just(backingSession)); + when(pureMockSessionService.getSession(anyString(), anyString(), anyString(), any())) + .thenAnswer(invocation -> Maybe.just(backingSession)); + when(pureMockSessionService.appendEvent(any(), any())).thenReturn(Single.just(sentinelEvent)); + + // Function call, then function-response triggers a second LLM call returning the final text. + TestLlm twoStepLlm = + createTestLlm( + createFunctionCallLlmResponse( + "call_1", new EchoTool().name(), ImmutableMap.of("arg", "v1")), + createTextLlmResponse("final answer")); + LlmAgent twoStepLlmAgent = + createTestAgentBuilder(twoStepLlm).tools(ImmutableList.of(new EchoTool())).build(); + Runner twoStepRunner = + Runner.builder() + .app(App.builder().name("test").rootAgent(twoStepLlmAgent).build()) + .sessionService(pureMockSessionService) + .build(); + + var emittedEvents = + twoStepRunner + .runAsync("user", backingSession.id(), createContent("start")) + .toList() + .blockingGet(); + + // 1 (user msg) + N (agent events). Without the LlmAgent-aware skip: 1 + 2N. + int expectedAppendCount = 1 + emittedEvents.size(); + verify(pureMockSessionService, times(expectedAppendCount)).appendEvent(any(), any()); + } + @Test public void runAsync_withSessionKey_success() { var events =