Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion contrib/sarvam-ai/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
<parent>
<groupId>com.google.adk</groupId>
<artifactId>google-adk-parent</artifactId>
<version>1.1.0</version><!-- {x-version-update:google-adk:current} -->
<version>1.3.1-SNAPSHOT</version><!-- {x-version-update:google-adk:current} -->
<relativePath>../../pom.xml</relativePath>
</parent>

Expand Down
137 changes: 41 additions & 96 deletions core/src/test/java/com/google/adk/models/GeminiLlmConnectionTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -21,16 +21,14 @@
import com.google.common.collect.ImmutableList;
import com.google.genai.types.Content;
import com.google.genai.types.FunctionCall;
import com.google.genai.types.GenerateContentResponseUsageMetadata;
import com.google.genai.types.LiveServerContent;
import com.google.genai.types.LiveServerMessage;
import com.google.genai.types.LiveServerSetupComplete;
import com.google.genai.types.LiveServerToolCall;
import com.google.genai.types.LiveServerToolCallCancellation;
import com.google.genai.types.Part;
import com.google.genai.types.UsageMetadata;
import io.reactivex.rxjava3.observers.TestObserver;
import java.util.List;
import java.util.Optional;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.junit.runners.JUnit4;
Expand All @@ -48,14 +46,11 @@ public void convertToServerResponse_withInterruptedTrue_mapsInterruptedField() {
.build();

LiveServerMessage message = LiveServerMessage.builder().serverContent(serverContent).build();
TestObserver<LlmResponse> testObserver = new TestObserver<>();

GeminiLlmConnection.convertToServerResponse(message).subscribe(testObserver);

testObserver.assertValueCount(1);
testObserver.assertComplete();
LlmResponse response = testObserver.values().get(0);
Optional<LlmResponse> result = GeminiLlmConnection.convertToServerResponse(message);

assertThat(result.isPresent()).isTrue();
LlmResponse response = result.get();
assertThat(response.content()).isPresent();
assertThat(response.content().get().text()).isEqualTo("Model response");
assertThat(response.partial()).hasValue(true);
Expand All @@ -74,13 +69,10 @@ public void convertToServerResponse_withInterruptedFalse_mapsInterruptedField()

LiveServerMessage message = LiveServerMessage.builder().serverContent(serverContent).build();

TestObserver<LlmResponse> testObserver = new TestObserver<>();

GeminiLlmConnection.convertToServerResponse(message).subscribe(testObserver);
Optional<LlmResponse> result = GeminiLlmConnection.convertToServerResponse(message);

testObserver.assertValueCount(1);
testObserver.assertComplete();
LlmResponse response = testObserver.values().get(0);
assertThat(result.isPresent()).isTrue();
LlmResponse response = result.get();
assertThat(response.interrupted()).hasValue(false);
assertThat(response.turnComplete()).hasValue(false);
}
Expand All @@ -95,13 +87,10 @@ public void convertToServerResponse_withoutInterruptedField_mapsEmptyOptional()

LiveServerMessage message = LiveServerMessage.builder().serverContent(serverContent).build();

TestObserver<LlmResponse> testObserver = new TestObserver<>();

GeminiLlmConnection.convertToServerResponse(message).subscribe(testObserver);
Optional<LlmResponse> result = GeminiLlmConnection.convertToServerResponse(message);

