-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathtest.py
More file actions
161 lines (122 loc) · 5.33 KB
/
test.py
File metadata and controls
161 lines (122 loc) · 5.33 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
import streamlit as st
import mysql.connector
from langchain_core.messages import AIMessage, HumanMessage
from langchain_core.prompts import ChatPromptTemplate
from langchain_core.runnables import RunnablePassthrough
from langchain_core.output_parsers import StrOutputParser
from langchain_community.utilities.sql_database import SQLDatabase
from langchain_groq import ChatGroq
from dotenv import load_dotenv
load_dotenv()
def connect_database(hostname: str, port: str, username: str, password: str, database: str) -> SQLDatabase:
# uniform resource identifier
db_uri = f"mysql+mysqlconnector://{username}:{password}@{hostname}:{port}/{database}"
return SQLDatabase.from_uri(db_uri)
def get_sql_chain(db):
prompt_template = """
You are a data analysis chatbot.
Based on the table schema provided below, write a SQL query that answers the question.
Consider the conversation history.
```<SCHEMA> {schema} </SCHEMA>```
Conversation History: {conversation_history}
Write only the SQL query without any additional text.
For example:
Question: Who are the top 3 artists with the most tracks?
SQL Query: SELECT ArtistId, COUNT(*) as track_count FROM Track GROUP BY ArtistId ORDER BY track_count DESC LIMIT 3;
Response Format:
Question: {question}
SQL Query:
"""
# Prompt
prompt = ChatPromptTemplate.from_template(template=prompt_template)
llm = ChatGroq(model="Mixtral-8x7b-32768", temperature=0)
# Function to return the details / schema of the database
def get_schema(_):
return db.get_table_info()
return (
RunnablePassthrough.assign(schema=get_schema)
| prompt
| llm
| StrOutputParser()
)
# Function to convert SQL Query into Natural Language
def get_response(user_query: str, db: SQLDatabase, conversation_history: list):
sql_chain = get_sql_chain(db)
prompt_template = """
You are a data analysis chatbot.
Given the database schema details, question, SQL query, and SQL response,
write a natural language response for the SQL query.
<SCHEMA> {schema} </SCHEMA>
Conversation History: {conversation_history}
SQL Query: <SQL> {sql_query} </SQL>
Question: {question}
SQL Response: {response}
Response Format:
SQL Query:
Natural Language Response:
"""
prompt = ChatPromptTemplate.from_template(template=prompt_template)
llm = ChatGroq(model="Mixtral-8x7b-32768", temperature=0)
chain = (
RunnablePassthrough.assign(sql_query=sql_chain).assign(
schema=lambda _: db.get_table_info(),
response=lambda vars: db.run(vars["sql_query"])
)
| prompt
| llm
| StrOutputParser()
)
return chain.invoke({
"question": user_query,
"conversation_history": conversation_history
})
# Initialize conversation_history
if "conversation_history" not in st.session_state:
st.session_state.conversation_history = [
AIMessage(content="Hello! I'm an AI-SQL assistant. Ask me questions about your MySQL database.")
]
# Page config
st.set_page_config(page_title="AI-SQL ChatBot", page_icon=":speech_balloon:")
st.title("AI - SQL Chat")
# Sidebar
with st.sidebar:
st.subheader("Settings")
st.write("Connect your MYSQL database and chat with it!")
# Connect database
st.text_input("Hostname", value="localhost", key="Host")
st.text_input("Port", value="3306", key="Port")
st.text_input("Username", value="root", key="Username")
st.text_input("Password", type="password", key="Password")
st.text_input("Database", key="Database")
if st.button("Connect"):
with st.spinner("Connecting to database..."):
try:
db = connect_database(
st.session_state["Host"],
st.session_state["Port"],
st.session_state["Username"],
st.session_state["Password"],
st.session_state["Database"]
)
st.session_state.db = db
st.success("Connected to Database!")
except mysql.connector.Error as err:
st.error(f"Error connecting to database: {err}")
# Interactive chat interface
for message in st.session_state.conversation_history:
if isinstance(message, AIMessage):
with st.chat_message("AI"):
st.markdown(message.content)
elif isinstance(message, HumanMessage):
with st.chat_message("Human"):
st.markdown(message.content)
# User Query
user_query = st.chat_input("Question your database...")
if user_query is not None and len(user_query) > 0:
st.session_state.conversation_history.append(HumanMessage(content=user_query))
with st.chat_message("Human"):
st.markdown(user_query)
with st.chat_message("AI"):
response = get_response(user_query, st.session_state.db, st.session_state.conversation_history)
st.markdown(response)
st.session_state.conversation_history.append(AIMessage(content=response))