Skip to content
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
Show all changes
21 commits
Select commit Hold shift + click to select a range
59b40f0
REFACTOR Pull apart get_closest_embeddings to make testing easier
sahilds1 Feb 13, 2026
3ffb74a
ADD Add infra required to run pytest
sahilds1 Feb 13, 2026
12b09a7
ADD Start adding tests for embedding_services"
sahilds1 Feb 13, 2026
da9afaa
DOC Add a note about running pytest in the README
sahilds1 Feb 17, 2026
5ce7782
Preload SentenceTransformer model at Django startup before traffic is…
sahilds1 Feb 27, 2026
50a8bd3
Merge branch 'develop' into 441-embedding-models
sahilds1 Mar 11, 2026
795f218
Run python-app workflow on pushes and PRs to develop branch
sahilds1 Mar 11, 2026
d498a00
Pytest won’t automatically discover config files in subdirectories
sahilds1 Mar 19, 2026
6d3d8d1
Merge branch 'develop' into 441-embedding-models
sahilds1 Mar 19, 2026
3824d81
Suppress E402 import violations
sahilds1 Mar 19, 2026
46e9969
Add build_query tests and document coverage gaps in embedding_services
sahilds1 Mar 20, 2026
64a19ef
Fill test gaps in test_embedding_services
sahilds1 Mar 20, 2026
dec3c12
Fix incorrect build_query test assertions
sahilds1 Mar 20, 2026
f9e890a
Guard TransformerModel preload to runserver processes only
sahilds1 Mar 23, 2026
67176a8
Revert GitHub Workflow changes
sahilds1 Mar 25, 2026
d273921
Add section header comments to all four test groups in test_embedding…
sahilds1 Mar 26, 2026
8198574
Document why tests are split by responsibility
sahilds1 Mar 26, 2026
5d8c8b3
Improve logging and comments
sahilds1 Mar 31, 2026
31498dc
Fall back to lazy load using try except block
sahilds1 Mar 31, 2026
a39d33c
Revert settings.py to develop state
sahilds1 Mar 31, 2026
fe1eeca
Manually test fall back to lazy loading
sahilds1 Mar 31, 2026
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 7 additions & 2 deletions .github/workflows/python-app.yml
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,9 @@ name: Python application

on:
push:
branches: [ "listOfMed" ]
branches: [ "develop" ]
pull_request:
branches: [ "listOfMed" ]
branches: [ "develop" ]

permissions:
contents: read
Expand All @@ -27,3 +27,8 @@ jobs:
run: pipx install ruff
- name: Lint code with Ruff
run: ruff check --output-format=github --target-version=py39
- name: Install test dependencies
run: pip install -r server/requirements.txt
# Discover and run all files matching test_*.py or *_test.py under server/
- name: Run tests
run: pytest server/ -v
Comment thread
sahilds1 marked this conversation as resolved.
Outdated
5 changes: 5 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,11 @@ df = pd.read_sql(query, engine)

#### Django REST
- The email and password are set in `server/api/management/commands/createsu.py`
- Backend tests can be run using `pytest` by running the below command inside the running backend container:

```
docker compose exec backend pytest api/ -v
```

## API Documentation

Expand Down
4 changes: 4 additions & 0 deletions server/api/apps.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,3 +4,7 @@
class ApiConfig(AppConfig):
default_auto_field = 'django.db.models.BigAutoField'
name = 'api'

def ready(self):
from .services.sentencetTransformer_model import TransformerModel
TransformerModel.get_instance()
Copy link

Copilot AI Mar 11, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ApiConfig.ready() will only run if this AppConfig is actually used by Django. Right now INSTALLED_APPS appears to include just "api" (not "api.apps.ApiConfig"), and api/__init__.py doesn’t set a default config, so this preload hook may never execute. Consider updating INSTALLED_APPS to reference api.apps.ApiConfig (or otherwise ensuring this config is selected) so the model is preloaded as intended.

