-
Notifications
You must be signed in to change notification settings - Fork 24
Expand file tree
/
Copy pathlangchain_rag_agent.py
More file actions
135 lines (106 loc) · 4.15 KB
/
langchain_rag_agent.py
File metadata and controls
135 lines (106 loc) · 4.15 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
import traceback
from urllib.parse import urlparse
import nltk
import requests
import validators
from ai_engine import UAgentResponse, UAgentResponseType
from bs4 import BeautifulSoup
from langchain.prompts import ChatPromptTemplate
from langchain.retrievers import ContextualCompressionRetriever
from langchain.retrievers.document_compressors import CohereRerank
from langchain_community.document_loaders import UnstructuredURLLoader
from langchain_community.vectorstores import FAISS
from langchain_openai import ChatOpenAI, OpenAIEmbeddings
from messages.requests import RagRequest
from uagents import Agent, Context, Protocol
from uagents.setup import fund_agent_if_low
nltk.download("punkt")
nltk.download("averaged_perceptron_tagger")
LANGCHAIN_RAG_SEED = "YOUR_LANGCHAIN_RAG_SEED"
agent = Agent(
name="langchain_rag_agent",
seed=LANGCHAIN_RAG_SEED,
mailbox=True
)
fund_agent_if_low(agent.wallet.address())
docs_bot_protocol = Protocol("DocsBot")
PROMPT_TEMPLATE = """
Answer the question based only on the following context:
{context}
---
Answer the question based on the above context: {question}
"""
def create_retriever(
ctx: Context, url: str, deep_read: bool
) -> ContextualCompressionRetriever:
def scrape(site: str):
if not validators.url(site):
ctx.logger.info(f"Url {site} is not valid")
return
r = requests.get(site)
soup = BeautifulSoup(r.text, "html.parser")
parsed_url = urlparse(url)
base_domain = parsed_url.scheme + "://" + parsed_url.netloc
link_array = soup.find_all("a")
for link in link_array:
href: str = link.get("href", "")
if len(href) == 0:
continue
current_site = f"{base_domain}{href}" if href.startswith("/") else href
if (
".php" in current_site
or "#" in current_site
or not current_site.startswith(url)
or current_site in urls
):
continue
urls.append(current_site)
scrape(current_site)
urls = [url]
if deep_read:
scrape(url)
ctx.logger.info(f"After deep scraping - urls to parse: {urls}")
try:
loader = UnstructuredURLLoader(urls=urls)
docs = loader.load_and_split()
db = FAISS.from_documents(docs, OpenAIEmbeddings())
compression_retriever = ContextualCompressionRetriever(
base_compressor=CohereRerank(), base_retriever=db.as_retriever()
)
return compression_retriever
except Exception as exc:
ctx.logger.error(f"Error happened: {exc}")
traceback.format_exception(exc)
@docs_bot_protocol.on_message(model=RagRequest, replies={UAgentResponse})
async def answer_question(ctx: Context, sender: str, msg: RagRequest):
ctx.logger.info(f"Received message from {sender}, session: {ctx.session}")
ctx.logger.info(
f"input url: {msg.url}, question: {msg.question}, is deep scraping: {msg.deep_read}"
)
parsed_url = urlparse(msg.url)
if not parsed_url.scheme or not parsed_url.netloc:
ctx.logger.error("invalid input url")
await ctx.send(
sender,
UAgentResponse(
message="Input url is not valid",
type=UAgentResponseType.FINAL,
),
)
return
base_domain = parsed_url.scheme + "://" + parsed_url.netloc
ctx.logger.info(f"Base domain: {base_domain}")
retriever = create_retriever(ctx, url=msg.url, deep_read=msg.deep_read == "yes")
compressed_docs = retriever.get_relevant_documents(msg.question)
context_text = "\n\n---\n\n".join([doc.page_content for doc in compressed_docs])
prompt_template = ChatPromptTemplate.from_template(PROMPT_TEMPLATE)
prompt = prompt_template.format(context=context_text, question=msg.question)
model = ChatOpenAI(model="gpt-4o-mini")
response = model.predict(prompt)
ctx.logger.info(f"Response: {response}")
await ctx.send(
sender, UAgentResponse(message=response, type=UAgentResponseType.FINAL)
)
agent.include(docs_bot_protocol, publish_manifest=True)
if __name__ == "__main__":
agent.run()