Skip to content

Commit 4b26424

Browse files
Update the views
1 parent 759808c commit 4b26424

2 files changed

Lines changed: 35 additions & 353 deletions

File tree

core/views.py

Lines changed: 22 additions & 201 deletions
Original file line numberDiff line numberDiff line change
@@ -1,79 +1,17 @@
1-
import json
2-
from django.shortcuts import render, redirect, get_object_or_404
3-
from django.views.decorators.csrf import csrf_exempt
4-
from django.http import JsonResponse
1+
from django.conf import settings
2+
from django.db import connection as django_connection
53

6-
from rest_framework import viewsets, status
7-
from rest_framework.decorators import action
4+
from rest_framework import status
85
from rest_framework.response import Response
96
from rest_framework.views import APIView
107
from PyPDF2 import PdfReader
118
from langchain_text_splitters import RecursiveCharacterTextSplitter
12-
13-
from .models import DatabaseConnection, SchemaDocument, QueryHistory
14-
from .serializers import (
15-
DatabaseConnectionSerializer,
16-
SchemaDocumentSerializer,
17-
QueryHistorySerializer,
18-
QueryRequestSerializer,
19-
SchemaSearchSerializer
20-
)
21-
from .services import DatabaseService, SQLGenerationService, EmbeddingService
22-
23-
24-
class DatabaseConnectionViewSet(viewsets.ModelViewSet):
25-
queryset = DatabaseConnection.objects.all()
26-
serializer_class = DatabaseConnectionSerializer
27-
28-
@action(detail=True, methods=['post'])
29-
def test(self, request, pk=None):
30-
connection = self.get_object()
31-
connection_data = {
32-
'host': connection.host,
33-
'port': connection.port,
34-
'database': connection.database,
35-
'username': connection.username,
36-
'password': connection.password
37-
}
38-
result = DatabaseService.test_connection(connection_data)
39-
return Response(result)
40-
41-
@action(detail=True, methods=['get'])
42-
def schema(self, request, pk=None):
43-
connection = self.get_object()
44-
connection_data = {
45-
'host': connection.host,
46-
'port': connection.port,
47-
'database': connection.database,
48-
'username': connection.username,
49-
'password': connection.password
50-
}
51-
result = DatabaseService.get_schema(connection_data)
52-
return Response(result)
53-
54-
55-
class SchemaDocumentViewSet(viewsets.ModelViewSet):
56-
queryset = SchemaDocument.objects.all()
57-
serializer_class = SchemaDocumentSerializer
58-
59-
def perform_create(self, serializer):
60-
instance = serializer.save()
61-
try:
62-
text_splitter = RecursiveCharacterTextSplitter(
63-
chunk_size=1000,
64-
chunk_overlap=200
65-
)
66-
chunks = text_splitter.split_text(instance.content)
67-
embeddings = EmbeddingService.embed_documents(chunks)
68-
instance.embeddings = json.dumps({'chunks': chunks, 'embeddings': embeddings})
69-
instance.save()
70-
except Exception:
71-
pass
9+
from pgvector.psycopg2 import register_vector
7210

7311

74-
class QueryHistoryViewSet(viewsets.ReadOnlyModelViewSet):
75-
queryset = QueryHistory.objects.all()
76-
serializer_class = QueryHistorySerializer
12+
from .models import QueryHistory
13+
from .serializers import QueryRequestSerializer, SchemaSearchSerializer
14+
from .services import DatabaseService, SQLGenerationService, EmbeddingService
7715

7816

7917
class GenerateSQLView(APIView):
@@ -83,20 +21,13 @@ def post(self, request):
8321
connection_id = serializer.validated_data['connection_id']
8422
natural_language = serializer.validated_data['natural_language']
8523

86-
try:
87-
connection = DatabaseConnection.objects.get(id=connection_id)
88-
except DatabaseConnection.DoesNotExist:
89-
return Response(
90-
{'error': 'Database connection not found'},
91-
status=status.HTTP_404_NOT_FOUND
92-
)
9324

9425
connection_data = {
95-
'host': connection.host,
96-
'port': connection.port,
97-
'database': connection.database,
98-
'username': connection.username,
99-
'password': connection.password
26+
'host': settings.DATABASE_HOST,
27+
'port': settings.DATABASE_PORT,
28+
'database': settings.DATABASE_NAME,
29+
'username': settings.DATABASE_USER,
30+
'password': settings.DATABASE_PASSWORD
10031
}
10132

10233
schema_result = DatabaseService.get_schema(connection_data)
@@ -135,20 +66,12 @@ def post(self, request):
13566
status=status.HTTP_400_BAD_REQUEST
13667
)
13768

