Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions src/sqlite_rag/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -439,12 +439,12 @@ def reset(
def search(
ctx: typer.Context,
query: str,
limit: int = typer.Option(10, help="Number of results to return"),
limit: int = typer.Option(5, help="Number of results to return"),
debug: bool = typer.Option(
False,
"-d",
"--debug",
help="Print extra debug information with modern formatting",
help="Print extra debug information with sentence-level details",
),
peek: bool = typer.Option(
False, "--peek", help="Print debug information using compact table format"
Expand Down
28 changes: 25 additions & 3 deletions src/sqlite_rag/database.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,16 +88,38 @@ def _create_schema(conn: sqlite3.Connection, settings: Settings):
"""
)

# TODO: remove sequence
Comment thread
danielebriggi marked this conversation as resolved.
Outdated
cursor.execute(
"""
CREATE TABLE IF NOT EXISTS sentences (
id TEXT PRIMARY KEY,
chunk_id INTEGER,
content TEXT,
embedding BLOB,
start_offset INTEGER,
end_offset INTEGER
)
"""
)

cursor.execute(
"""
CREATE VIRTUAL TABLE IF NOT EXISTS chunks_fts USING fts5(content, content='chunks', content_rowid='id');
"""
)

cursor.execute(
f"""
SELECT vector_init('chunks', 'embedding', 'type={settings.vector_type},dimension={settings.embedding_dim},{settings.other_vector_options}');
"""
"""
SELECT vector_init('chunks', 'embedding', ?);
""",
(settings.get_vector_init_options(),),
)
# TODO: same configuration as chunks (or different options?)
cursor.execute(
"""
SELECT vector_init('sentences', 'embedding', ?);
""",
(settings.get_vector_init_options(),),
)

conn.commit()
111 changes: 96 additions & 15 deletions src/sqlite_rag/engine.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
import json
import re
import sqlite3
from pathlib import Path
from typing import List

from sqlite_rag.logger import Logger
from sqlite_rag.models.document_result import DocumentResult
from sqlite_rag.models.sentence_result import SentenceResult
from sqlite_rag.sentence_splitter import SentenceSplitter

from .chunker import Chunker
from .models.document import Document
Expand All @@ -15,10 +17,17 @@ class Engine:
# Considered a good default to normilize the score for RRF
DEFAULT_RRF_K = 60

def __init__(self, conn: sqlite3.Connection, settings: Settings, chunker: Chunker):
def __init__(
self,
conn: sqlite3.Connection,
settings: Settings,
chunker: Chunker,
sentence_splitter: SentenceSplitter,
):
self._conn = conn
self._settings = settings
self._chunker = chunker
self._sentence_splitter = sentence_splitter
self._logger = Logger()

def load_model(self):
Expand All @@ -30,7 +39,7 @@ def load_model(self):

self._conn.execute(
"SELECT llm_model_load(?, ?);",
(self._settings.model_path, self._settings.model_options),
(self._settings.model_path, self._settings.other_model_options),
)

def process(self, document: Document) -> Document:
Expand All @@ -46,6 +55,11 @@ def process(self, document: Document) -> Document:
chunk.title = document.get_title()
chunk.embedding = self.generate_embedding(chunk.get_embedding_text())

sentences = self._sentence_splitter.split(chunk)
for sentence in sentences:
sentence.embedding = self.generate_embedding(sentence.content)
chunk.sentences = sentences

document.chunks = chunks

return document
Expand All @@ -72,6 +86,7 @@ def quantize(self) -> None:
cursor = self._conn.cursor()

cursor.execute("SELECT vector_quantize('chunks', 'embedding');")
cursor.execute("SELECT vector_quantize('sentences', 'embedding');")

self._conn.commit()
self._logger.debug("Quantization completed.")
Expand All @@ -81,21 +96,25 @@ def quantize_preload(self) -> None:
cursor = self._conn.cursor()

cursor.execute("SELECT vector_quantize_preload('chunks', 'embedding');")
cursor.execute("SELECT vector_quantize_preload('sentences', 'embedding');")

def quantize_cleanup(self) -> None:
"""Clean up internal structures related to a previously quantized table/column."""
cursor = self._conn.cursor()

cursor.execute("SELECT vector_quantize_cleanup('chunks', 'embedding');")
cursor.execute("SELECT vector_quantize_cleanup('sentences', 'embedding');")

self._conn.commit()

def create_new_context(self) -> None:
""""""
"""Create a new LLM context with optional runtime overrides."""
cursor = self._conn.cursor()
context_options = self._settings.get_embeddings_context_options()

cursor.execute(
"SELECT llm_context_create(?);", (self._settings.model_context_options,)
"SELECT llm_context_create(?);",
(context_options,),
)

def free_context(self) -> None:
Expand All @@ -104,13 +123,11 @@ def free_context(self) -> None:

cursor.execute("SELECT llm_context_free();")

def search(self, query: str, top_k: int = 10) -> list[DocumentResult]:
def search(
self, semantic_query: str, fts_query, top_k: int = 10
) -> list[DocumentResult]:
"""Semantic search and full-text search sorted with Reciprocal Rank Fusion."""
query_embedding = self.generate_embedding(query)

# Clean up and split into words
# '*' is used to match while typing
query = " ".join(re.findall(r"\b\w+\b", query.lower())) + "*"
query_embedding = self.generate_embedding(semantic_query)

vector_scan_type = (
"vector_quantize_scan"
Expand All @@ -119,8 +136,7 @@ def search(self, query: str, top_k: int = 10) -> list[DocumentResult]:
)

cursor = self._conn.cursor()
# TODO: understand how to sort results depending on the distance metric
# Eg, for cosine distance, higher is better (closer to 1)

cursor.execute(
f"""
-- sqlite-vector KNN vector search results
Expand Down Expand Up @@ -163,6 +179,7 @@ def search(self, query: str, top_k: int = 10) -> list[DocumentResult]:
documents.uri,
documents.content as document_content,
documents.metadata,
chunks.id AS chunk_id,
chunks.content AS snippet,
vec_rank,
fts_rank,
Expand All @@ -176,7 +193,7 @@ def search(self, query: str, top_k: int = 10) -> list[DocumentResult]:
;
""", # nosec B608
{
"query": query,
"query": fts_query,
"query_embedding": query_embedding,
"k": top_k,
"rrf_k": Engine.DEFAULT_RRF_K,
Expand All @@ -186,14 +203,15 @@ def search(self, query: str, top_k: int = 10) -> list[DocumentResult]:
)

rows = cursor.fetchall()
return [
results = [
DocumentResult(
document=Document(
id=row["id"],
uri=row["uri"],
content=row["document_content"],
metadata=json.loads(row["metadata"]) if row["metadata"] else {},
),
chunk_id=row["chunk_id"],
snippet=row["snippet"],
vec_rank=row["vec_rank"],
fts_rank=row["fts_rank"],
Expand All @@ -204,6 +222,69 @@ def search(self, query: str, top_k: int = 10) -> list[DocumentResult]:
for row in rows
]

return results

def search_sentences(
self, query: str, chunk_id: int, top_k: int
) -> List[SentenceResult]:
query_embedding = self.generate_embedding(query)

vector_scan_type = (
"vector_quantize_scan_stream"
if self._settings.quantize_scan
else "vector_full_scan_stream"
)

cursor = self._conn.cursor()

cursor.execute(
f"""
WITH vec_matches AS (
SELECT
v.rowid AS sentence_id,
row_number() OVER (ORDER BY v.distance) AS rank_number,
v.distance,
sentences.content as sentence_content,
sentences.start_offset as sentence_start_offset,
sentences.end_offset as sentence_end_offset
FROM {vector_scan_type}('sentences', 'embedding', :query_embedding) AS v
JOIN sentences ON sentences.rowid = v.rowid
WHERE sentences.chunk_id = :chunk_id
ORDER BY rank_number ASC
LIMIT :top_k
)
SELECT
sentence_id,
sentence_content,
sentence_start_offset,
sentence_end_offset,
rank_number,
distance
FROM vec_matches
""", # nosec B608
{
"query_embedding": query_embedding,
"top_k": top_k,
"chunk_id": chunk_id,
},
)

rows = cursor.fetchall()
sentences = []
for row in rows:
sentences.append(
SentenceResult(
id=row["sentence_id"],
chunk_id=chunk_id,
rank=row["rank_number"],
distance=row["distance"],
start_offset=row["sentence_start_offset"],
end_offset=row["sentence_end_offset"],
)
)

return sentences[:top_k]

def versions(self) -> dict:
"""Get versions of the loaded extensions."""
cursor = self._conn.cursor()
Expand Down
Loading
Loading