Skip to content

Commit 44b6dfc

Browse files
author
Andrey Dolgolev
authored
Merge pull request #52 from bugout-dev/fix-search-endpoint
Fix search endpoint
2 parents 022ea3b + 2b8e0f8 commit 44b6dfc

3 files changed

Lines changed: 66 additions & 37 deletions

File tree

spire/journal/actions.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -612,6 +612,21 @@ async def get_journal_entries(
612612
return query.all()
613613

614614

615+
async def get_journal_entry(
616+
db_session: Session, journal_entry_id: UUID
617+
) -> Optional[JournalEntry]:
618+
"""
619+
Returns a journal entry by its id. Raises a JournalEntryNotFound error if no such entry is
620+
found in the database.
621+
"""
622+
journal_entry = (
623+
db_session.query(JournalEntry)
624+
.filter(JournalEntry.id == journal_entry_id)
625+
.one_or_none()
626+
)
627+
return journal_entry
628+
629+
615630
async def delete_journal_entry(
616631
db_session: Session,
617632
journal_spec: JournalSpec,

spire/journal/api.py

Lines changed: 17 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -1280,21 +1280,12 @@ async def update_entry_content(
12801280
es_index = journal.search_index
12811281

12821282
try:
1283-
journal_entry_container = await actions.get_journal_entries(
1284-
db_session,
1285-
journal_spec,
1286-
entry_id,
1287-
request.state.user_group_id_list,
1283+
journal_entry = await actions.get_journal_entry(
1284+
db_session=db_session, journal_entry_id=entry_id
12881285
)
1289-
if len(journal_entry_container) == 0:
1286+
if journal_entry is None:
12901287
raise actions.EntryNotFound()
1291-
assert len(journal_entry_container) == 1
1292-
journal_entry = journal_entry_container[0]
1293-
except actions.JournalNotFound:
1294-
logger.error(
1295-
f"Journal not found with ID={journal_id} for user={request.state.user_id}"
1296-
)
1297-
raise HTTPException(status_code=404)
1288+
12981289
except actions.EntryNotFound:
12991290
logger.error(
13001291
f"Entry not found with ID={entry_id} in journal with ID={journal_id}"
@@ -1810,14 +1801,11 @@ async def create_tags(
18101801

18111802
if es_index is not None:
18121803
try:
1813-
entry_container = await actions.get_journal_entries(
1814-
db_session,
1815-
journal_spec,
1816-
entry_id,
1817-
user_group_id_list=request.state.user_group_id_list,
1804+
journal_entry = await actions.get_journal_entry(
1805+
db_session=db_session, journal_entry_id=entry_id
18181806
)
1819-
assert len(entry_container) == 1
1820-
entry = entry_container[0]
1807+
assert journal_entry != None
1808+
entry = journal_entry
18211809
all_tags = await actions.get_journal_entry_tags(
18221810
db_session,
18231811
journal_spec,
@@ -1963,14 +1951,11 @@ async def update_tags(
19631951

19641952
if es_index is not None:
19651953
try:
1966-
entry_container = await actions.get_journal_entries(
1967-
db_session,
1968-
journal_spec,
1969-
entry_id,
1970-
request.state.user_group_id_list,
1954+
journal_entry = await actions.get_journal_entry(
1955+
db_session=db_session, journal_entry_id=entry_id
19711956
)
1972-
assert len(entry_container) == 1
1973-
entry = entry_container[0]
1957+
assert journal_entry != None
1958+
entry = journal_entry
19741959
all_tags_str = [tag.tag for tag in tags]
19751960
search.new_entry(
19761961
es_client,
@@ -2062,14 +2047,11 @@ async def delete_tag(
20622047

20632048
if es_index is not None:
20642049
try:
2065-
entry_container = await actions.get_journal_entries(
2066-
db_session,
2067-
journal_spec,
2068-
entry_id,
2069-
user_group_id_list=request.state.user_group_id_list,
2050+
journal_entry = await actions.get_journal_entry(
2051+
db_session=db_session, journal_entry_id=entry_id
20702052
)
2071-
assert len(entry_container) == 1
2072-
entry = entry_container[0]
2053+
assert journal_entry != None
2054+
entry = journal_entry
20732055
all_tags = await actions.get_journal_entry_tags(
20742056
db_session,
20752057
journal_spec,
@@ -2165,15 +2147,14 @@ async def search_journal(
21652147
max_score: Optional[float] = 1.0
21662148

21672149
for entry in rows:
2168-
tags: List[str] = [tag.tag for tag in entry.tags]
21692150
entry_url = f"{journal_url}/entries/{str(entry.id)}"
21702151
content_url = f"{entry_url}/content"
21712152
result = JournalSearchResult(
21722153
entry_url=entry_url,
21732154
content_url=content_url,
21742155
title=entry.title,
21752156
content=entry.content,
2176-
tags=tags,
2157+
tags=entry.tags,
21772158
created_at=str(entry.created_at),
21782159
updated_at=str(entry.updated_at),
21792160
score=1.0,

spire/journal/search.py

Lines changed: 34 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,10 +14,11 @@
1414
import elasticsearch
1515
from elasticsearch.client import IndicesClient
1616
from elasticsearch.helpers import bulk
17-
from sqlalchemy import and_, or_, not_
17+
from sqlalchemy import and_, or_, not_, func
1818
from sqlalchemy.sql.elements import BooleanClauseList
1919
from sqlalchemy.orm import Session, Query
2020

21+
2122
from . import actions
2223
from .data import JournalSpec, JournalEntryResponse
2324
from ..db import yield_connection_from_env
@@ -734,6 +735,38 @@ def search_database(
734735
query = query.order_by(JournalEntry.created_at.desc())
735736
num_entries = query.count()
736737
query = query.limit(size).offset(start)
738+
739+
journal_entries_temp = query.cte(name="journal_entries_temp")
740+
741+
entries_ids_with_tags = (
742+
db_session.query(journal_entries_temp.c.id, JournalEntryTag.tag).join(
743+
JournalEntryTag,
744+
JournalEntryTag.journal_entry_id == journal_entries_temp.c.id,
745+
)
746+
).cte(name="entries_ids_with_tags")
747+
748+
aggregated_tags = (
749+
db_session.query(
750+
entries_ids_with_tags.c.id,
751+
func.array_agg(entries_ids_with_tags.c.tag).label("tags"),
752+
)
753+
.group_by(entries_ids_with_tags.c.id)
754+
.cte(name="aggregated_tags")
755+
)
756+
757+
query = db_session.query(
758+
journal_entries_temp.c.id.label("id"),
759+
aggregated_tags.c.tags.label("tags"),
760+
journal_entries_temp.c.title.label("title"),
761+
journal_entries_temp.c.content.label("content"),
762+
journal_entries_temp.c.context_id.label("context_id"),
763+
journal_entries_temp.c.context_url.label("context_url"),
764+
journal_entries_temp.c.context_type.label("context_type"),
765+
journal_entries_temp.c.version_id.label("version_id"),
766+
journal_entries_temp.c.created_at.label("created_at"),
767+
journal_entries_temp.c.updated_at.label("updated_at"),
768+
).join(aggregated_tags, journal_entries_temp.c.id == aggregated_tags.c.id)
769+
737770
rows = query.all()
738771

739772
return num_entries, rows

0 commit comments

Comments
 (0)