Skip to content

Commit ee64b35

Browse files
committed
refactor: core/rag docstore, datasource, embedding, rerank, retrieval index processors
1 parent 66fab87 commit ee64b35

24 files changed

Lines changed: 173 additions & 212 deletions

api/core/rag/datasource/keyword/jieba/jieba.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -97,13 +97,13 @@ def search(self, query: str, **kwargs: Any) -> list[Document]:
9797

9898
documents = []
9999

100-
segment_query_stmt = db.session.query(DocumentSegment).where(
100+
segment_query_stmt = select(DocumentSegment).where(
101101
DocumentSegment.dataset_id == self.dataset.id, DocumentSegment.index_node_id.in_(sorted_chunk_indices)
102102
)
103103
if document_ids_filter:
104104
segment_query_stmt = segment_query_stmt.where(DocumentSegment.document_id.in_(document_ids_filter))
105105

106-
segments = db.session.execute(segment_query_stmt).scalars().all()
106+
segments = db.session.scalars(segment_query_stmt).all()
107107
segment_map = {segment.index_node_id: segment for segment in segments}
108108
for chunk_index in sorted_chunk_indices:
109109
segment = segment_map.get(chunk_index)

api/core/rag/datasource/retrieval_service.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -432,10 +432,11 @@ def format_retrieval_documents(cls, documents: list[Document]) -> list[Retrieval
432432
# Batch query dataset documents
433433
dataset_documents = {
434434
doc.id: doc
435-
for doc in db.session.query(DatasetDocument)
436-
.where(DatasetDocument.id.in_(document_ids))
437-
.options(load_only(DatasetDocument.id, DatasetDocument.doc_form, DatasetDocument.dataset_id))
438-
.all()
435+
for doc in db.session.scalars(
436+
select(DatasetDocument)
437+
.where(DatasetDocument.id.in_(document_ids))
438+
.options(load_only(DatasetDocument.id, DatasetDocument.doc_form, DatasetDocument.dataset_id))
439+
).all()
439440
}
440441

441442
valid_dataset_documents = {}

api/core/rag/datasource/vdb/tidb_on_qdrant/tidb_on_qdrant_vector.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -426,11 +426,10 @@ def init_vector(self, dataset: Dataset, attributes: list, embeddings: Embeddings
426426
TIDB_ON_QDRANT_API_KEY = f"{tidb_auth_binding.account}:{tidb_auth_binding.password}"
427427

428428
else:
429-
idle_tidb_auth_binding = (
430-
db.session.query(TidbAuthBinding)
429+
idle_tidb_auth_binding = db.session.scalar(
430+
select(TidbAuthBinding)
431431
.where(TidbAuthBinding.active == False, TidbAuthBinding.status == "ACTIVE")
432432
.limit(1)
433-
.one_or_none()
434433
)
435434
if idle_tidb_auth_binding:
436435
idle_tidb_auth_binding.active = True

api/core/rag/datasource/vdb/vector_factory.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -277,7 +277,7 @@ def search_by_vector(self, query: str, **kwargs: Any) -> list[Document]:
277277
return self._vector_processor.search_by_vector(query_vector, **kwargs)
278278

279279
def search_by_file(self, file_id: str, **kwargs: Any) -> list[Document]:
280-
upload_file: UploadFile | None = db.session.query(UploadFile).where(UploadFile.id == file_id).first()
280+
upload_file: UploadFile | None = db.session.get(UploadFile, file_id)
281281

282282
if not upload_file:
283283
return []

api/core/rag/docstore/dataset_docstore.py

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
from collections.abc import Sequence
44
from typing import Any
55

6-
from sqlalchemy import func, select
6+
from sqlalchemy import delete, func, select
77

88
from core.model_manager import ModelManager
99
from core.rag.index_processor.constant.index_type import IndexTechniqueType
@@ -63,10 +63,8 @@ def docs(self) -> dict[str, Document]:
6363
return output
6464

6565
def add_documents(self, docs: Sequence[Document], allow_update: bool = True, save_child: bool = False):
66-
max_position = (
67-
db.session.query(func.max(DocumentSegment.position))
68-
.where(DocumentSegment.document_id == self._document_id)
69-
.scalar()
66+
max_position = db.session.scalar(
67+
select(func.max(DocumentSegment.position)).where(DocumentSegment.document_id == self._document_id)
7068
)
7169

7270
if max_position is None:
@@ -155,12 +153,14 @@ def add_documents(self, docs: Sequence[Document], allow_update: bool = True, sav
155153
)
156154
if save_child and doc.children:
157155
# delete the existing child chunks
158-
db.session.query(ChildChunk).where(
159-
ChildChunk.tenant_id == self._dataset.tenant_id,
160-
ChildChunk.dataset_id == self._dataset.id,
161-
ChildChunk.document_id == self._document_id,
162-
ChildChunk.segment_id == segment_document.id,
163-
).delete()
156+
db.session.execute(
157+
delete(ChildChunk).where(
158+
ChildChunk.tenant_id == self._dataset.tenant_id,
159+
ChildChunk.dataset_id == self._dataset.id,
160+
ChildChunk.document_id == self._document_id,
161+
ChildChunk.segment_id == segment_document.id,
162+
)
163+
)
164164
# add new child chunks
165165
for position, child in enumerate(doc.children, start=1):
166166
child_segment = ChildChunk(

api/core/rag/embedding/cached_embedding.py

Lines changed: 15 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
from typing import Any, cast
55

66
import numpy as np
7+
from sqlalchemy import select
78
from sqlalchemy.exc import IntegrityError
89

910
from configs import dify_config
@@ -31,14 +32,14 @@ def embed_documents(self, texts: list[str]) -> list[list[float]]:
3132
embedding_queue_indices = []
3233
for i, text in enumerate(texts):
3334
hash = helper.generate_text_hash(text)
34-
embedding = (
35-
db.session.query(Embedding)
36-
.filter_by(
37-
model_name=self._model_instance.model_name,
38-
hash=hash,
39-
provider_name=self._model_instance.provider,
35+
embedding = db.session.scalar(
36+
select(Embedding)
37+
.where(
38+
Embedding.model_name == self._model_instance.model_name,
39+
Embedding.hash == hash,
40+
Embedding.provider_name == self._model_instance.provider,
4041
)
41-
.first()
42+
.limit(1)
4243
)
4344
if embedding:
4445
text_embeddings[i] = embedding.get_embedding()
@@ -112,14 +113,14 @@ def embed_multimodal_documents(self, multimodel_documents: list[dict]) -> list[l
112113
embedding_queue_indices = []
113114
for i, multimodel_document in enumerate(multimodel_documents):
114115
file_id = multimodel_document["file_id"]
115-
embedding = (
116-
db.session.query(Embedding)
117-
.filter_by(
118-
model_name=self._model_instance.model_name,
119-
hash=file_id,
120-
provider_name=self._model_instance.provider,
116+
embedding = db.session.scalar(
117+
select(Embedding)
118+
.where(
119+
Embedding.model_name == self._model_instance.model_name,
120+
Embedding.hash == file_id,
121+
Embedding.provider_name == self._model_instance.provider,
121122
)
122-
.first()
123+
.limit(1)
123124
)
124125
if embedding:
125126
multimodel_embeddings[i] = embedding.get_embedding()

api/core/rag/extractor/notion_extractor.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,8 @@
88
from configs import dify_config
99
from core.rag.extractor.extractor_base import BaseExtractor
1010
from core.rag.models.document import Document
11+
from sqlalchemy import update
12+
1113
from extensions.ext_database import db
1214
from models.dataset import Document as DocumentModel
1315
from services.datasource_provider_service import DatasourceProviderService
@@ -346,9 +348,11 @@ def update_last_edited_time(self, document_model: DocumentModel | None):
346348
if data_source_info:
347349
data_source_info["last_edited_time"] = last_edited_time
348350

349-
db.session.query(DocumentModel).filter_by(id=document_model.id).update(
350-
{DocumentModel.data_source_info: json.dumps(data_source_info)}
351-
) # type: ignore
351+
db.session.execute(
352+
update(DocumentModel)
353+
.where(DocumentModel.id == document_model.id)
354+
.values(data_source_info=json.dumps(data_source_info))
355+
)
352356
db.session.commit()
353357

354358
def get_notion_last_edited_time(self) -> str:

api/core/rag/index_processor/index_processor_base.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,8 @@
2626
FixedRecursiveCharacterTextSplitter,
2727
)
2828
from core.rag.splitter.text_splitter import TextSplitter
29+
from sqlalchemy import select
30+
2931
from extensions.ext_database import db
3032
from extensions.ext_storage import storage
3133
from models import Account, ToolFile
@@ -200,7 +202,7 @@ def _get_content_files(self, document: Document, current_user: Account | None =
200202

201203
# Get unique IDs for database query
202204
unique_upload_file_ids = list(set(upload_file_id_list))
203-
upload_files = db.session.query(UploadFile).where(UploadFile.id.in_(unique_upload_file_ids)).all()
205+
upload_files = db.session.scalars(select(UploadFile).where(UploadFile.id.in_(unique_upload_file_ids))).all()
204206

205207
# Create a mapping from ID to UploadFile for quick lookup
206208
upload_file_map = {upload_file.id: upload_file for upload_file in upload_files}
@@ -312,7 +314,7 @@ def _download_tool_file(self, tool_file_id: str, current_user: Account) -> str |
312314
"""
313315
from services.file_service import FileService
314316

315-
tool_file = db.session.query(ToolFile).where(ToolFile.id == tool_file_id).first()
317+
tool_file = db.session.get(ToolFile, tool_file_id)
316318
if not tool_file:
317319
return None
318320
blob = storage.load_once(tool_file.file_key)

api/core/rag/index_processor/processor/paragraph_index_processor.py

Lines changed: 8 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,8 @@
2929
from core.rag.retrieval.retrieval_methods import RetrievalMethod
3030
from core.tools.utils.text_processing_utils import remove_leading_symbols
3131
from core.workflow.file_reference import build_file_reference
32+
from sqlalchemy import select
33+
3234
from extensions.ext_database import db
3335
from factories.file_factory import build_from_mapping
3436
from graphon.file import File, FileTransferMethod, FileType, file_manager
@@ -144,14 +146,12 @@ def clean(self, dataset: Dataset, node_ids: list[str] | None, with_keywords: boo
144146
if delete_summaries:
145147
if node_ids:
146148
# Find segments by index_node_id
147-
segments = (
148-
db.session.query(DocumentSegment)
149-
.filter(
149+
segments = db.session.scalars(
150+
select(DocumentSegment).where(
150151
DocumentSegment.dataset_id == dataset.id,
151152
DocumentSegment.index_node_id.in_(node_ids),
152153
)
153-
.all()
154-
)
154+
).all()
155155
segment_ids = [segment.id for segment in segments]
156156
if segment_ids:
157157
SummaryIndexService.delete_summaries_for_segments(dataset, segment_ids)
@@ -536,11 +536,9 @@ def _extract_images_from_text(tenant_id: str, text: str) -> list[File]:
536536

537537
# Get unique IDs for database query
538538
unique_upload_file_ids = list(set(upload_file_id_list))
539-
upload_files = (
540-
db.session.query(UploadFile)
541-
.where(UploadFile.id.in_(unique_upload_file_ids), UploadFile.tenant_id == tenant_id)
542-
.all()
543-
)
539+
upload_files = db.session.scalars(
540+
select(UploadFile).where(UploadFile.id.in_(unique_upload_file_ids), UploadFile.tenant_id == tenant_id)
541+
).all()
544542

545543
# Create File objects from UploadFile records
546544
file_objects = []

api/core/rag/index_processor/processor/parent_child_index_processor.py

Lines changed: 16 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,8 @@
2222
from core.rag.index_processor.index_processor_base import BaseIndexProcessor, SummaryIndexSettingDict
2323
from core.rag.models.document import AttachmentDocument, ChildDocument, Document, ParentChildStructureChunk
2424
from core.rag.retrieval.retrieval_methods import RetrievalMethod
25+
from sqlalchemy import delete, select
26+
2527
from extensions.ext_database import db
2628
from libs import helper
2729
from models import Account
@@ -177,36 +179,39 @@ def clean(self, dataset: Dataset, node_ids: list[str] | None, with_keywords: boo
177179
child_node_ids = precomputed_child_node_ids
178180
else:
179181
# Fallback to original query (may fail if segments are already deleted)
180-
child_node_ids = (
181-
db.session.query(ChildChunk.index_node_id)
182+
rows = db.session.execute(
183+
select(ChildChunk.index_node_id)
182184
.join(DocumentSegment, ChildChunk.segment_id == DocumentSegment.id)
183185
.where(
184186
DocumentSegment.dataset_id == dataset.id,
185187
DocumentSegment.index_node_id.in_(node_ids),
186188
ChildChunk.dataset_id == dataset.id,
187189
)
188-
.all()
189-
)
190-
child_node_ids = [child_node_id[0] for child_node_id in child_node_ids if child_node_id[0]]
190+
).all()
191+
child_node_ids = [row[0] for row in rows if row[0]]
191192

192193
# Delete from vector index
193194
if child_node_ids:
194195
vector.delete_by_ids(child_node_ids)
195196

196197
# Delete from database
197198
if delete_child_chunks and child_node_ids:
198-
db.session.query(ChildChunk).where(
199-
ChildChunk.dataset_id == dataset.id, ChildChunk.index_node_id.in_(child_node_ids)
200-
).delete(synchronize_session=False)
199+
db.session.execute(
200+
delete(ChildChunk).where(
201+
ChildChunk.dataset_id == dataset.id, ChildChunk.index_node_id.in_(child_node_ids)
202+
)
203+
)
201204
db.session.commit()
202205
else:
203206
vector.delete()
204207

205208
if delete_child_chunks:
206209
# Use existing compound index: (tenant_id, dataset_id, ...)
207-
db.session.query(ChildChunk).where(
208-
ChildChunk.tenant_id == dataset.tenant_id, ChildChunk.dataset_id == dataset.id
209-
).delete(synchronize_session=False)
210+
db.session.execute(
211+
delete(ChildChunk).where(
212+
ChildChunk.tenant_id == dataset.tenant_id, ChildChunk.dataset_id == dataset.id
213+
)
214+
)
210215
db.session.commit()
211216

212217
def retrieve(

0 commit comments

Comments
 (0)