138-
try:
139-
connection = DatabaseConnection.objects.get(id=connection_id)
140-
except DatabaseConnection.DoesNotExist:
141-
return Response(
142-
{'error': 'Database connection not found'},
143-
status=status.HTTP_404_NOT_FOUND
144-
)
145-
14669
connection_data = {
147-
'host': connection.host,
148-
'port': connection.port,
149-
'database': connection.database,
150-
'username': connection.username,
151-
'password': connection.password
70+
'host': settings.DATABASE_HOST,
71+
'port': settings.DATABASE_PORT,
72+
'database': settings.DATABASE_NAME,
73+
'username': settings.DATABASE_USER,
74+
'password': settings.DATABASE_PASSWORD
15275
}
15376

15477
result = DatabaseService.execute_query(connection_data, sql)
@@ -176,20 +99,12 @@ def post(self, request):
17699
status=status.HTTP_400_BAD_REQUEST
177100
)
178101

179-
try:
180-
connection = DatabaseConnection.objects.get(id=connection_id)
181-
except DatabaseConnection.DoesNotExist:
182-
return Response(
183-
{'error': 'Database connection not found'},
184-
status=status.HTTP_404_NOT_FOUND
185-
)
186-
187102
connection_data = {
188-
'host': connection.host,
189-
'port': connection.port,
190-
'database': connection.database,
191-
'username': connection.username,
192-
'password': connection.password
103+
'host': settings.DATABASE_HOST,
104+
'port': settings.DATABASE_PORT,
105+
'database': settings.DATABASE_NAME,
106+
'username': settings.DATABASE_USER,
107+
'password': settings.DATABASE_PASSWORD
193108
}
194109

195110
schema_result = DatabaseService.get_schema(connection_data)
@@ -224,97 +139,3 @@ def post(self, request):
224139
'sql': sql,
225140
'result': result
226141
})
227-
228-
229-
class SchemaSearchView(APIView):
230-
def post(self, request):
231-
serializer = SchemaSearchSerializer(data=request.data)
232-
if serializer.is_valid():
233-
query = serializer.validated_data['query']
234-
connection_id = serializer.validated_data.get('connection_id')
235-
236-
query_embedding = EmbeddingService.embed_text(query)
237-
238-
if connection_id:
239-
docs = SchemaDocument.objects.filter(connection_id=connection_id)
240-
else:
241-
docs = SchemaDocument.objects.all()
242-
243-
results = []
244-
for doc in docs:
245-
if doc.embeddings:
246-
doc_data = json.loads(doc.embeddings)
247-
if 'embeddings' in doc_data:
248-
doc_embeddings = doc_data['embeddings']
249-
chunks = doc_data.get('chunks', [])
250-
251-
similarities = []
252-
for i, emb in enumerate(doc_embeddings):
253-
similarity = sum([a * b for a, b in zip(query_embedding, emb)])
254-
similarities.append((i, similarity))
255-
256-
similarities.sort(key=lambda x: x[1], reverse=True)
257-
258-
for idx, sim in similarities[:3]:
259-
if sim > 0.3:
260-
results.append({
261-
'document': doc.name,
262-
'content': chunks[idx] if idx < len(chunks) else '',
263-
'similarity': sim
264-
})
265-
266-
return Response({'results': results})
267-
268-
return Response(serializer.errors, status=status.HTTP_400_BAD_REQUEST)
269-
270-
271-
class UploadDocumentView(APIView):
272-
def post(self, request):
273-
connection_id = request.data.get('connection_id')
274-
file = request.FILES.get('file')
275-
276-
if not connection_id or not file:
277-
return Response(
278-
{'error': 'connection_id and file are required'},
279-
status=status.HTTP_400_BAD_REQUEST
280-
)
281-
282-
try:
283-
connection = DatabaseConnection.objects.get(id=connection_id)
284-
except DatabaseConnection.DoesNotExist:
285-
return Response(
286-
{'error': 'Database connection not found'},
287-
status=status.HTTP_404_NOT_FOUND
288-
)
289-
290-
if file.name.endswith('.pdf'):
291-
pdf_reader = PdfReader(file)
292-
content = ''
293-
for page in pdf_reader.pages:
294-
content += page.extract_text()
295-
else:
296-
content = file.read().decode('utf-8', errors='ignore')
297-
298-
doc = SchemaDocument.objects.create(
299-
connection=connection,
300-
name=file.name,
301-
content=content
302-
)
303-
304-
try:
305-
text_splitter = RecursiveCharacterTextSplitter(
306-
chunk_size=1000,
307-
chunk_overlap=200
308-
)
309-
chunks = text_splitter.split_text(content)
310-
embeddings = EmbeddingService.embed_documents(chunks)
311-
doc.embeddings = json.dumps({'chunks': chunks, 'embeddings': embeddings})
312-
doc.save()
313-
except Exception:
314-
pass
315-
316-
return Response({
317-
'id': str(doc.id),
318-
'name': doc.name,
319-
'message': 'Document uploaded successfully'
320-
}, status=status.HTTP_201_CREATED)

0 commit comments

Comments
 (0)