-
-
Notifications
You must be signed in to change notification settings - Fork 16
[#441] [IMPROVE] Preload embedding model at startup #461
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 7 commits
59b40f0
3ffb74a
12b09a7
da9afaa
5ce7782
50a8bd3
795f218
d498a00
6d3d8d1
3824d81
46e9969
64a19ef
dec3c12
f9e890a
67176a8
d273921
8198574
5d8c8b3
31498dc
a39d33c
fe1eeca
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -4,3 +4,7 @@ | |
| class ApiConfig(AppConfig): | ||
| default_auto_field = 'django.db.models.BigAutoField' | ||
| name = 'api' | ||
|
|
||
| def ready(self): | ||
| from .services.sentencetTransformer_model import TransformerModel | ||
| TransformerModel.get_instance() | ||
|
||
| Original file line number | Diff line number | Diff line change | ||||||||
|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -11,18 +11,17 @@ | |||||||||
|
|
||||||||||
| logger = logging.getLogger(__name__) | ||||||||||
|
|
||||||||||
| def get_closest_embeddings( | ||||||||||
| user, message_data, document_name=None, guid=None, num_results=10 | ||||||||||
| ): | ||||||||||
|
|
||||||||||
| def build_query(user, embedding_vector, document_name=None, guid=None, num_results=10): | ||||||||||
| """ | ||||||||||
| Find the closest embeddings to a given message for a specific user. | ||||||||||
| Build an unevaluated QuerySet for the closest embeddings. | ||||||||||
|
|
||||||||||
| Parameters | ||||||||||
| ---------- | ||||||||||
| user : User | ||||||||||
| The user whose uploaded documents will be searched | ||||||||||
| message_data : str | ||||||||||
| The input message to find similar embeddings for | ||||||||||
| embedding_vector : array-like | ||||||||||
| Pre-computed embedding vector to compare against | ||||||||||
| document_name : str, optional | ||||||||||
| Filter results to a specific document name | ||||||||||
| guid : str, optional | ||||||||||
|
|
@@ -32,59 +31,52 @@ def get_closest_embeddings( | |||||||||
|
|
||||||||||
| Returns | ||||||||||
| ------- | ||||||||||
| list[dict] | ||||||||||
| List of dictionaries containing embedding results with keys: | ||||||||||
| - name: document name | ||||||||||
| - text: embedded text content | ||||||||||
| - page_number: page number in source document | ||||||||||
| - chunk_number: chunk number within the document | ||||||||||
| - distance: L2 distance from query embedding | ||||||||||
| - file_id: GUID of the source file | ||||||||||
| QuerySet | ||||||||||
| Unevaluated Django QuerySet ordered by L2 distance, sliced to num_results | ||||||||||
| """ | ||||||||||
|
|
||||||||||
| encoding_start = time.time() | ||||||||||
| transformerModel = TransformerModel.get_instance().model | ||||||||||
| embedding_message = transformerModel.encode(message_data) | ||||||||||
| encoding_time = time.time() - encoding_start | ||||||||||
|
|
||||||||||
| db_query_start = time.time() | ||||||||||
|
|
||||||||||
| # Django QuerySets are lazily evaluated | ||||||||||
| if user.is_authenticated: | ||||||||||
| # User sees their own files + files uploaded by superusers | ||||||||||
| closest_embeddings_query = ( | ||||||||||
| Embeddings.objects.filter( | ||||||||||
| Q(upload_file__uploaded_by=user) | Q(upload_file__uploaded_by__is_superuser=True) | ||||||||||
| ) | ||||||||||
| .annotate( | ||||||||||
| distance=L2Distance("embedding_sentence_transformers", embedding_message) | ||||||||||
| ) | ||||||||||
| .order_by("distance") | ||||||||||
| queryset = Embeddings.objects.filter( | ||||||||||
| Q(upload_file__uploaded_by=user) | Q(upload_file__uploaded_by__is_superuser=True) | ||||||||||
| ) | ||||||||||
| else: | ||||||||||
| # Unauthenticated users only see superuser-uploaded files | ||||||||||
| closest_embeddings_query = ( | ||||||||||
| Embeddings.objects.filter(upload_file__uploaded_by__is_superuser=True) | ||||||||||
| .annotate( | ||||||||||
| distance=L2Distance("embedding_sentence_transformers", embedding_message) | ||||||||||
| ) | ||||||||||
| .order_by("distance") | ||||||||||
| ) | ||||||||||
| queryset = Embeddings.objects.filter(upload_file__uploaded_by__is_superuser=True) | ||||||||||
|
|
||||||||||
| queryset = ( | ||||||||||
| queryset | ||||||||||
| .annotate(distance=L2Distance("embedding_sentence_transformers", embedding_vector)) | ||||||||||
| .order_by("distance") | ||||||||||
| ) | ||||||||||
|
|
||||||||||
| # Filtering to a document GUID takes precedence over a document name | ||||||||||
| if guid: | ||||||||||
| closest_embeddings_query = closest_embeddings_query.filter( | ||||||||||
| upload_file__guid=guid | ||||||||||
| ) | ||||||||||
| queryset = queryset.filter(upload_file__guid=guid) | ||||||||||
| elif document_name: | ||||||||||
| closest_embeddings_query = closest_embeddings_query.filter(name=document_name) | ||||||||||
| queryset = queryset.filter(name=document_name) | ||||||||||
|
|
||||||||||
| # Slicing is equivalent to SQL's LIMIT clause | ||||||||||
| closest_embeddings_query = closest_embeddings_query[:num_results] | ||||||||||
| return queryset[:num_results] | ||||||||||
|
Comment on lines
16
to
+61
|
||||||||||
|
|
||||||||||
|
|
||||||||||
| def evaluate_query(queryset): | ||||||||||
| """ | ||||||||||
| Evaluate a QuerySet and return a list of result dicts. | ||||||||||
|
|
||||||||||
| Parameters | ||||||||||
| ---------- | ||||||||||
| queryset : iterable | ||||||||||
| Iterable of Embeddings objects (or any objects with the expected attributes) | ||||||||||
|
|
||||||||||
| Returns | ||||||||||
| ------- | ||||||||||
| list[dict] | ||||||||||
| List of dicts with keys: name, text, page_number, chunk_number, distance, file_id | ||||||||||
| """ | ||||||||||
| # Iterating evaluates the QuerySet and hits the database | ||||||||||
| # TODO: Research improving the query evaluation performance | ||||||||||
| results = [ | ||||||||||
| return [ | ||||||||||
| { | ||||||||||
| "name": obj.name, | ||||||||||
| "text": obj.text, | ||||||||||
|
|
@@ -93,13 +85,36 @@ def get_closest_embeddings( | |||||||||
| "distance": obj.distance, | ||||||||||
| "file_id": obj.upload_file.guid if obj.upload_file else None, | ||||||||||
| } | ||||||||||
| for obj in closest_embeddings_query | ||||||||||
| for obj in queryset | ||||||||||
| ] | ||||||||||
|
|
||||||||||
| db_query_time = time.time() - db_query_start | ||||||||||
|
|
||||||||||
| def log_usage( | ||||||||||
| results, message_data, user, guid, document_name, num_results, encoding_time, db_query_time | ||||||||||
| ): | ||||||||||
| """ | ||||||||||
| Create a SemanticSearchUsage record. Swallows exceptions so search isn't interrupted. | ||||||||||
|
|
||||||||||
| Parameters | ||||||||||
| ---------- | ||||||||||
| results : list[dict] | ||||||||||
| The search results, each containing a "distance" key | ||||||||||
| message_data : str | ||||||||||
| The original search query text | ||||||||||
| user : User | ||||||||||
| The user who performed the search | ||||||||||
| guid : str or None | ||||||||||
| Document GUID filter used in the search | ||||||||||
| document_name : str or None | ||||||||||
| Document name filter used in the search | ||||||||||
| num_results : int | ||||||||||
| Number of results requested | ||||||||||
| encoding_time : float | ||||||||||
| Time in seconds to encode the query | ||||||||||
| db_query_time : float | ||||||||||
| Time in seconds for the database query | ||||||||||
| """ | ||||||||||
| try: | ||||||||||
| # Handle user having no uploaded docs or doc filtering returning no matches | ||||||||||
| if results: | ||||||||||
| distances = [r["distance"] for r in results] | ||||||||||
| SemanticSearchUsage.objects.create( | ||||||||||
|
|
@@ -113,11 +128,10 @@ def get_closest_embeddings( | |||||||||
| num_results_returned=len(results), | ||||||||||
| max_distance=max(distances), | ||||||||||
| median_distance=median(distances), | ||||||||||
| min_distance=min(distances) | ||||||||||
| min_distance=min(distances), | ||||||||||
| ) | ||||||||||
| else: | ||||||||||
| logger.warning("Semantic search returned no results") | ||||||||||
|
|
||||||||||
| SemanticSearchUsage.objects.create( | ||||||||||
| query_text=message_data, | ||||||||||
| user=user if (user and user.is_authenticated) else None, | ||||||||||
|
|
@@ -129,9 +143,58 @@ def get_closest_embeddings( | |||||||||
| num_results_returned=0, | ||||||||||
| max_distance=None, | ||||||||||
| median_distance=None, | ||||||||||
| min_distance=None | ||||||||||
| min_distance=None, | ||||||||||
| ) | ||||||||||
| except Exception as e: | ||||||||||
| logger.error(f"Failed to create semantic search usage database record: {e}") | ||||||||||
|
||||||||||
| except Exception as e: | |
| logger.error(f"Failed to create semantic search usage database record: {e}") | |
| except Exception: | |
| logger.exception("Failed to create semantic search usage database record") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Updated the logging to logger.exception
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,85 @@ | ||
| from unittest.mock import MagicMock, patch | ||
|
|
||
| from api.services.embedding_services import evaluate_query, log_usage | ||
|
|
||
|
|
||
| def test_evaluate_query_maps_fields(): | ||
| obj = MagicMock() | ||
| obj.name = "doc.pdf" | ||
| obj.text = "some text" | ||
| obj.page_num = 3 | ||
| obj.chunk_number = 1 | ||
| obj.distance = 0.42 | ||
| obj.upload_file.guid = "abc-123" | ||
|
|
||
| results = evaluate_query([obj]) | ||
|
|
||
| assert results == [ | ||
| { | ||
| "name": "doc.pdf", | ||
| "text": "some text", | ||
| "page_number": 3, | ||
| "chunk_number": 1, | ||
| "distance": 0.42, | ||
| "file_id": "abc-123", | ||
| } | ||
| ] | ||
|
|
||
|
|
||
| def test_evaluate_query_none_upload_file(): | ||
| obj = MagicMock() | ||
| obj.name = "doc.pdf" | ||
| obj.text = "some text" | ||
| obj.page_num = 1 | ||
| obj.chunk_number = 0 | ||
| obj.distance = 1.0 | ||
| obj.upload_file = None | ||
|
|
||
| results = evaluate_query([obj]) | ||
|
|
||
| assert results[0]["file_id"] is None | ||
|
|
||
|
|
||
| @patch("api.services.embedding_services.SemanticSearchUsage.objects.create") | ||
| def test_log_usage_computes_distance_stats(mock_create): | ||
| results = [{"distance": 1.0}, {"distance": 3.0}, {"distance": 2.0}] | ||
| user = MagicMock(is_authenticated=True) | ||
|
|
||
| log_usage( | ||
| results, | ||
| message_data="test query", | ||
| user=user, | ||
| guid=None, | ||
| document_name=None, | ||
| num_results=10, | ||
| encoding_time=0.1, | ||
| db_query_time=0.2, | ||
| ) | ||
|
|
||
| mock_create.assert_called_once() | ||
| kwargs = mock_create.call_args.kwargs | ||
| assert kwargs["min_distance"] == 1.0 | ||
| assert kwargs["max_distance"] == 3.0 | ||
| assert kwargs["median_distance"] == 2.0 | ||
| assert kwargs["num_results_returned"] == 3 | ||
|
|
||
|
|
||
| @patch( | ||
| "api.services.embedding_services.SemanticSearchUsage.objects.create", | ||
| side_effect=Exception("DB error"), | ||
| ) | ||
| def test_log_usage_swallows_exceptions(mock_create): | ||
| results = [{"distance": 1.0}] | ||
| user = MagicMock(is_authenticated=True) | ||
|
|
||
| # pytest fails the test if it catches unhandled Exception | ||
| log_usage( | ||
| results, | ||
| message_data="test query", | ||
| user=user, | ||
| guid=None, | ||
| document_name=None, | ||
| num_results=10, | ||
| encoding_time=0.1, | ||
| db_query_time=0.2, | ||
| ) |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,3 @@ | ||
| [pytest] | ||
| DJANGO_SETTINGS_MODULE = balancer_backend.settings | ||
| pythonpath = . |
| Original file line number | Diff line number | Diff line change | ||||||||
|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -19,4 +19,6 @@ PyMuPDF==1.24.0 | |||||||||
| Pillow | ||||||||||
| pytesseract | ||||||||||
| anthropic | ||||||||||
| drf-spectacular | ||||||||||
| pytest | ||||||||||
| pytest-django | ||||||||||
|
Comment on lines
+22
to
+23
|
||||||||||
| pytest | |
| pytest-django | |
| pytest>=8.0.0,<9.0.0 | |
| pytest-django>=4.8.0,<5.0.0 |
Uh oh!
There was an error while loading. Please reload this page.