Copilot uses AI. Check for mistakes.
Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The model preloads as intended because Django ≥ 3.2 auto discovers AppConfig subclasses in apps.py

Copy link

Copilot AI Mar 11, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Calling TransformerModel.get_instance() unconditionally in ready() will run for every Django startup context (tests, migrations, management commands, autoreload) and can trigger a large model download/init even when no web traffic will be served. Consider gating this preload behind an explicit env flag (or limiting it to the web server entrypoint) to avoid slowing/fragilizing CI and one-off management commands.

Copilot uses AI. Check for mistakes.
Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I added a guard to only preload the model when we're actually going to serve requests

161 changes: 112 additions & 49 deletions server/api/services/embedding_services.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,18 +11,17 @@

logger = logging.getLogger(__name__)

def get_closest_embeddings(
user, message_data, document_name=None, guid=None, num_results=10
):

def build_query(user, embedding_vector, document_name=None, guid=None, num_results=10):
"""
Find the closest embeddings to a given message for a specific user.
Build an unevaluated QuerySet for the closest embeddings.

Parameters
----------
user : User
The user whose uploaded documents will be searched
message_data : str
The input message to find similar embeddings for
embedding_vector : array-like
Pre-computed embedding vector to compare against
document_name : str, optional
Filter results to a specific document name
guid : str, optional
Expand All @@ -32,59 +31,52 @@ def get_closest_embeddings(

Returns
-------
list[dict]
List of dictionaries containing embedding results with keys:
- name: document name
- text: embedded text content
- page_number: page number in source document
- chunk_number: chunk number within the document
- distance: L2 distance from query embedding
- file_id: GUID of the source file
QuerySet
Unevaluated Django QuerySet ordered by L2 distance, sliced to num_results
"""

encoding_start = time.time()
transformerModel = TransformerModel.get_instance().model
embedding_message = transformerModel.encode(message_data)
encoding_time = time.time() - encoding_start

db_query_start = time.time()

