Skip to content

Commit 0b8b35b

Browse files
google-genai-botcopybara-github
authored andcommitted
fix: propagate thought signatures in BaseLlmFlow
PiperOrigin-RevId: 836673368
1 parent 626f171 commit 0b8b35b

1 file changed

Lines changed: 46 additions & 5 deletions

File tree

core/src/main/java/com/google/adk/flows/llmflows/BaseLlmFlow.java

Lines changed: 46 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,9 @@
4040
import com.google.adk.tools.ToolContext;
4141
import com.google.common.collect.ImmutableList;
4242
import com.google.common.collect.Iterables;
43+
import com.google.genai.types.Content;
4344
import com.google.genai.types.FunctionResponse;
45+
import com.google.genai.types.Part;
4446
import io.opentelemetry.api.trace.Span;
4547
import io.opentelemetry.api.trace.StatusCode;
4648
import io.opentelemetry.context.Context;
@@ -56,6 +58,7 @@
5658
import java.util.List;
5759
import java.util.Optional;
5860
import java.util.Set;
61+
import java.util.concurrent.atomic.AtomicReference;
5962
import org.slf4j.Logger;
6063
import org.slf4j.LoggerFactory;
6164

@@ -138,7 +141,8 @@ protected Flowable<Event> postprocess(
138141
InvocationContext context,
139142
Event baseEventForLlmResponse,
140143
LlmRequest llmRequest,
141-
LlmResponse llmResponse) {
144+
LlmResponse llmResponse,
145+
AtomicReference<Optional<byte[]>> thoughtSignatureState) {
142146

143147
List<Iterable<Event>> eventIterables = new ArrayList<>();
144148
Single<LlmResponse> currentLlmResponse = Single.just(llmResponse);
@@ -167,7 +171,8 @@ protected Flowable<Event> postprocess(
167171
}
168172

169173
Event modelResponseEvent =
170-
buildModelResponseEvent(baseEventForLlmResponse, llmRequest, updatedResponse);
174+
buildModelResponseEvent(
175+
baseEventForLlmResponse, llmRequest, updatedResponse, thoughtSignatureState);
171176

172177
Flowable<Event> modelEventStream = Flowable.just(modelResponseEvent);
173178

@@ -335,6 +340,8 @@ private Single<LlmResponse> handleAfterModelCallback(
335340
*/
336341
private Flowable<Event> runOneStep(InvocationContext context) {
337342
LlmRequest initialLlmRequest = LlmRequest.builder().build();
343+
AtomicReference<Optional<byte[]>> thoughtSignatureState =
344+
new AtomicReference<>(Optional.empty());
338345

339346
return preprocess(context, initialLlmRequest)
340347
.flatMapPublisher(
@@ -373,7 +380,8 @@ private Flowable<Event> runOneStep(InvocationContext context) {
373380
context,
374381
mutableEventTemplate,
375382
llmRequestAfterPreprocess,
376-
llmResponse)
383+
llmResponse,
384+
thoughtSignatureState)
377385
.doFinally(
378386
() -> {
379387
String oldId = mutableEventTemplate.id();
@@ -453,6 +461,8 @@ public Flowable<Event> run(InvocationContext invocationContext) {
453461
@Override
454462
public Flowable<Event> runLive(InvocationContext invocationContext) {
455463
LlmRequest llmRequest = LlmRequest.builder().build();
464+
AtomicReference<Optional<byte[]>> thoughtSignatureState =
465+
new AtomicReference<>(Optional.empty());
456466

457467
return preprocess(invocationContext, llmRequest)
458468
.flatMapPublisher(
@@ -568,7 +578,8 @@ public void onError(Throwable e) {
568578
invocationContext,
569579
baseEventForThisLlmResponse,
570580
llmRequestAfterPreprocess,
571-
llmResponse);
581+
llmResponse,
582+
thoughtSignatureState);
572583
})
573584
.flatMap(
574585
event -> {
@@ -624,7 +635,10 @@ public void onError(Throwable e) {
624635
* @return A fully constructed {@link Event} representing the LLM response.
625636
*/
626637
private Event buildModelResponseEvent(
627-
Event baseEventForLlmResponse, LlmRequest llmRequest, LlmResponse llmResponse) {
638+
Event baseEventForLlmResponse,
639+
LlmRequest llmRequest,
640+
LlmResponse llmResponse,
641+
AtomicReference<Optional<byte[]>> thoughtSignatureState) {
628642
Event.Builder eventBuilder =
629643
baseEventForLlmResponse.toBuilder()
630644
.content(llmResponse.content())
@@ -638,6 +652,33 @@ private Event buildModelResponseEvent(
638652
.finishReason(llmResponse.finishReason())
639653
.usageMetadata(llmResponse.usageMetadata());
640654

655+
List<Part> parts = llmResponse.content().flatMap(Content::parts).orElse(ImmutableList.of());
656+
if (!parts.isEmpty()) {
657+
boolean signaturePresentInResponse = false;
658+
for (Part part : parts) {
659+
if (part.thoughtSignature().isPresent()) {
660+
signaturePresentInResponse = true;
661+
thoughtSignatureState.set(part.thoughtSignature());
662+
}
663+
}
664+
665+
if (!signaturePresentInResponse
666+
&& parts.stream().anyMatch(p -> p.functionCall().isPresent())) {
667+
ArrayList<Part> updatedParts = new ArrayList<>();
668+
for (Part part : parts) {
669+
Part partToAdd = part;
670+
if (part.functionCall().isPresent() && part.thoughtSignature().isEmpty()) {
671+
Optional<byte[]> signatureToApply = thoughtSignatureState.get();
672+
if (signatureToApply.isPresent()) {
673+
partToAdd = part.toBuilder().thoughtSignature(signatureToApply.get()).build();
674+
}
675+
}
676+
updatedParts.add(partToAdd);
677+
}
678+
eventBuilder.content(llmResponse.content().get().toBuilder().parts(updatedParts).build());
679+
}
680+
}
681+
641682
Event event = eventBuilder.build();
642683

643684
if (!event.functionCalls().isEmpty()) {

0 commit comments

Comments
 (0)