Skip to content

Commit 924fb71

Browse files
tilgalascopybara-github
authored andcommitted
feat: add support for gemini models in VertexAiRagRetrieval
PiperOrigin-RevId: 879585211
1 parent 0b9057c commit 924fb71

3 files changed

Lines changed: 23 additions & 4 deletions

File tree

core/pom.xml

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -209,10 +209,29 @@
209209
<artifactId>maven-surefire-plugin</artifactId>
210210
<executions>
211211
<execution>
212+
<!-- run all tests with standard env -->
212213
<id>basic</id>
213214
<goals>
214215
<goal>test</goal>
215216
</goals>
217+
<configuration>
218+
<environmentVariables>
219+
<GOOGLE_GENAI_USE_VERTEXAI>false</GOOGLE_GENAI_USE_VERTEXAI>
220+
</environmentVariables>
221+
</configuration>
222+
</execution>
223+
<execution>
224+
<id>vertex-ai-rag-retrieval</id>
225+
<goals>
226+
<goal>test</goal>
227+
</goals>
228+
<configuration>
229+
<environmentVariables>
230+
<GOOGLE_GENAI_USE_VERTEXAI>true</GOOGLE_GENAI_USE_VERTEXAI>
231+
</environmentVariables>
232+
<!-- run a second variant of the following tests -->
233+
<test>VertexAiRagRetrievalTest#processLlmRequest_gemini2Model_addVertexRagStoreToConfig, VertexAiRagRetrievalTest#processLlmRequest_otherModel_doNotAddVertexRagStoreToConfig</test>
234+
</configuration>
216235
</execution>
217236
<execution>
218237
<id>apigee-llm</id>

core/src/main/java/com/google/adk/tools/retrieval/VertexAiRagRetrieval.java

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020

2121
import com.google.adk.models.LlmRequest;
2222
import com.google.adk.tools.ToolContext;
23+
import com.google.adk.utils.ModelNameUtils;
2324
import com.google.cloud.aiplatform.v1.RagContexts;
2425
import com.google.cloud.aiplatform.v1.RagQuery;
2526
import com.google.cloud.aiplatform.v1.RetrieveContextsRequest;
@@ -105,10 +106,9 @@ public VertexAiRagRetrieval(
105106
public Completable processLlmRequest(
106107
LlmRequest.Builder llmRequestBuilder, ToolContext toolContext) {
107108
LlmRequest llmRequest = llmRequestBuilder.build();
108-
// Use Gemini built-in Vertex AI RAG tool for Gemini 2 models or when using Vertex AI API Model
109+
// Use Gemini built-in Vertex AI RAG tool for Gemini models when using Vertex AI API Model
109110
boolean useVertexAi = Boolean.parseBoolean(System.getenv("GOOGLE_GENAI_USE_VERTEXAI"));
110-
if (useVertexAi
111-
&& (llmRequest.model().isPresent() && llmRequest.model().get().startsWith("gemini-2"))) {
111+
if (useVertexAi && llmRequest.model().filter(ModelNameUtils::isGeminiModel).isPresent()) {
112112
GenerateContentConfig config =
113113
llmRequest.config().orElseGet(() -> GenerateContentConfig.builder().build());
114114
ImmutableList.Builder<Tool> toolsBuilder = ImmutableList.builder();

core/src/test/java/com/google/adk/tools/retrieval/VertexAiRagRetrievalTest.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -208,7 +208,7 @@ public void processLlmRequest_otherModel_doNotAddVertexRagStoreToConfig() {
208208
"projects/test-project/locations/us-central1",
209209
ragResources,
210210
vectorDistanceThreshold);
211-
LlmRequest.Builder llmRequestBuilder = LlmRequest.builder().model("gemini-1-pro");
211+
LlmRequest.Builder llmRequestBuilder = LlmRequest.builder().model("other-model");
212212
ToolContext toolContext = buildToolContext();
213213
GenerateContentConfig initialConfig = GenerateContentConfig.builder().build();
214214
llmRequestBuilder.config(initialConfig);

0 commit comments

Comments
 (0)