# Django QuerySets are lazily evaluated
if user.is_authenticated:
# User sees their own files + files uploaded by superusers
closest_embeddings_query = (
Embeddings.objects.filter(
Q(upload_file__uploaded_by=user) | Q(upload_file__uploaded_by__is_superuser=True)
)
.annotate(
distance=L2Distance("embedding_sentence_transformers", embedding_message)
)
.order_by("distance")
queryset = Embeddings.objects.filter(
Q(upload_file__uploaded_by=user) | Q(upload_file__uploaded_by__is_superuser=True)
)
else:
# Unauthenticated users only see superuser-uploaded files
closest_embeddings_query = (
Embeddings.objects.filter(upload_file__uploaded_by__is_superuser=True)
.annotate(
distance=L2Distance("embedding_sentence_transformers", embedding_message)
)
.order_by("distance")
)
queryset = Embeddings.objects.filter(upload_file__uploaded_by__is_superuser=True)

queryset = (
queryset
.annotate(distance=L2Distance("embedding_sentence_transformers", embedding_vector))
.order_by("distance")
)

# Filtering to a document GUID takes precedence over a document name
if guid:
closest_embeddings_query = closest_embeddings_query.filter(
upload_file__guid=guid
)
queryset = queryset.filter(upload_file__guid=guid)
elif document_name:
closest_embeddings_query = closest_embeddings_query.filter(name=document_name)
queryset = queryset.filter(name=document_name)

# Slicing is equivalent to SQL's LIMIT clause
closest_embeddings_query = closest_embeddings_query[:num_results]
return queryset[:num_results]
Comment on lines 16 to +61
Copy link

Copilot AI Mar 11, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

build_query() introduces/relocates important filtering + precedence logic (authenticated vs unauthenticated visibility; guid-over-document_name; LIMIT slicing), but the new tests only cover evaluate_query and log_usage. Add unit/integration tests covering build_query behavior (e.g., guid precedence and the authenticated/unauthenticated queryset filters) to prevent regressions in access control and filtering.

Copilot uses AI. Check for mistakes.
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Building on Copilot's comment, the specifics of the QuerySet object's structure aren't publicly documented. To inspect the QuerySets, we should actually execute them.

There's a couple ways we handle DB access for these tests. We could use [pytest-django's ``@pytest.mark.django_db](https://pytest-django.readthedocs.io/en/latest/database.html), which wraps the test in a transaction the rolls back automatically afterwards. Django also has a built-in django.test.TestCase`, which does a similar thing.

Copy link
Copy Markdown
Collaborator Author

@sahilds1 sahilds1 Mar 25, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for sharing the docs references -- I added tests for build_query and didn't have to access the database because I was able to inspect which methods and arguments were called on the model ("Embeddings")



def evaluate_query(queryset):
"""
Evaluate a QuerySet and return a list of result dicts.

Parameters
----------
queryset : iterable
Iterable of Embeddings objects (or any objects with the expected attributes)

Returns
-------
list[dict]
List of dicts with keys: name, text, page_number, chunk_number, distance, file_id
"""
# Iterating evaluates the QuerySet and hits the database
# TODO: Research improving the query evaluation performance
results = [
return [
{
"name": obj.name,
"text": obj.text,
Expand All @@ -93,13 +85,36 @@ def get_closest_embeddings(
"distance": obj.distance,
"file_id": obj.upload_file.guid if obj.upload_file else None,
}
for obj in closest_embeddings_query
for obj in queryset
]

db_query_time = time.time() - db_query_start

def log_usage(
results, message_data, user, guid, document_name, num_results, encoding_time, db_query_time
):
"""
Create a SemanticSearchUsage record. Swallows exceptions so search isn't interrupted.

Parameters
----------
results : list[dict]
The search results, each containing a "distance" key
message_data : str
The original search query text
user : User
The user who performed the search
guid : str or None
Document GUID filter used in the search
document_name : str or None
Document name filter used in the search
num_results : int
Number of results requested
encoding_time : float
Time in seconds to encode the query
db_query_time : float
Time in seconds for the database query
"""
try:
# Handle user having no uploaded docs or doc filtering returning no matches
if results:
distances = [r["distance"] for r in results]
SemanticSearchUsage.objects.create(
Expand All @@ -113,11 +128,10 @@ def get_closest_embeddings(
num_results_returned=len(results),
max_distance=max(distances),
median_distance=median(distances),
min_distance=min(distances)
min_distance=min(distances),
)
else:
logger.warning("Semantic search returned no results")

SemanticSearchUsage.objects.create(
query_text=message_data,
user=user if (user and user.is_authenticated) else None,
Expand All @@ -129,9 +143,58 @@ def get_closest_embeddings(
num_results_returned=0,
max_distance=None,
median_distance=None,
min_distance=None
min_distance=None,
)
except Exception as e:
logger.error(f"Failed to create semantic search usage database record: {e}")
Copy link

Copilot AI Mar 27, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

log_usage() swallows all exceptions, but logger.error(f"... {e}") drops the traceback, making it much harder to debug production failures when SemanticSearchUsage writes fail. Prefer logger.exception(...) (or logger.error(..., exc_info=True)) so the stack trace is captured while still not interrupting the request.

Suggested change
except Exception as e:
logger.error(f"Failed to create semantic search usage database record: {e}")
except Exception:
logger.exception("Failed to create semantic search usage database record")

Copilot uses AI. Check for mistakes.
Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Updated the logging to logger.exception



def get_closest_embeddings(
user, message_data, document_name=None, guid=None, num_results=10
):
"""
Find the closest embeddings to a given message for a specific user.

Parameters
----------
user : User
The user whose uploaded documents will be searched
message_data : str
The input message to find similar embeddings for
document_name : str, optional
Filter results to a specific document name
guid : str, optional
Filter results to a specific document GUID (takes precedence over document_name)
num_results : int, default 10
Maximum number of results to return

Returns
-------
list[dict]
List of dictionaries containing embedding results with keys:
- name: document name
- text: embedded text content
- page_number: page number in source document
- chunk_number: chunk number within the document
- distance: L2 distance from query embedding
- file_id: GUID of the source file

Notes
-----
Creates a SemanticSearchUsage record. Swallows exceptions so search isn't interrupted.
"""
encoding_start = time.time()
model = TransformerModel.get_instance().model
embedding_vector = model.encode(message_data)
encoding_time = time.time() - encoding_start

db_query_start = time.time()
queryset = build_query(user, embedding_vector, document_name, guid, num_results)
results = evaluate_query(queryset)
db_query_time = time.time() - db_query_start

log_usage(
results, message_data, user, guid, document_name, num_results, encoding_time, db_query_time
)

return results
85 changes: 85 additions & 0 deletions server/api/services/test_embedding_services.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
from unittest.mock import MagicMock, patch

from api.services.embedding_services import evaluate_query, log_usage


def test_evaluate_query_maps_fields():
obj = MagicMock()
obj.name = "doc.pdf"
obj.text = "some text"
obj.page_num = 3
obj.chunk_number = 1
obj.distance = 0.42
obj.upload_file.guid = "abc-123"

results = evaluate_query([obj])

assert results == [
{
"name": "doc.pdf",
"text": "some text",
"page_number": 3,
"chunk_number": 1,
"distance": 0.42,
"file_id": "abc-123",
}
]


def test_evaluate_query_none_upload_file():
obj = MagicMock()
obj.name = "doc.pdf"
obj.text = "some text"
obj.page_num = 1
obj.chunk_number = 0
obj.distance = 1.0
obj.upload_file = None

results = evaluate_query([obj])

assert results[0]["file_id"] is None


@patch("api.services.embedding_services.SemanticSearchUsage.objects.create")
def test_log_usage_computes_distance_stats(mock_create):
results = [{"distance": 1.0}, {"distance": 3.0}, {"distance": 2.0}]
user = MagicMock(is_authenticated=True)

log_usage(
results,
message_data="test query",
user=user,
guid=None,
document_name=None,
num_results=10,
encoding_time=0.1,
db_query_time=0.2,
)

mock_create.assert_called_once()
kwargs = mock_create.call_args.kwargs
assert kwargs["min_distance"] == 1.0
assert kwargs["max_distance"] == 3.0
assert kwargs["median_distance"] == 2.0
assert kwargs["num_results_returned"] == 3


@patch(
"api.services.embedding_services.SemanticSearchUsage.objects.create",
side_effect=Exception("DB error"),
)
def test_log_usage_swallows_exceptions(mock_create):
results = [{"distance": 1.0}]
user = MagicMock(is_authenticated=True)

# pytest fails the test if it catches unhandled Exception
log_usage(
results,
message_data="test query",
user=user,
guid=None,
document_name=None,
num_results=10,
encoding_time=0.1,
db_query_time=0.2,
)
3 changes: 3 additions & 0 deletions server/pytest.ini
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
[pytest]
DJANGO_SETTINGS_MODULE = balancer_backend.settings
pythonpath = .
4 changes: 3 additions & 1 deletion server/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -19,4 +19,6 @@ PyMuPDF==1.24.0
Pillow
pytesseract
anthropic
drf-spectacular
pytest
pytest-django
Comment on lines +22 to +23
Copy link

Copilot AI Mar 27, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

pytest and pytest-django are added without version pins. Since these packages can introduce breaking behavior across major/minor releases, consider pinning them (or constraining with compatible ranges) to keep test runs reproducible across environments and container rebuilds.

Suggested change
pytest
pytest-django
pytest>=8.0.0,<9.0.0
pytest-django>=4.8.0,<5.0.0

Copilot uses AI. Check for mistakes.
drf-spectacular
Loading