From 0a81b34bd47707478e867502194c26afc804ab68 Mon Sep 17 00:00:00 2001 From: dcfocus Date: Fri, 12 Jun 2026 15:48:22 +0000 Subject: [PATCH 1/2] feat: pluggable embedding provider registry (#85) Add an EmbeddingProvider protocol and registry so Context auto-embeds text at write time and string queries at search time, eliminating the need for each caller to maintain their own embedding pipeline. Built-in providers for OpenAI and sentence-transformers ship as optional extras (lance-context[openai] / lance-context[sentence-transformers]); the registry accepts any object satisfying the EmbeddingProvider protocol for custom backends. - EmbeddingProvider: runtime-checkable Protocol (dims, embed_texts) - Context.create/AsyncContext.create: new embedding_provider kwarg (instance or {"provider": "openai", "model": ...} dict) - add() / upsert(): auto-embed text payloads when no manual embedding given; manual embedding= always takes precedence - add_many(): batch-embeds all uneembedded text records in one call - search(): accepts a plain string query and auto-embeds it via the provider; existing vector queries are unaffected - fork() propagates the provider to the child context - EmbeddingProvider exported from the top-level package Closes #85 Co-Authored-By: Claude Sonnet 4.6 --- python/pyproject.toml | 2 + python/python/lance_context/__init__.py | 3 +- python/python/lance_context/api.py | 48 +++- python/python/lance_context/embeddings.py | 116 +++++++++ python/tests/test_embeddings.py | 294 ++++++++++++++++++++++ 5 files changed, 459 insertions(+), 4 deletions(-) create mode 100644 python/python/lance_context/embeddings.py create mode 100644 python/tests/test_embeddings.py diff --git a/python/pyproject.toml b/python/pyproject.toml index e999808..3395d25 100644 --- a/python/pyproject.toml +++ b/python/pyproject.toml @@ -37,6 +37,8 @@ requires = ["maturin>=1.4"] build-backend = "maturin" [project.optional-dependencies] +openai = ["openai>=1.0"] +sentence-transformers = ["sentence-transformers>=2.0"] graph = ["lance-graph==0.5.4"] lance-python = [ "pylance>=7.0.0,<8", diff --git a/python/python/lance_context/__init__.py b/python/python/lance_context/__init__.py index a73a265..b6a356c 100644 --- a/python/python/lance_context/__init__.py +++ b/python/python/lance_context/__init__.py @@ -3,7 +3,8 @@ from .api import ( # pyright: ignore[reportMissingImports] AsyncContext, Context, + EmbeddingProvider, __version__, ) -__all__ = ["AsyncContext", "Context", "__version__"] +__all__ = ["AsyncContext", "Context", "EmbeddingProvider", "__version__"] diff --git a/python/python/lance_context/api.py b/python/python/lance_context/api.py index 957d6a5..7f0e1c0 100644 --- a/python/python/lance_context/api.py +++ b/python/python/lance_context/api.py @@ -10,8 +10,9 @@ from ._internal import Context as _Context # pyright: ignore[reportMissingImports] from ._internal import version as _version # pyright: ignore[reportMissingImports] +from .embeddings import EmbeddingProvider, _build_provider -__all__ = ["AsyncContext", "Context", "__version__"] +__all__ = ["AsyncContext", "Context", "EmbeddingProvider", "__version__"] __version__ = _version() @@ -313,6 +314,8 @@ def __init__( id_index_type: str | None = None, embedding_dim: int | None = None, distance_metric: str | None = None, + # --- Embedding provider. --- + embedding_provider: EmbeddingProvider | dict[str, Any] | None = None, ) -> None: options = _merge_storage_options( storage_options, @@ -350,6 +353,13 @@ def __init__( else: self._inner = _Context.create(uri) + if isinstance(embedding_provider, dict): + self._embedding_provider: EmbeddingProvider | None = _build_provider( + embedding_provider + ) + else: + self._embedding_provider = embedding_provider + @classmethod def create( cls, @@ -370,6 +380,7 @@ def create( id_index_type: str | None = None, embedding_dim: int | None = None, distance_metric: str | None = None, + embedding_provider: EmbeddingProvider | dict[str, Any] | None = None, ) -> Context: return cls( uri, @@ -388,6 +399,7 @@ def create( id_index_type=id_index_type, embedding_dim=embedding_dim, distance_metric=distance_metric, + embedding_provider=embedding_provider, ) def uri(self) -> str: @@ -427,6 +439,9 @@ def add( if content_type is None: content_type = data_type payload, resolved_type = _normalize_content(content, content_type) + provider = getattr(self, "_embedding_provider", None) + if embedding is None and provider is not None and isinstance(payload, str): + embedding = provider.embed_texts([payload])[0] self._inner.add( role, payload, @@ -477,6 +492,9 @@ def upsert( content_type = data_type payload, resolved_type = _normalize_content(content, content_type) + provider = getattr(self, "_embedding_provider", None) + if embedding is None and provider is not None and isinstance(payload, str): + embedding = provider.embed_texts([payload])[0] result = self._inner.upsert( role, payload, @@ -610,14 +628,32 @@ def add_many(self, records: Iterable[Mapping[str, Any]]) -> None: } ) + self._auto_embed_batch(normalized) self._inner.add_many(normalized) + def _auto_embed_batch(self, records: list[dict[str, Any]]) -> None: + """Embed text records without an embedding in one provider call.""" + provider = getattr(self, "_embedding_provider", None) + if provider is None: + return + indices = [ + i + for i, r in enumerate(records) + if r.get("embedding") is None and isinstance(r.get("content"), str) + ] + if not indices: + return + texts = [records[i]["content"] for i in indices] + vectors = provider.embed_texts(texts) + for i, vec in zip(indices, vectors): + records[i]["embedding"] = vec + def snapshot(self, label: str | None = None) -> str: return self._inner.snapshot(label) def fork(self, branch_name: str) -> Context: inner = self._inner.fork(branch_name) - return self._from_inner(inner) + return self._from_inner(inner, self._embedding_provider) def checkout(self, version_id: int | str) -> None: self._inner.checkout(int(version_id)) @@ -632,6 +668,9 @@ def search( include_retired: bool = False, include_relationships: bool = False, ) -> list[dict[str, Any]]: + provider = getattr(self, "_embedding_provider", None) + if isinstance(query, str) and provider is not None: + query = provider.embed_texts([query])[0] vector = _coerce_vector(query) results = self._inner.search( vector, @@ -815,9 +854,12 @@ def __repr__(self) -> str: ) @classmethod - def _from_inner(cls, inner: _Context) -> Context: + def _from_inner( + cls, inner: _Context, embedding_provider: EmbeddingProvider | None = None + ) -> Context: obj = cls.__new__(cls) obj._inner = inner + obj._embedding_provider = embedding_provider return obj diff --git a/python/python/lance_context/embeddings.py b/python/python/lance_context/embeddings.py new file mode 100644 index 0000000..53feed8 --- /dev/null +++ b/python/python/lance_context/embeddings.py @@ -0,0 +1,116 @@ +from __future__ import annotations + +from typing import Any, Protocol, runtime_checkable + +__all__ = ["EmbeddingProvider", "OpenAIProvider", "SentenceTransformersProvider"] + + +@runtime_checkable +class EmbeddingProvider(Protocol): + """Minimal interface for pluggable embedding providers. + + Implement this protocol to plug in any embedding backend. Providers + receive lists so :meth:`Context.add_many` can embed in a single call. + """ + + @property + def dims(self) -> int: + """Dimensionality of the vectors produced by this provider.""" + ... + + def embed_texts(self, texts: list[str]) -> list[list[float]]: + """Return one embedding vector per input text.""" + ... + + +def _build_provider(config: dict[str, Any]) -> EmbeddingProvider: + """Instantiate a built-in provider from a config dict. + + Config keys: + provider: "openai" | "sentence-transformers" + model: model name / ID (optional, uses provider default) + **kwargs: forwarded to the provider constructor + """ + name = config.get("provider") + if not name: + raise ValueError("embedding config must include a 'provider' key") + kwargs = {k: v for k, v in config.items() if k != "provider"} + try: + cls = _REGISTRY[name] + except KeyError: + raise ValueError( + f"Unknown embedding provider {name!r}. " + f"Available: {sorted(_REGISTRY)}" + ) from None + return cls(**kwargs) + + +class OpenAIProvider: + """Embedding provider backed by the OpenAI embeddings API. + + Requires ``pip install lance-context[openai]``. + """ + + def __init__( + self, model: str = "text-embedding-3-small", **client_kwargs: Any + ) -> None: + try: + from openai import OpenAI # pyright: ignore[reportMissingImports] + except ImportError as exc: + raise ImportError( + "openai is required for the OpenAI embedding provider. " + "Install it with: pip install lance-context[openai]" + ) from exc + self._client = OpenAI(**client_kwargs) + self._model = model + self._dims: int | None = None + + @property + def dims(self) -> int: + if self._dims is None: + # Probe dims with a dummy call the first time they're requested. + result = self._client.embeddings.create(input=["probe"], model=self._model) + self._dims = len(result.data[0].embedding) + return self._dims + + def embed_texts(self, texts: list[str]) -> list[list[float]]: + result = self._client.embeddings.create(input=texts, model=self._model) + ordered = sorted(result.data, key=lambda d: d.index) + vecs = [item.embedding for item in ordered] + if self._dims is None and vecs: + self._dims = len(vecs[0]) + return vecs + + +class SentenceTransformersProvider: + """Embedding provider backed by sentence-transformers (local / offline). + + Requires ``pip install lance-context[sentence-transformers]``. + """ + + def __init__(self, model: str = "all-MiniLM-L6-v2", **model_kwargs: Any) -> None: + try: + from sentence_transformers import ( # pyright: ignore[reportMissingImports] + SentenceTransformer, + ) + except ImportError as exc: + raise ImportError( + "sentence-transformers is required for the " + "SentenceTransformers provider. " + "Install it with: pip install lance-context[sentence-transformers]" + ) from exc + self._model = SentenceTransformer(model, **model_kwargs) + + @property + def dims(self) -> int: + return int(self._model.get_sentence_embedding_dimension()) + + def embed_texts(self, texts: list[str]) -> list[list[float]]: + vectors = self._model.encode(texts, convert_to_numpy=True) + return [v.tolist() for v in vectors] + + +_REGISTRY: dict[str, type] = { + "openai": OpenAIProvider, + "sentence-transformers": SentenceTransformersProvider, +} diff --git a/python/tests/test_embeddings.py b/python/tests/test_embeddings.py new file mode 100644 index 0000000..f32402e --- /dev/null +++ b/python/tests/test_embeddings.py @@ -0,0 +1,294 @@ +"""Tests for the pluggable embedding provider registry. + +Uses a stub provider so no external dependencies are needed. +""" +from __future__ import annotations + +from typing import Any + +import pytest +from lance_context.api import Context +from lance_context.embeddings import EmbeddingProvider, _build_provider + +# --------------------------------------------------------------------------- +# Stub provider +# --------------------------------------------------------------------------- + + +class StubProvider: + """Deterministic fake that returns [index, 0.0] per text.""" + + def __init__(self) -> None: + self.calls: list[list[str]] = [] + + @property + def dims(self) -> int: + return 2 + + def embed_texts(self, texts: list[str]) -> list[list[float]]: + self.calls.append(list(texts)) + return [[float(i), 0.0] for i in range(len(texts))] + + +def _ctx_with_provider(provider: StubProvider) -> Context: + ctx = Context.__new__(Context) + ctx._inner = _DummyInner() # type: ignore[attr-defined] + ctx._embedding_provider = provider + return ctx + + +def _ctx_no_provider() -> Context: + ctx = Context.__new__(Context) + ctx._inner = _DummyInner() # type: ignore[attr-defined] + ctx._embedding_provider = None + return ctx + + +class _DummyInner: + """Minimal inner stub that records calls.""" + + def __init__(self) -> None: + self.add_calls: list[dict[str, Any]] = [] + self.add_many_calls: list[list[dict[str, Any]]] = [] + self.search_calls: list[tuple[Any, ...]] = [] + self.upsert_calls: list[dict[str, Any]] = [] + + def add( # noqa: PLR0913 + self, + role: str, + content: Any, + data_type: Any, + embedding: Any, + bot_id: Any, + session_id: Any, + external_id: Any, + metadata_json: Any, + expires_at: Any = None, + retention_policy: Any = None, + lifecycle_status: Any = None, + retired_at: Any = None, + retired_reason: Any = None, + supersedes_id: Any = None, + superseded_by_id: Any = None, + relationships_json: Any = None, + ) -> None: + self.add_calls.append( + {"role": role, "content": content, "embedding": embedding} + ) + + def add_many(self, records: list[dict[str, Any]]) -> None: + self.add_many_calls.append(records) + + def search( + self, + vector: list[float], + limit: Any, + filters_json: Any, + include_expired: bool = False, + include_retired: bool = False, + include_relationships: bool = False, + ) -> list[Any]: + self.search_calls.append( + (vector, limit, filters_json, + include_expired, include_retired, include_relationships) + ) + return [] + + def upsert( # noqa: PLR0913 + self, + role: str, + content: Any, + data_type: Any, + embedding: Any, + bot_id: Any, + session_id: Any, + external_id: Any, + metadata_json: Any, + expires_at: Any = None, + retention_policy: Any = None, + lifecycle_status: Any = None, + retired_at: Any = None, + retired_reason: Any = None, + relationships_json: Any = None, + key: str = "external_id", + ) -> dict[str, Any]: + self.upsert_calls.append( + {"role": role, "content": content, "embedding": embedding} + ) + return { + "inserted": True, + "replaced_id": None, + "version": 1, + "record": { + "id": "x", + "external_id": external_id, + "run_id": "r", + "bot_id": None, + "session_id": None, + "role": role, + "content_type": data_type, + "text_payload": content, + "binary_payload": None, + "embedding": embedding, + "created_at": "2024-01-01T00:00:00Z", + "state_metadata": None, + "metadata": None, + "relationships": [], + "expires_at": None, + "retention_policy": None, + "lifecycle_status": "active", + "retired_at": None, + "retired_reason": None, + "supersedes_id": None, + "superseded_by_id": None, + }, + } + + +# --------------------------------------------------------------------------- +# Protocol conformance +# --------------------------------------------------------------------------- + + +def test_stub_satisfies_protocol(): + assert isinstance(StubProvider(), EmbeddingProvider) + + +# --------------------------------------------------------------------------- +# add() +# --------------------------------------------------------------------------- + + +def test_add_auto_embeds_text(): + provider = StubProvider() + ctx = _ctx_with_provider(provider) + ctx.add("user", "hello world") + assert provider.calls == [["hello world"]] + assert ctx._inner.add_calls[0]["embedding"] == [0.0, 0.0] # type: ignore[attr-defined] + + +def test_add_manual_embedding_takes_precedence(): + provider = StubProvider() + ctx = _ctx_with_provider(provider) + ctx.add("user", "hello", embedding=[9.0, 9.0]) + assert provider.calls == [] + assert ctx._inner.add_calls[0]["embedding"] == [9.0, 9.0] # type: ignore[attr-defined] + + +def test_add_no_provider_leaves_embedding_none(): + ctx = _ctx_no_provider() + ctx.add("user", "hello") + assert ctx._inner.add_calls[0]["embedding"] is None # type: ignore[attr-defined] + + +def test_add_binary_content_not_auto_embedded(): + provider = StubProvider() + ctx = _ctx_with_provider(provider) + ctx.add("user", b"\x00\x01\x02", content_type="application/octet-stream") + assert provider.calls == [] + assert ctx._inner.add_calls[0]["embedding"] is None # type: ignore[attr-defined] + + +# --------------------------------------------------------------------------- +# add_many() +# --------------------------------------------------------------------------- + + +def test_add_many_batch_embeds_text_records(): + provider = StubProvider() + ctx = _ctx_with_provider(provider) + ctx.add_many([ + {"role": "user", "content": "first"}, + {"role": "assistant", "content": "second"}, + ]) + # Both texts sent in one provider call. + assert provider.calls == [["first", "second"]] + records = ctx._inner.add_many_calls[0] # type: ignore[attr-defined] + assert records[0]["embedding"] == [0.0, 0.0] + assert records[1]["embedding"] == [1.0, 0.0] + + +def test_add_many_skips_records_with_manual_embedding(): + provider = StubProvider() + ctx = _ctx_with_provider(provider) + ctx.add_many([ + {"role": "user", "content": "first", "embedding": [5.0, 5.0]}, + {"role": "assistant", "content": "second"}, + ]) + # Only the second record is sent for embedding. + assert provider.calls == [["second"]] + records = ctx._inner.add_many_calls[0] # type: ignore[attr-defined] + assert records[0]["embedding"] == [5.0, 5.0] + assert records[1]["embedding"] == [0.0, 0.0] + + +def test_add_many_no_provider_leaves_embeddings_unchanged(): + ctx = _ctx_no_provider() + ctx.add_many([{"role": "user", "content": "hello"}]) + records = ctx._inner.add_many_calls[0] # type: ignore[attr-defined] + assert records[0]["embedding"] is None + + +# --------------------------------------------------------------------------- +# search() +# --------------------------------------------------------------------------- + + +def test_search_auto_embeds_string_query(): + provider = StubProvider() + ctx = _ctx_with_provider(provider) + ctx.search("spring travel") + assert provider.calls == [["spring travel"]] + vector_passed = ctx._inner.search_calls[0][0] # type: ignore[attr-defined] + assert vector_passed == [0.0, 0.0] + + +def test_search_vector_query_bypasses_provider(): + provider = StubProvider() + ctx = _ctx_with_provider(provider) + ctx.search([0.1, 0.2]) + assert provider.calls == [] + vector_passed = ctx._inner.search_calls[0][0] # type: ignore[attr-defined] + assert vector_passed == [0.1, 0.2] + + +def test_search_string_query_no_provider_raises(): + ctx = _ctx_no_provider() + with pytest.raises(TypeError): + ctx.search("spring travel") + + +# --------------------------------------------------------------------------- +# upsert() +# --------------------------------------------------------------------------- + + +def test_upsert_auto_embeds_text(): + provider = StubProvider() + ctx = _ctx_with_provider(provider) + ctx.upsert("user", "updated content", external_id="doc-1") + assert provider.calls == [["updated content"]] + assert ctx._inner.upsert_calls[0]["embedding"] == [0.0, 0.0] # type: ignore[attr-defined] + + +def test_upsert_manual_embedding_takes_precedence(): + provider = StubProvider() + ctx = _ctx_with_provider(provider) + ctx.upsert("user", "updated content", external_id="doc-1", embedding=[7.0, 7.0]) + assert provider.calls == [] + assert ctx._inner.upsert_calls[0]["embedding"] == [7.0, 7.0] # type: ignore[attr-defined] + + +# --------------------------------------------------------------------------- +# _build_provider registry +# --------------------------------------------------------------------------- + + +def test_build_provider_unknown_raises(): + with pytest.raises(ValueError, match="Unknown embedding provider"): + _build_provider({"provider": "does-not-exist"}) + + +def test_build_provider_missing_key_raises(): + with pytest.raises(ValueError, match="'provider' key"): + _build_provider({}) From 175ee4eb9064d345fb85699339c6dd40c73954ba Mon Sep 17 00:00:00 2001 From: Allen Cheng Date: Fri, 12 Jun 2026 16:29:21 -0700 Subject: [PATCH 2/2] style: format embedding provider Python files --- python/python/lance_context/embeddings.py | 3 +-- python/tests/test_embeddings.py | 31 +++++++++++++++-------- 2 files changed, 22 insertions(+), 12 deletions(-) diff --git a/python/python/lance_context/embeddings.py b/python/python/lance_context/embeddings.py index 53feed8..8dc2321 100644 --- a/python/python/lance_context/embeddings.py +++ b/python/python/lance_context/embeddings.py @@ -39,8 +39,7 @@ def _build_provider(config: dict[str, Any]) -> EmbeddingProvider: cls = _REGISTRY[name] except KeyError: raise ValueError( - f"Unknown embedding provider {name!r}. " - f"Available: {sorted(_REGISTRY)}" + f"Unknown embedding provider {name!r}. Available: {sorted(_REGISTRY)}" ) from None return cls(**kwargs) diff --git a/python/tests/test_embeddings.py b/python/tests/test_embeddings.py index f32402e..2d35711 100644 --- a/python/tests/test_embeddings.py +++ b/python/tests/test_embeddings.py @@ -2,6 +2,7 @@ Uses a stub provider so no external dependencies are needed. """ + from __future__ import annotations from typing import Any @@ -89,8 +90,14 @@ def search( include_relationships: bool = False, ) -> list[Any]: self.search_calls.append( - (vector, limit, filters_json, - include_expired, include_retired, include_relationships) + ( + vector, + limit, + filters_json, + include_expired, + include_retired, + include_relationships, + ) ) return [] @@ -197,10 +204,12 @@ def test_add_binary_content_not_auto_embedded(): def test_add_many_batch_embeds_text_records(): provider = StubProvider() ctx = _ctx_with_provider(provider) - ctx.add_many([ - {"role": "user", "content": "first"}, - {"role": "assistant", "content": "second"}, - ]) + ctx.add_many( + [ + {"role": "user", "content": "first"}, + {"role": "assistant", "content": "second"}, + ] + ) # Both texts sent in one provider call. assert provider.calls == [["first", "second"]] records = ctx._inner.add_many_calls[0] # type: ignore[attr-defined] @@ -211,10 +220,12 @@ def test_add_many_batch_embeds_text_records(): def test_add_many_skips_records_with_manual_embedding(): provider = StubProvider() ctx = _ctx_with_provider(provider) - ctx.add_many([ - {"role": "user", "content": "first", "embedding": [5.0, 5.0]}, - {"role": "assistant", "content": "second"}, - ]) + ctx.add_many( + [ + {"role": "user", "content": "first", "embedding": [5.0, 5.0]}, + {"role": "assistant", "content": "second"}, + ] + ) # Only the second record is sent for embedding. assert provider.calls == [["second"]] records = ctx._inner.add_many_calls[0] # type: ignore[attr-defined]