-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathmain.py
More file actions
executable file
·116 lines (95 loc) · 2.9 KB
/
main.py
File metadata and controls
executable file
·116 lines (95 loc) · 2.9 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
import os
import argparse
from groq import Groq
from langchain.vectorstores.chroma import Chroma
from langchain.prompts import ChatPromptTemplate
from langchain_community.llms.ollama import Ollama
from embedding_functions import embedding_api, embedding_llm
client = Groq(
api_key=os.environ.get("RAG_API_KEY"),
)
instructions = (
"You're an AI chatbot named Docs.AI."
"Maintain a professional and friendly tone."
"Provide as much information as possible."
"Answer any follow-up questions for clarification."
"Aim to provide comprehensive information in your responses."
)
CHROMA_PATH = "chroma"
PROMPT_TEMPLATE = """
Answer the question based only on the following context:
{context}
---
Answer the question based on the above context: {question}
"""
def main():
# Create CLI.
parser = argparse.ArgumentParser()
parser.add_argument("query_text", type=str, help="The query text.")
args = parser.parse_args()
history = []
print(f"Response: {query_api(args.query_text, history)}")
def search_db(query_text: str):
# Prepare the DB.
embedding_function = embedding_api()
db = Chroma(persist_directory=CHROMA_PATH, embedding_function=embedding_function)
# Search the DB.
results = db.similarity_search_with_score(query_text, k=5)
context_text = "\n\n---\n\n".join([doc.page_content for doc, _score in results])
prompt_template = ChatPromptTemplate.from_template(PROMPT_TEMPLATE)
return prompt_template.format(context=context_text, question=query_text)
# Querying the llm using API
def query_api(query_text: str, history: list):
prompt = search_db(query_text)
if not history:
history.append(
{
"role": "assistant",
"content": instructions
}
)
history.append(
{
"role": "user",
"content": prompt,
}
)
chat_completion = client.chat.completions.create(
messages=history,
model="llama-3.1-70b-versatile",
)
response = chat_completion.choices[0].message.content
history.append(
{
"role": "assistant",
"content": response
}
)
return response
# Querying the local llm
def query_llm(query_text: str, history: list):
prompt = search_db(query_text)
if not history:
history.append(
{
"role": "assistant",
"content": instructions
}
)
history.append(
{
"role": "user",
"content": prompt,
}
)
model = Ollama(model="llama3.1")
response = model.invoke(prompt)
history.append(
{
"role": "assistant",
"content": response
}
)
return response
if __name__ == "__main__":
main()