1- import json
2- from django .shortcuts import render , redirect , get_object_or_404
3- from django .views .decorators .csrf import csrf_exempt
4- from django .http import JsonResponse
1+ from django .conf import settings
2+ from django .db import connection as django_connection
53
6- from rest_framework import viewsets , status
7- from rest_framework .decorators import action
4+ from rest_framework import status
85from rest_framework .response import Response
96from rest_framework .views import APIView
107from PyPDF2 import PdfReader
118from langchain_text_splitters import RecursiveCharacterTextSplitter
12-
13- from .models import DatabaseConnection , SchemaDocument , QueryHistory
14- from .serializers import (
15- DatabaseConnectionSerializer ,
16- SchemaDocumentSerializer ,
17- QueryHistorySerializer ,
18- QueryRequestSerializer ,
19- SchemaSearchSerializer
20- )
21- from .services import DatabaseService , SQLGenerationService , EmbeddingService
22-
23-
24- class DatabaseConnectionViewSet (viewsets .ModelViewSet ):
25- queryset = DatabaseConnection .objects .all ()
26- serializer_class = DatabaseConnectionSerializer
27-
28- @action (detail = True , methods = ['post' ])
29- def test (self , request , pk = None ):
30- connection = self .get_object ()
31- connection_data = {
32- 'host' : connection .host ,
33- 'port' : connection .port ,
34- 'database' : connection .database ,
35- 'username' : connection .username ,
36- 'password' : connection .password
37- }
38- result = DatabaseService .test_connection (connection_data )
39- return Response (result )
40-
41- @action (detail = True , methods = ['get' ])
42- def schema (self , request , pk = None ):
43- connection = self .get_object ()
44- connection_data = {
45- 'host' : connection .host ,
46- 'port' : connection .port ,
47- 'database' : connection .database ,
48- 'username' : connection .username ,
49- 'password' : connection .password
50- }
51- result = DatabaseService .get_schema (connection_data )
52- return Response (result )
53-
54-
55- class SchemaDocumentViewSet (viewsets .ModelViewSet ):
56- queryset = SchemaDocument .objects .all ()
57- serializer_class = SchemaDocumentSerializer
58-
59- def perform_create (self , serializer ):
60- instance = serializer .save ()
61- try :
62- text_splitter = RecursiveCharacterTextSplitter (
63- chunk_size = 1000 ,
64- chunk_overlap = 200
65- )
66- chunks = text_splitter .split_text (instance .content )
67- embeddings = EmbeddingService .embed_documents (chunks )
68- instance .embeddings = json .dumps ({'chunks' : chunks , 'embeddings' : embeddings })
69- instance .save ()
70- except Exception :
71- pass
9+ from pgvector .psycopg2 import register_vector
7210
7311
74- class QueryHistoryViewSet ( viewsets . ReadOnlyModelViewSet ):
75- queryset = QueryHistory . objects . all ()
76- serializer_class = QueryHistorySerializer
12+ from . models import QueryHistory
13+ from . serializers import QueryRequestSerializer , SchemaSearchSerializer
14+ from . services import DatabaseService , SQLGenerationService , EmbeddingService
7715
7816
7917class GenerateSQLView (APIView ):
@@ -83,20 +21,13 @@ def post(self, request):
8321 connection_id = serializer .validated_data ['connection_id' ]
8422 natural_language = serializer .validated_data ['natural_language' ]
8523
86- try :
87- connection = DatabaseConnection .objects .get (id = connection_id )
88- except DatabaseConnection .DoesNotExist :
89- return Response (
90- {'error' : 'Database connection not found' },
91- status = status .HTTP_404_NOT_FOUND
92- )
9324
9425 connection_data = {
95- 'host' : connection . host ,
96- 'port' : connection . port ,
97- 'database' : connection . database ,
98- 'username' : connection . username ,
99- 'password' : connection . password
26+ 'host' : settings . DATABASE_HOST ,
27+ 'port' : settings . DATABASE_PORT ,
28+ 'database' : settings . DATABASE_NAME ,
29+ 'username' : settings . DATABASE_USER ,
30+ 'password' : settings . DATABASE_PASSWORD
10031 }
10132
10233 schema_result = DatabaseService .get_schema (connection_data )
@@ -135,20 +66,12 @@ def post(self, request):
13566 status = status .HTTP_400_BAD_REQUEST
13667 )
13768
138- try :
139- connection = DatabaseConnection .objects .get (id = connection_id )
140- except DatabaseConnection .DoesNotExist :
141- return Response (
142- {'error' : 'Database connection not found' },
143- status = status .HTTP_404_NOT_FOUND
144- )
145-
14669 connection_data = {
147- 'host' : connection . host ,
148- 'port' : connection . port ,
149- 'database' : connection . database ,
150- 'username' : connection . username ,
151- 'password' : connection . password
70+ 'host' : settings . DATABASE_HOST ,
71+ 'port' : settings . DATABASE_PORT ,
72+ 'database' : settings . DATABASE_NAME ,
73+ 'username' : settings . DATABASE_USER ,
74+ 'password' : settings . DATABASE_PASSWORD
15275 }
15376
15477 result = DatabaseService .execute_query (connection_data , sql )
@@ -176,20 +99,12 @@ def post(self, request):
17699 status = status .HTTP_400_BAD_REQUEST
177100 )
178101
179- try :
180- connection = DatabaseConnection .objects .get (id = connection_id )
181- except DatabaseConnection .DoesNotExist :
182- return Response (
183- {'error' : 'Database connection not found' },
184- status = status .HTTP_404_NOT_FOUND
185- )
186-
187102 connection_data = {
188- 'host' : connection . host ,
189- 'port' : connection . port ,
190- 'database' : connection . database ,
191- 'username' : connection . username ,
192- 'password' : connection . password
103+ 'host' : settings . DATABASE_HOST ,
104+ 'port' : settings . DATABASE_PORT ,
105+ 'database' : settings . DATABASE_NAME ,
106+ 'username' : settings . DATABASE_USER ,
107+ 'password' : settings . DATABASE_PASSWORD
193108 }
194109
195110 schema_result = DatabaseService .get_schema (connection_data )
@@ -224,97 +139,3 @@ def post(self, request):
224139 'sql' : sql ,
225140 'result' : result
226141 })
227-
228-
229- class SchemaSearchView (APIView ):
230- def post (self , request ):
231- serializer = SchemaSearchSerializer (data = request .data )
232- if serializer .is_valid ():
233- query = serializer .validated_data ['query' ]
234- connection_id = serializer .validated_data .get ('connection_id' )
235-
236- query_embedding = EmbeddingService .embed_text (query )
237-
238- if connection_id :
239- docs = SchemaDocument .objects .filter (connection_id = connection_id )
240- else :
241- docs = SchemaDocument .objects .all ()
242-
243- results = []
244- for doc in docs :
245- if doc .embeddings :
246- doc_data = json .loads (doc .embeddings )
247- if 'embeddings' in doc_data :
248- doc_embeddings = doc_data ['embeddings' ]
249- chunks = doc_data .get ('chunks' , [])
250-
251- similarities = []
252- for i , emb in enumerate (doc_embeddings ):
253- similarity = sum ([a * b for a , b in zip (query_embedding , emb )])
254- similarities .append ((i , similarity ))
255-
256- similarities .sort (key = lambda x : x [1 ], reverse = True )
257-
258- for idx , sim in similarities [:3 ]:
259- if sim > 0.3 :
260- results .append ({
261- 'document' : doc .name ,
262- 'content' : chunks [idx ] if idx < len (chunks ) else '' ,
263- 'similarity' : sim
264- })
265-
266- return Response ({'results' : results })
267-
268- return Response (serializer .errors , status = status .HTTP_400_BAD_REQUEST )
269-
270-
271- class UploadDocumentView (APIView ):
272- def post (self , request ):
273- connection_id = request .data .get ('connection_id' )
274- file = request .FILES .get ('file' )
275-
276- if not connection_id or not file :
277- return Response (
278- {'error' : 'connection_id and file are required' },
279- status = status .HTTP_400_BAD_REQUEST
280- )
281-
282- try :
283- connection = DatabaseConnection .objects .get (id = connection_id )
284- except DatabaseConnection .DoesNotExist :
285- return Response (
286- {'error' : 'Database connection not found' },
287- status = status .HTTP_404_NOT_FOUND
288- )
289-
290- if file .name .endswith ('.pdf' ):
291- pdf_reader = PdfReader (file )
292- content = ''
293- for page in pdf_reader .pages :
294- content += page .extract_text ()
295- else :
296- content = file .read ().decode ('utf-8' , errors = 'ignore' )
297-
298- doc = SchemaDocument .objects .create (
299- connection = connection ,
300- name = file .name ,
301- content = content
302- )
303-
304- try :
305- text_splitter = RecursiveCharacterTextSplitter (
306- chunk_size = 1000 ,
307- chunk_overlap = 200
308- )
309- chunks = text_splitter .split_text (content )
310- embeddings = EmbeddingService .embed_documents (chunks )
311- doc .embeddings = json .dumps ({'chunks' : chunks , 'embeddings' : embeddings })
312- doc .save ()
313- except Exception :
314- pass
315-
316- return Response ({
317- 'id' : str (doc .id ),
318- 'name' : doc .name ,
319- 'message' : 'Document uploaded successfully'
320- }, status = status .HTTP_201_CREATED )
0 commit comments