diff --git a/python/pyproject.toml b/python/pyproject.toml index e999808..3395d25 100644 --- a/python/pyproject.toml +++ b/python/pyproject.toml @@ -37,6 +37,8 @@ requires = ["maturin>=1.4"] build-backend = "maturin" [project.optional-dependencies] +openai = ["openai>=1.0"] +sentence-transformers = ["sentence-transformers>=2.0"] graph = ["lance-graph==0.5.4"] lance-python = [ "pylance>=7.0.0,<8", diff --git a/python/python/lance_context/__init__.py b/python/python/lance_context/__init__.py index a73a265..b6a356c 100644 --- a/python/python/lance_context/__init__.py +++ b/python/python/lance_context/__init__.py @@ -3,7 +3,8 @@ from .api import ( # pyright: ignore[reportMissingImports] AsyncContext, Context, + EmbeddingProvider, __version__, ) -__all__ = ["AsyncContext", "Context", "__version__"] +__all__ = ["AsyncContext", "Context", "EmbeddingProvider", "__version__"] diff --git a/python/python/lance_context/api.py b/python/python/lance_context/api.py index 957d6a5..7f0e1c0 100644 --- a/python/python/lance_context/api.py +++ b/python/python/lance_context/api.py @@ -10,8 +10,9 @@ from ._internal import Context as _Context # pyright: ignore[reportMissingImports] from ._internal import version as _version # pyright: ignore[reportMissingImports] +from .embeddings import EmbeddingProvider, _build_provider -__all__ = ["AsyncContext", "Context", "__version__"] +__all__ = ["AsyncContext", "Context", "EmbeddingProvider", "__version__"] __version__ = _version() @@ -313,6 +314,8 @@ def __init__( id_index_type: str | None = None, embedding_dim: int | None = None, distance_metric: str | None = None, + # --- Embedding provider. --- + embedding_provider: EmbeddingProvider | dict[str, Any] | None = None, ) -> None: options = _merge_storage_options( storage_options, @@ -350,6 +353,13 @@ def __init__( else: self._inner = _Context.create(uri) + if isinstance(embedding_provider, dict): + self._embedding_provider: EmbeddingProvider | None = _build_provider( + embedding_provider + ) + else: + self._embedding_provider = embedding_provider + @classmethod def create( cls, @@ -370,6 +380,7 @@ def create( id_index_type: str | None = None, embedding_dim: int | None = None, distance_metric: str | None = None, + embedding_provider: EmbeddingProvider | dict[str, Any] | None = None, ) -> Context: return cls( uri, @@ -388,6 +399,7 @@ def create( id_index_type=id_index_type, embedding_dim=embedding_dim, distance_metric=distance_metric, + embedding_provider=embedding_provider, ) def uri(self) -> str: @@ -427,6 +439,9 @@ def add( if content_type is None: content_type = data_type payload, resolved_type = _normalize_content(content, content_type) + provider = getattr(self, "_embedding_provider", None) + if embedding is None and provider is not None and isinstance(payload, str): + embedding = provider.embed_texts([payload])[0] self._inner.add( role, payload, @@ -477,6 +492,9 @@ def upsert( content_type = data_type payload, resolved_type = _normalize_content(content, content_type) + provider = getattr(self, "_embedding_provider", None) + if embedding is None and provider is not None and isinstance(payload, str): + embedding = provider.embed_texts([payload])[0] result = self._inner.upsert( role, payload, @@ -610,14 +628,32 @@ def add_many(self, records: Iterable[Mapping[str, Any]]) -> None: } ) + self._auto_embed_batch(normalized) self._inner.add_many(normalized) + def _auto_embed_batch(self, records: list[dict[str, Any]]) -> None: + """Embed text records without an embedding in one provider call.""" + provider = getattr(self, "_embedding_provider", None) + if provider is None: + return + indices = [ + i + for i, r in enumerate(records) + if r.get("embedding") is None and isinstance(r.get("content"), str) + ] + if not indices: + return + texts = [records[i]["content"] for i in indices] + vectors = provider.embed_texts(texts) + for i, vec in zip(indices, vectors): + records[i]["embedding"] = vec + def snapshot(self, label: str | None = None) -> str: return self._inner.snapshot(label) def fork(self, branch_name: str) -> Context: inner = self._inner.fork(branch_name) - return self._from_inner(inner) + return self._from_inner(inner, self._embedding_provider) def checkout(self, version_id: int | str) -> None: self._inner.checkout(int(version_id)) @@ -632,6 +668,9 @@ def search( include_retired: bool = False, include_relationships: bool = False, ) -> list[dict[str, Any]]: + provider = getattr(self, "_embedding_provider", None) + if isinstance(query, str) and provider is not None: + query = provider.embed_texts([query])[0] vector = _coerce_vector(query) results = self._inner.search( vector, @@ -815,9 +854,12 @@ def __repr__(self) -> str: ) @classmethod - def _from_inner(cls, inner: _Context) -> Context: + def _from_inner( + cls, inner: _Context, embedding_provider: EmbeddingProvider | None = None + ) -> Context: obj = cls.__new__(cls) obj._inner = inner + obj._embedding_provider = embedding_provider return obj diff --git a/python/python/lance_context/embeddings.py b/python/python/lance_context/embeddings.py new file mode 100644 index 0000000..8dc2321 --- /dev/null +++ b/python/python/lance_context/embeddings.py @@ -0,0 +1,115 @@ +from __future__ import annotations + +from typing import Any, Protocol, runtime_checkable + +__all__ = ["EmbeddingProvider", "OpenAIProvider", "SentenceTransformersProvider"] + + +@runtime_checkable +class EmbeddingProvider(Protocol): + """Minimal interface for pluggable embedding providers. + + Implement this protocol to plug in any embedding backend. Providers + receive lists so :meth:`Context.add_many` can embed in a single call. + """ + + @property + def dims(self) -> int: + """Dimensionality of the vectors produced by this provider.""" + ... + + def embed_texts(self, texts: list[str]) -> list[list[float]]: + """Return one embedding vector per input text.""" + ... + + +def _build_provider(config: dict[str, Any]) -> EmbeddingProvider: + """Instantiate a built-in provider from a config dict. + + Config keys: + provider: "openai" | "sentence-transformers" + model: model name / ID (optional, uses provider default) + **kwargs: forwarded to the provider constructor + """ + name = config.get("provider") + if not name: + raise ValueError("embedding config must include a 'provider' key") + kwargs = {k: v for k, v in config.items() if k != "provider"} + try: + cls = _REGISTRY[name] + except KeyError: + raise ValueError( + f"Unknown embedding provider {name!r}. Available: {sorted(_REGISTRY)}" + ) from None + return cls(**kwargs) + + +class OpenAIProvider: + """Embedding provider backed by the OpenAI embeddings API. + + Requires ``pip install lance-context[openai]``. + """ + + def __init__( + self, model: str = "text-embedding-3-small", **client_kwargs: Any + ) -> None: + try: + from openai import OpenAI # pyright: ignore[reportMissingImports] + except ImportError as exc: + raise ImportError( + "openai is required for the OpenAI embedding provider. " + "Install it with: pip install lance-context[openai]" + ) from exc + self._client = OpenAI(**client_kwargs) + self._model = model + self._dims: int | None = None + + @property + def dims(self) -> int: + if self._dims is None: + # Probe dims with a dummy call the first time they're requested. + result = self._client.embeddings.create(input=["probe"], model=self._model) + self._dims = len(result.data[0].embedding) + return self._dims + + def embed_texts(self, texts: list[str]) -> list[list[float]]: + result = self._client.embeddings.create(input=texts, model=self._model) + ordered = sorted(result.data, key=lambda d: d.index) + vecs = [item.embedding for item in ordered] + if self._dims is None and vecs: + self._dims = len(vecs[0]) + return vecs + + +class SentenceTransformersProvider: + """Embedding provider backed by sentence-transformers (local / offline). + + Requires ``pip install lance-context[sentence-transformers]``. + """ + + def __init__(self, model: str = "all-MiniLM-L6-v2", **model_kwargs: Any) -> None: + try: + from sentence_transformers import ( # pyright: ignore[reportMissingImports] + SentenceTransformer, + ) + except ImportError as exc: + raise ImportError( + "sentence-transformers is required for the " + "SentenceTransformers provider. " + "Install it with: pip install lance-context[sentence-transformers]" + ) from exc + self._model = SentenceTransformer(model, **model_kwargs) + + @property + def dims(self) -> int: + return int(self._model.get_sentence_embedding_dimension()) + + def embed_texts(self, texts: list[str]) -> list[list[float]]: + vectors = self._model.encode(texts, convert_to_numpy=True) + return [v.tolist() for v in vectors] + + +_REGISTRY: dict[str, type] = { + "openai": OpenAIProvider, + "sentence-transformers": SentenceTransformersProvider, +} diff --git a/python/tests/test_embeddings.py b/python/tests/test_embeddings.py new file mode 100644 index 0000000..2d35711 --- /dev/null +++ b/python/tests/test_embeddings.py @@ -0,0 +1,305 @@ +"""Tests for the pluggable embedding provider registry. + +Uses a stub provider so no external dependencies are needed. +""" + +from __future__ import annotations + +from typing import Any + +import pytest +from lance_context.api import Context +from lance_context.embeddings import EmbeddingProvider, _build_provider + +# --------------------------------------------------------------------------- +# Stub provider +# --------------------------------------------------------------------------- + + +class StubProvider: + """Deterministic fake that returns [index, 0.0] per text.""" + + def __init__(self) -> None: + self.calls: list[list[str]] = [] + + @property + def dims(self) -> int: + return 2 + + def embed_texts(self, texts: list[str]) -> list[list[float]]: + self.calls.append(list(texts)) + return [[float(i), 0.0] for i in range(len(texts))] + + +def _ctx_with_provider(provider: StubProvider) -> Context: + ctx = Context.__new__(Context) + ctx._inner = _DummyInner() # type: ignore[attr-defined] + ctx._embedding_provider = provider + return ctx + + +def _ctx_no_provider() -> Context: + ctx = Context.__new__(Context) + ctx._inner = _DummyInner() # type: ignore[attr-defined] + ctx._embedding_provider = None + return ctx + + +class _DummyInner: + """Minimal inner stub that records calls.""" + + def __init__(self) -> None: + self.add_calls: list[dict[str, Any]] = [] + self.add_many_calls: list[list[dict[str, Any]]] = [] + self.search_calls: list[tuple[Any, ...]] = [] + self.upsert_calls: list[dict[str, Any]] = [] + + def add( # noqa: PLR0913 + self, + role: str, + content: Any, + data_type: Any, + embedding: Any, + bot_id: Any, + session_id: Any, + external_id: Any, + metadata_json: Any, + expires_at: Any = None, + retention_policy: Any = None, + lifecycle_status: Any = None, + retired_at: Any = None, + retired_reason: Any = None, + supersedes_id: Any = None, + superseded_by_id: Any = None, + relationships_json: Any = None, + ) -> None: + self.add_calls.append( + {"role": role, "content": content, "embedding": embedding} + ) + + def add_many(self, records: list[dict[str, Any]]) -> None: + self.add_many_calls.append(records) + + def search( + self, + vector: list[float], + limit: Any, + filters_json: Any, + include_expired: bool = False, + include_retired: bool = False, + include_relationships: bool = False, + ) -> list[Any]: + self.search_calls.append( + ( + vector, + limit, + filters_json, + include_expired, + include_retired, + include_relationships, + ) + ) + return [] + + def upsert( # noqa: PLR0913 + self, + role: str, + content: Any, + data_type: Any, + embedding: Any, + bot_id: Any, + session_id: Any, + external_id: Any, + metadata_json: Any, + expires_at: Any = None, + retention_policy: Any = None, + lifecycle_status: Any = None, + retired_at: Any = None, + retired_reason: Any = None, + relationships_json: Any = None, + key: str = "external_id", + ) -> dict[str, Any]: + self.upsert_calls.append( + {"role": role, "content": content, "embedding": embedding} + ) + return { + "inserted": True, + "replaced_id": None, + "version": 1, + "record": { + "id": "x", + "external_id": external_id, + "run_id": "r", + "bot_id": None, + "session_id": None, + "role": role, + "content_type": data_type, + "text_payload": content, + "binary_payload": None, + "embedding": embedding, + "created_at": "2024-01-01T00:00:00Z", + "state_metadata": None, + "metadata": None, + "relationships": [], + "expires_at": None, + "retention_policy": None, + "lifecycle_status": "active", + "retired_at": None, + "retired_reason": None, + "supersedes_id": None, + "superseded_by_id": None, + }, + } + + +# --------------------------------------------------------------------------- +# Protocol conformance +# --------------------------------------------------------------------------- + + +def test_stub_satisfies_protocol(): + assert isinstance(StubProvider(), EmbeddingProvider) + + +# --------------------------------------------------------------------------- +# add() +# --------------------------------------------------------------------------- + + +def test_add_auto_embeds_text(): + provider = StubProvider() + ctx = _ctx_with_provider(provider) + ctx.add("user", "hello world") + assert provider.calls == [["hello world"]] + assert ctx._inner.add_calls[0]["embedding"] == [0.0, 0.0] # type: ignore[attr-defined] + + +def test_add_manual_embedding_takes_precedence(): + provider = StubProvider() + ctx = _ctx_with_provider(provider) + ctx.add("user", "hello", embedding=[9.0, 9.0]) + assert provider.calls == [] + assert ctx._inner.add_calls[0]["embedding"] == [9.0, 9.0] # type: ignore[attr-defined] + + +def test_add_no_provider_leaves_embedding_none(): + ctx = _ctx_no_provider() + ctx.add("user", "hello") + assert ctx._inner.add_calls[0]["embedding"] is None # type: ignore[attr-defined] + + +def test_add_binary_content_not_auto_embedded(): + provider = StubProvider() + ctx = _ctx_with_provider(provider) + ctx.add("user", b"\x00\x01\x02", content_type="application/octet-stream") + assert provider.calls == [] + assert ctx._inner.add_calls[0]["embedding"] is None # type: ignore[attr-defined] + + +# --------------------------------------------------------------------------- +# add_many() +# --------------------------------------------------------------------------- + + +def test_add_many_batch_embeds_text_records(): + provider = StubProvider() + ctx = _ctx_with_provider(provider) + ctx.add_many( + [ + {"role": "user", "content": "first"}, + {"role": "assistant", "content": "second"}, + ] + ) + # Both texts sent in one provider call. + assert provider.calls == [["first", "second"]] + records = ctx._inner.add_many_calls[0] # type: ignore[attr-defined] + assert records[0]["embedding"] == [0.0, 0.0] + assert records[1]["embedding"] == [1.0, 0.0] + + +def test_add_many_skips_records_with_manual_embedding(): + provider = StubProvider() + ctx = _ctx_with_provider(provider) + ctx.add_many( + [ + {"role": "user", "content": "first", "embedding": [5.0, 5.0]}, + {"role": "assistant", "content": "second"}, + ] + ) + # Only the second record is sent for embedding. + assert provider.calls == [["second"]] + records = ctx._inner.add_many_calls[0] # type: ignore[attr-defined] + assert records[0]["embedding"] == [5.0, 5.0] + assert records[1]["embedding"] == [0.0, 0.0] + + +def test_add_many_no_provider_leaves_embeddings_unchanged(): + ctx = _ctx_no_provider() + ctx.add_many([{"role": "user", "content": "hello"}]) + records = ctx._inner.add_many_calls[0] # type: ignore[attr-defined] + assert records[0]["embedding"] is None + + +# --------------------------------------------------------------------------- +# search() +# --------------------------------------------------------------------------- + + +def test_search_auto_embeds_string_query(): + provider = StubProvider() + ctx = _ctx_with_provider(provider) + ctx.search("spring travel") + assert provider.calls == [["spring travel"]] + vector_passed = ctx._inner.search_calls[0][0] # type: ignore[attr-defined] + assert vector_passed == [0.0, 0.0] + + +def test_search_vector_query_bypasses_provider(): + provider = StubProvider() + ctx = _ctx_with_provider(provider) + ctx.search([0.1, 0.2]) + assert provider.calls == [] + vector_passed = ctx._inner.search_calls[0][0] # type: ignore[attr-defined] + assert vector_passed == [0.1, 0.2] + + +def test_search_string_query_no_provider_raises(): + ctx = _ctx_no_provider() + with pytest.raises(TypeError): + ctx.search("spring travel") + + +# --------------------------------------------------------------------------- +# upsert() +# --------------------------------------------------------------------------- + + +def test_upsert_auto_embeds_text(): + provider = StubProvider() + ctx = _ctx_with_provider(provider) + ctx.upsert("user", "updated content", external_id="doc-1") + assert provider.calls == [["updated content"]] + assert ctx._inner.upsert_calls[0]["embedding"] == [0.0, 0.0] # type: ignore[attr-defined] + + +def test_upsert_manual_embedding_takes_precedence(): + provider = StubProvider() + ctx = _ctx_with_provider(provider) + ctx.upsert("user", "updated content", external_id="doc-1", embedding=[7.0, 7.0]) + assert provider.calls == [] + assert ctx._inner.upsert_calls[0]["embedding"] == [7.0, 7.0] # type: ignore[attr-defined] + + +# --------------------------------------------------------------------------- +# _build_provider registry +# --------------------------------------------------------------------------- + + +def test_build_provider_unknown_raises(): + with pytest.raises(ValueError, match="Unknown embedding provider"): + _build_provider({"provider": "does-not-exist"}) + + +def test_build_provider_missing_key_raises(): + with pytest.raises(ValueError, match="'provider' key"): + _build_provider({})