testObserver.assertValueCount(1);
testObserver.assertComplete();
LlmResponse response = testObserver.values().get(0);
assertThat(result.isPresent()).isTrue();
LlmResponse response = result.get();
assertThat(response.interrupted()).isEmpty();
assertThat(response.turnComplete()).hasValue(true);
}
Expand All @@ -116,13 +105,10 @@ public void convertToServerResponse_withTurnCompleteTrue_mapsPartialFalse() {

LiveServerMessage message = LiveServerMessage.builder().serverContent(serverContent).build();

TestObserver<LlmResponse> testObserver = new TestObserver<>();
Optional<LlmResponse> result = GeminiLlmConnection.convertToServerResponse(message);

GeminiLlmConnection.convertToServerResponse(message).subscribe(testObserver);

testObserver.assertValueCount(1);
testObserver.assertComplete();
LlmResponse response = testObserver.values().get(0);
assertThat(result.isPresent()).isTrue();
LlmResponse response = result.get();
assertThat(response.partial()).hasValue(false);
assertThat(response.turnComplete()).hasValue(true);
}
Expand All @@ -137,13 +123,10 @@ public void convertToServerResponse_withTurnCompleteFalse_mapsPartialTrue() {

LiveServerMessage message = LiveServerMessage.builder().serverContent(serverContent).build();

TestObserver<LlmResponse> testObserver = new TestObserver<>();

GeminiLlmConnection.convertToServerResponse(message).subscribe(testObserver);
Optional<LlmResponse> result = GeminiLlmConnection.convertToServerResponse(message);

testObserver.assertValueCount(1);
testObserver.assertComplete();
LlmResponse response = testObserver.values().get(0);
assertThat(result.isPresent()).isTrue();
LlmResponse response = result.get();
assertThat(response.partial()).hasValue(true);
assertThat(response.turnComplete()).hasValue(false);
}
Expand All @@ -156,13 +139,10 @@ public void convertToServerResponse_withToolCall_mapsContentWithFunctionCall() {

LiveServerMessage message = LiveServerMessage.builder().toolCall(toolCall).build();

TestObserver<LlmResponse> testObserver = new TestObserver<>();

GeminiLlmConnection.convertToServerResponse(message).subscribe(testObserver);
Optional<LlmResponse> result = GeminiLlmConnection.convertToServerResponse(message);

testObserver.assertValueCount(1);
testObserver.assertComplete();
LlmResponse response = testObserver.values().get(0);
assertThat(result.isPresent()).isTrue();
LlmResponse response = result.get();
assertThat(response.content()).isPresent();
assertThat(response.content().get().parts()).isPresent();
assertThat(response.content().get().parts().get()).hasSize(1);
Expand All @@ -172,7 +152,7 @@ public void convertToServerResponse_withToolCall_mapsContentWithFunctionCall() {
}

@Test
public void convertToServerResponse_withUsageMetadata_mapsGenerateResponseUsageMetadata() {
public void convertToServerResponse_withUsageMetadata_returnsEmpty() {
LiveServerMessage message =
LiveServerMessage.builder()
.usageMetadata(
Expand All @@ -183,68 +163,52 @@ public void convertToServerResponse_withUsageMetadata_mapsGenerateResponseUsageM
.build())
.build();

TestObserver<LlmResponse> testObserver = new TestObserver<>();
Optional<LlmResponse> result = GeminiLlmConnection.convertToServerResponse(message);

GeminiLlmConnection.convertToServerResponse(message).subscribe(testObserver);
testObserver.assertValueCount(1);
testObserver.assertComplete();
LlmResponse response = testObserver.values().get(0);
assertThat(response.usageMetadata()).isPresent();
GenerateContentResponseUsageMetadata expectedUsageMetadata =
GenerateContentResponseUsageMetadata.builder()
.promptTokenCount(10)
.candidatesTokenCount(20)
.totalTokenCount(30)
.build();
assertThat(response.usageMetadata()).hasValue(expectedUsageMetadata);
assertThat(result.isPresent()).isFalse();
}

@Test
public void convertToServerResponse_withToolCallCancellation_returnsNoValues() {
public void convertToServerResponse_withToolCallCancellation_returnsInterrupted() {
LiveServerMessage message =
LiveServerMessage.builder()
.toolCallCancellation(LiveServerToolCallCancellation.builder().build())
.build();

TestObserver<LlmResponse> testObserver = new TestObserver<>();
Optional<LlmResponse> result = GeminiLlmConnection.convertToServerResponse(message);

GeminiLlmConnection.convertToServerResponse(message).subscribe(testObserver);
testObserver.assertNoValues();
testObserver.assertComplete();
assertThat(result.isPresent()).isTrue();
LlmResponse response = result.get();
assertThat(response.interrupted()).hasValue(true);
assertThat(response.turnComplete()).hasValue(true);
}

@Test
public void convertToServerResponse_withSetupComplete_returnsNoValues() {
public void convertToServerResponse_withSetupComplete_returnsEmpty() {
LiveServerMessage message =
LiveServerMessage.builder()
.setupComplete(LiveServerSetupComplete.builder().build())
.build();

TestObserver<LlmResponse> testObserver = new TestObserver<>();

GeminiLlmConnection.convertToServerResponse(message).subscribe(testObserver);
Optional<LlmResponse> result = GeminiLlmConnection.convertToServerResponse(message);

testObserver.assertNoValues();
testObserver.assertComplete();
assertThat(result.isPresent()).isFalse();
}

@Test
public void convertToServerResponse_withUnknownMessage_returnsErrorResponse() {
LiveServerMessage message = LiveServerMessage.builder().build();

TestObserver<LlmResponse> testObserver = new TestObserver<>();
Optional<LlmResponse> result = GeminiLlmConnection.convertToServerResponse(message);

GeminiLlmConnection.convertToServerResponse(message).subscribe(testObserver);

testObserver.assertValueCount(1);
testObserver.assertComplete();
LlmResponse response = testObserver.values().get(0);
assertThat(result.isPresent()).isTrue();
LlmResponse response = result.get();
assertThat(response.errorCode()).isPresent();
assertThat(response.errorMessage()).hasValue("Received unknown server message.");
}

@Test
public void convertToServerResponse_withContentAndUsageMetadata_emitsMultiple() {
public void convertToServerResponse_withContentAndUsageMetadata_returnsContentOnly() {
LiveServerContent serverContent =
LiveServerContent.builder()
.modelTurn(Content.fromParts(Part.fromText("Model response")))
Expand All @@ -264,31 +228,12 @@ public void convertToServerResponse_withContentAndUsageMetadata_emitsMultiple()
.usageMetadata(usageMetadata)
.build();

TestObserver<LlmResponse> testObserver = new TestObserver<>();

GeminiLlmConnection.convertToServerResponse(message).subscribe(testObserver);

testObserver.assertValueCount(2);
testObserver.assertComplete();
Optional<LlmResponse> result = GeminiLlmConnection.convertToServerResponse(message);

List<LlmResponse> responses = testObserver.values();

// Check for ServerContent response
LlmResponse contentResponse = responses.get(0);
assertThat(contentResponse.content()).isPresent();
assertThat(contentResponse.content().get().text()).isEqualTo("Model response");
assertThat(contentResponse.usageMetadata()).isEmpty();

// Check for UsageMetadata response
LlmResponse usageResponse = responses.get(1);
assertThat(usageResponse.content()).isEmpty();
assertThat(usageResponse.usageMetadata()).isPresent();
GenerateContentResponseUsageMetadata expectedUsageMetadata =
GenerateContentResponseUsageMetadata.builder()
.promptTokenCount(10)
.candidatesTokenCount(20)
.totalTokenCount(30)
.build();
assertThat(usageResponse.usageMetadata()).hasValue(expectedUsageMetadata);
assertThat(result.isPresent()).isTrue();
LlmResponse response = result.get();
assertThat(response.content()).isPresent();
assertThat(response.content().get().text()).isEqualTo("Model response");
assertThat(response.turnComplete()).hasValue(true);
}
}
Loading