-
Notifications
You must be signed in to change notification settings - Fork 1
Expand file tree
/
Copy pathAbstractRetrievalAugmentedGenerationService.java
More file actions
50 lines (39 loc) · 2.06 KB
/
AbstractRetrievalAugmentedGenerationService.java
File metadata and controls
50 lines (39 loc) · 2.06 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
package de.cofinpro.springai.retrieval_augmented_generation;
import org.springframework.ai.document.Document;
import org.springframework.ai.openai.OpenAiChatClient;
import org.springframework.ai.prompt.Prompt;
import org.springframework.ai.prompt.SystemPromptTemplate;
import org.springframework.ai.prompt.messages.UserMessage;
import org.springframework.ai.reader.JsonReader;
import org.springframework.ai.vectorstore.VectorStore;
import org.springframework.core.io.Resource;
import java.util.List;
import java.util.Map;
import java.util.stream.Collectors;
public abstract class AbstractRetrievalAugmentedGenerationService {
private final VectorStore vectorStore;
private final OpenAiChatClient openAiChatClient;
private final SystemPromptTemplate systemPromptTemplate;
private final Resource bikesResource;
public AbstractRetrievalAugmentedGenerationService(VectorStore vectorStore, OpenAiChatClient openAiChatClient, Resource systemPromptTemplateResource,
Resource bikesResource) {
this.vectorStore = vectorStore;
this.openAiChatClient = openAiChatClient;
systemPromptTemplate = new SystemPromptTemplate(systemPromptTemplateResource);
this.bikesResource = bikesResource;
}
public VectorStore getVectorStore() {
return vectorStore;
}
public void ingestDocuments() {
final var jsonReader = new JsonReader(bikesResource, "name", "price", "shortDescription");
vectorStore.add(jsonReader.get());
}
public String retrievalAugmentedGeneration(String message) {
final var similarDocuments = vectorStore.similaritySearch(message);
final var joinedDocuments = similarDocuments.stream().map(Document::getContent).collect(Collectors.joining("\n"));
final var systemMessage = systemPromptTemplate.createMessage(Map.of("documents", joinedDocuments));
final var prompt = new Prompt(List.of(systemMessage, new UserMessage(message)));
return openAiChatClient.generate(prompt).getGeneration().getContent();
}
}