diff --git a/contrib/sarvam-ai/pom.xml b/contrib/sarvam-ai/pom.xml index 0c32593f0..199d0222a 100644 --- a/contrib/sarvam-ai/pom.xml +++ b/contrib/sarvam-ai/pom.xml @@ -20,7 +20,7 @@ com.google.adk google-adk-parent - 1.1.0 + 1.3.1-SNAPSHOT ../../pom.xml diff --git a/core/src/test/java/com/google/adk/models/GeminiLlmConnectionTest.java b/core/src/test/java/com/google/adk/models/GeminiLlmConnectionTest.java index a3ac09fe5..d031572aa 100644 --- a/core/src/test/java/com/google/adk/models/GeminiLlmConnectionTest.java +++ b/core/src/test/java/com/google/adk/models/GeminiLlmConnectionTest.java @@ -21,7 +21,6 @@ import com.google.common.collect.ImmutableList; import com.google.genai.types.Content; import com.google.genai.types.FunctionCall; -import com.google.genai.types.GenerateContentResponseUsageMetadata; import com.google.genai.types.LiveServerContent; import com.google.genai.types.LiveServerMessage; import com.google.genai.types.LiveServerSetupComplete; @@ -29,8 +28,7 @@ import com.google.genai.types.LiveServerToolCallCancellation; import com.google.genai.types.Part; import com.google.genai.types.UsageMetadata; -import io.reactivex.rxjava3.observers.TestObserver; -import java.util.List; +import java.util.Optional; import org.junit.Test; import org.junit.runner.RunWith; import org.junit.runners.JUnit4; @@ -48,14 +46,11 @@ public void convertToServerResponse_withInterruptedTrue_mapsInterruptedField() { .build(); LiveServerMessage message = LiveServerMessage.builder().serverContent(serverContent).build(); - TestObserver testObserver = new TestObserver<>(); - GeminiLlmConnection.convertToServerResponse(message).subscribe(testObserver); - - testObserver.assertValueCount(1); - testObserver.assertComplete(); - LlmResponse response = testObserver.values().get(0); + Optional result = GeminiLlmConnection.convertToServerResponse(message); + assertThat(result.isPresent()).isTrue(); + LlmResponse response = result.get(); assertThat(response.content()).isPresent(); assertThat(response.content().get().text()).isEqualTo("Model response"); assertThat(response.partial()).hasValue(true); @@ -74,13 +69,10 @@ public void convertToServerResponse_withInterruptedFalse_mapsInterruptedField() LiveServerMessage message = LiveServerMessage.builder().serverContent(serverContent).build(); - TestObserver testObserver = new TestObserver<>(); - - GeminiLlmConnection.convertToServerResponse(message).subscribe(testObserver); + Optional result = GeminiLlmConnection.convertToServerResponse(message); - testObserver.assertValueCount(1); - testObserver.assertComplete(); - LlmResponse response = testObserver.values().get(0); + assertThat(result.isPresent()).isTrue(); + LlmResponse response = result.get(); assertThat(response.interrupted()).hasValue(false); assertThat(response.turnComplete()).hasValue(false); } @@ -95,13 +87,10 @@ public void convertToServerResponse_withoutInterruptedField_mapsEmptyOptional() LiveServerMessage message = LiveServerMessage.builder().serverContent(serverContent).build(); - TestObserver testObserver = new TestObserver<>(); - - GeminiLlmConnection.convertToServerResponse(message).subscribe(testObserver); + Optional result = GeminiLlmConnection.convertToServerResponse(message); - testObserver.assertValueCount(1); - testObserver.assertComplete(); - LlmResponse response = testObserver.values().get(0); + assertThat(result.isPresent()).isTrue(); + LlmResponse response = result.get(); assertThat(response.interrupted()).isEmpty(); assertThat(response.turnComplete()).hasValue(true); } @@ -116,13 +105,10 @@ public void convertToServerResponse_withTurnCompleteTrue_mapsPartialFalse() { LiveServerMessage message = LiveServerMessage.builder().serverContent(serverContent).build(); - TestObserver testObserver = new TestObserver<>(); + Optional result = GeminiLlmConnection.convertToServerResponse(message); - GeminiLlmConnection.convertToServerResponse(message).subscribe(testObserver); - - testObserver.assertValueCount(1); - testObserver.assertComplete(); - LlmResponse response = testObserver.values().get(0); + assertThat(result.isPresent()).isTrue(); + LlmResponse response = result.get(); assertThat(response.partial()).hasValue(false); assertThat(response.turnComplete()).hasValue(true); } @@ -137,13 +123,10 @@ public void convertToServerResponse_withTurnCompleteFalse_mapsPartialTrue() { LiveServerMessage message = LiveServerMessage.builder().serverContent(serverContent).build(); - TestObserver testObserver = new TestObserver<>(); - - GeminiLlmConnection.convertToServerResponse(message).subscribe(testObserver); + Optional result = GeminiLlmConnection.convertToServerResponse(message); - testObserver.assertValueCount(1); - testObserver.assertComplete(); - LlmResponse response = testObserver.values().get(0); + assertThat(result.isPresent()).isTrue(); + LlmResponse response = result.get(); assertThat(response.partial()).hasValue(true); assertThat(response.turnComplete()).hasValue(false); } @@ -156,13 +139,10 @@ public void convertToServerResponse_withToolCall_mapsContentWithFunctionCall() { LiveServerMessage message = LiveServerMessage.builder().toolCall(toolCall).build(); - TestObserver testObserver = new TestObserver<>(); - - GeminiLlmConnection.convertToServerResponse(message).subscribe(testObserver); + Optional result = GeminiLlmConnection.convertToServerResponse(message); - testObserver.assertValueCount(1); - testObserver.assertComplete(); - LlmResponse response = testObserver.values().get(0); + assertThat(result.isPresent()).isTrue(); + LlmResponse response = result.get(); assertThat(response.content()).isPresent(); assertThat(response.content().get().parts()).isPresent(); assertThat(response.content().get().parts().get()).hasSize(1); @@ -172,7 +152,7 @@ public void convertToServerResponse_withToolCall_mapsContentWithFunctionCall() { } @Test - public void convertToServerResponse_withUsageMetadata_mapsGenerateResponseUsageMetadata() { + public void convertToServerResponse_withUsageMetadata_returnsEmpty() { LiveServerMessage message = LiveServerMessage.builder() .usageMetadata( @@ -183,68 +163,52 @@ public void convertToServerResponse_withUsageMetadata_mapsGenerateResponseUsageM .build()) .build(); - TestObserver testObserver = new TestObserver<>(); + Optional result = GeminiLlmConnection.convertToServerResponse(message); - GeminiLlmConnection.convertToServerResponse(message).subscribe(testObserver); - testObserver.assertValueCount(1); - testObserver.assertComplete(); - LlmResponse response = testObserver.values().get(0); - assertThat(response.usageMetadata()).isPresent(); - GenerateContentResponseUsageMetadata expectedUsageMetadata = - GenerateContentResponseUsageMetadata.builder() - .promptTokenCount(10) - .candidatesTokenCount(20) - .totalTokenCount(30) - .build(); - assertThat(response.usageMetadata()).hasValue(expectedUsageMetadata); + assertThat(result.isPresent()).isFalse(); } @Test - public void convertToServerResponse_withToolCallCancellation_returnsNoValues() { + public void convertToServerResponse_withToolCallCancellation_returnsInterrupted() { LiveServerMessage message = LiveServerMessage.builder() .toolCallCancellation(LiveServerToolCallCancellation.builder().build()) .build(); - TestObserver testObserver = new TestObserver<>(); + Optional result = GeminiLlmConnection.convertToServerResponse(message); - GeminiLlmConnection.convertToServerResponse(message).subscribe(testObserver); - testObserver.assertNoValues(); - testObserver.assertComplete(); + assertThat(result.isPresent()).isTrue(); + LlmResponse response = result.get(); + assertThat(response.interrupted()).hasValue(true); + assertThat(response.turnComplete()).hasValue(true); } @Test - public void convertToServerResponse_withSetupComplete_returnsNoValues() { + public void convertToServerResponse_withSetupComplete_returnsEmpty() { LiveServerMessage message = LiveServerMessage.builder() .setupComplete(LiveServerSetupComplete.builder().build()) .build(); - TestObserver testObserver = new TestObserver<>(); - - GeminiLlmConnection.convertToServerResponse(message).subscribe(testObserver); + Optional result = GeminiLlmConnection.convertToServerResponse(message); - testObserver.assertNoValues(); - testObserver.assertComplete(); + assertThat(result.isPresent()).isFalse(); } @Test public void convertToServerResponse_withUnknownMessage_returnsErrorResponse() { LiveServerMessage message = LiveServerMessage.builder().build(); - TestObserver testObserver = new TestObserver<>(); + Optional result = GeminiLlmConnection.convertToServerResponse(message); - GeminiLlmConnection.convertToServerResponse(message).subscribe(testObserver); - - testObserver.assertValueCount(1); - testObserver.assertComplete(); - LlmResponse response = testObserver.values().get(0); + assertThat(result.isPresent()).isTrue(); + LlmResponse response = result.get(); assertThat(response.errorCode()).isPresent(); assertThat(response.errorMessage()).hasValue("Received unknown server message."); } @Test - public void convertToServerResponse_withContentAndUsageMetadata_emitsMultiple() { + public void convertToServerResponse_withContentAndUsageMetadata_returnsContentOnly() { LiveServerContent serverContent = LiveServerContent.builder() .modelTurn(Content.fromParts(Part.fromText("Model response"))) @@ -264,31 +228,12 @@ public void convertToServerResponse_withContentAndUsageMetadata_emitsMultiple() .usageMetadata(usageMetadata) .build(); - TestObserver testObserver = new TestObserver<>(); - - GeminiLlmConnection.convertToServerResponse(message).subscribe(testObserver); - - testObserver.assertValueCount(2); - testObserver.assertComplete(); + Optional result = GeminiLlmConnection.convertToServerResponse(message); - List responses = testObserver.values(); - - // Check for ServerContent response - LlmResponse contentResponse = responses.get(0); - assertThat(contentResponse.content()).isPresent(); - assertThat(contentResponse.content().get().text()).isEqualTo("Model response"); - assertThat(contentResponse.usageMetadata()).isEmpty(); - - // Check for UsageMetadata response - LlmResponse usageResponse = responses.get(1); - assertThat(usageResponse.content()).isEmpty(); - assertThat(usageResponse.usageMetadata()).isPresent(); - GenerateContentResponseUsageMetadata expectedUsageMetadata = - GenerateContentResponseUsageMetadata.builder() - .promptTokenCount(10) - .candidatesTokenCount(20) - .totalTokenCount(30) - .build(); - assertThat(usageResponse.usageMetadata()).hasValue(expectedUsageMetadata); + assertThat(result.isPresent()).isTrue(); + LlmResponse response = result.get(); + assertThat(response.content()).isPresent(); + assertThat(response.content().get().text()).isEqualTo("Model response"); + assertThat(response.turnComplete()).hasValue(true); } }