Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions python/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,8 @@ requires = ["maturin>=1.4"]
build-backend = "maturin"

[project.optional-dependencies]
openai = ["openai>=1.0"]
sentence-transformers = ["sentence-transformers>=2.0"]
graph = ["lance-graph==0.5.4"]
lance-python = [
"pylance>=7.0.0,<8",
Expand Down
3 changes: 2 additions & 1 deletion python/python/lance_context/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,8 @@
from .api import ( # pyright: ignore[reportMissingImports]
AsyncContext,
Context,
EmbeddingProvider,
__version__,
)

__all__ = ["AsyncContext", "Context", "__version__"]
__all__ = ["AsyncContext", "Context", "EmbeddingProvider", "__version__"]
48 changes: 45 additions & 3 deletions python/python/lance_context/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,9 @@

from ._internal import Context as _Context # pyright: ignore[reportMissingImports]
from ._internal import version as _version # pyright: ignore[reportMissingImports]
from .embeddings import EmbeddingProvider, _build_provider

__all__ = ["AsyncContext", "Context", "__version__"]
__all__ = ["AsyncContext", "Context", "EmbeddingProvider", "__version__"]

__version__ = _version()

Expand Down Expand Up @@ -313,6 +314,8 @@ def __init__(
id_index_type: str | None = None,
embedding_dim: int | None = None,
distance_metric: str | None = None,
# --- Embedding provider. ---
embedding_provider: EmbeddingProvider | dict[str, Any] | None = None,
) -> None:
options = _merge_storage_options(
storage_options,
Expand Down Expand Up @@ -350,6 +353,13 @@ def __init__(
else:
self._inner = _Context.create(uri)

if isinstance(embedding_provider, dict):
self._embedding_provider: EmbeddingProvider | None = _build_provider(
embedding_provider
)
else:
self._embedding_provider = embedding_provider

@classmethod
def create(
cls,
Expand All @@ -370,6 +380,7 @@ def create(
id_index_type: str | None = None,
embedding_dim: int | None = None,
distance_metric: str | None = None,
embedding_provider: EmbeddingProvider | dict[str, Any] | None = None,
) -> Context:
return cls(
uri,
Expand All @@ -388,6 +399,7 @@ def create(
id_index_type=id_index_type,
embedding_dim=embedding_dim,
distance_metric=distance_metric,
embedding_provider=embedding_provider,
)

def uri(self) -> str:
Expand Down Expand Up @@ -427,6 +439,9 @@ def add(
if content_type is None:
content_type = data_type
payload, resolved_type = _normalize_content(content, content_type)
provider = getattr(self, "_embedding_provider", None)
if embedding is None and provider is not None and isinstance(payload, str):
embedding = provider.embed_texts([payload])[0]
self._inner.add(
role,
payload,
Expand Down Expand Up @@ -477,6 +492,9 @@ def upsert(
content_type = data_type

payload, resolved_type = _normalize_content(content, content_type)
provider = getattr(self, "_embedding_provider", None)
if embedding is None and provider is not None and isinstance(payload, str):
embedding = provider.embed_texts([payload])[0]
result = self._inner.upsert(
role,
payload,
Expand Down Expand Up @@ -610,14 +628,32 @@ def add_many(self, records: Iterable[Mapping[str, Any]]) -> None:
}
)

self._auto_embed_batch(normalized)
self._inner.add_many(normalized)

def _auto_embed_batch(self, records: list[dict[str, Any]]) -> None:
"""Embed text records without an embedding in one provider call."""
provider = getattr(self, "_embedding_provider", None)
if provider is None:
return
indices = [
i
for i, r in enumerate(records)
if r.get("embedding") is None and isinstance(r.get("content"), str)
]
if not indices:
return
texts = [records[i]["content"] for i in indices]
vectors = provider.embed_texts(texts)
for i, vec in zip(indices, vectors):
records[i]["embedding"] = vec

def snapshot(self, label: str | None = None) -> str:
return self._inner.snapshot(label)

def fork(self, branch_name: str) -> Context:
inner = self._inner.fork(branch_name)
return self._from_inner(inner)
return self._from_inner(inner, self._embedding_provider)

def checkout(self, version_id: int | str) -> None:
self._inner.checkout(int(version_id))
Expand All @@ -632,6 +668,9 @@ def search(
include_retired: bool = False,
include_relationships: bool = False,
) -> list[dict[str, Any]]:
provider = getattr(self, "_embedding_provider", None)
if isinstance(query, str) and provider is not None:
query = provider.embed_texts([query])[0]
vector = _coerce_vector(query)
results = self._inner.search(
vector,
Expand Down Expand Up @@ -815,9 +854,12 @@ def __repr__(self) -> str:
)

@classmethod
def _from_inner(cls, inner: _Context) -> Context:
def _from_inner(
cls, inner: _Context, embedding_provider: EmbeddingProvider | None = None
) -> Context:
obj = cls.__new__(cls)
obj._inner = inner
obj._embedding_provider = embedding_provider
return obj


Expand Down
115 changes: 115 additions & 0 deletions python/python/lance_context/embeddings.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,115 @@
from __future__ import annotations

from typing import Any, Protocol, runtime_checkable

__all__ = ["EmbeddingProvider", "OpenAIProvider", "SentenceTransformersProvider"]


@runtime_checkable
class EmbeddingProvider(Protocol):
"""Minimal interface for pluggable embedding providers.

Implement this protocol to plug in any embedding backend. Providers
receive lists so :meth:`Context.add_many` can embed in a single call.
"""

@property
def dims(self) -> int:
"""Dimensionality of the vectors produced by this provider."""
...

def embed_texts(self, texts: list[str]) -> list[list[float]]:
"""Return one embedding vector per input text."""
...


def _build_provider(config: dict[str, Any]) -> EmbeddingProvider:
"""Instantiate a built-in provider from a config dict.

Config keys:
provider: "openai" | "sentence-transformers"
model: model name / ID (optional, uses provider default)
**kwargs: forwarded to the provider constructor
"""
name = config.get("provider")
if not name:
raise ValueError("embedding config must include a 'provider' key")
kwargs = {k: v for k, v in config.items() if k != "provider"}
try:
cls = _REGISTRY[name]
except KeyError:
raise ValueError(
f"Unknown embedding provider {name!r}. Available: {sorted(_REGISTRY)}"
) from None
return cls(**kwargs)


class OpenAIProvider:
"""Embedding provider backed by the OpenAI embeddings API.

Requires ``pip install lance-context[openai]``.
"""

def __init__(
self, model: str = "text-embedding-3-small", **client_kwargs: Any
) -> None:
try:
from openai import OpenAI # pyright: ignore[reportMissingImports]
except ImportError as exc:
raise ImportError(
"openai is required for the OpenAI embedding provider. "
"Install it with: pip install lance-context[openai]"
) from exc
self._client = OpenAI(**client_kwargs)
self._model = model
self._dims: int | None = None

@property
def dims(self) -> int:
if self._dims is None:
# Probe dims with a dummy call the first time they're requested.
result = self._client.embeddings.create(input=["probe"], model=self._model)
self._dims = len(result.data[0].embedding)
return self._dims

def embed_texts(self, texts: list[str]) -> list[list[float]]:
result = self._client.embeddings.create(input=texts, model=self._model)
ordered = sorted(result.data, key=lambda d: d.index)
vecs = [item.embedding for item in ordered]
if self._dims is None and vecs:
self._dims = len(vecs[0])
return vecs


class SentenceTransformersProvider:
"""Embedding provider backed by sentence-transformers (local / offline).

Requires ``pip install lance-context[sentence-transformers]``.
"""

def __init__(self, model: str = "all-MiniLM-L6-v2", **model_kwargs: Any) -> None:
try:
from sentence_transformers import ( # pyright: ignore[reportMissingImports]
SentenceTransformer,
)
except ImportError as exc:
raise ImportError(
"sentence-transformers is required for the "
"SentenceTransformers provider. "
"Install it with: pip install lance-context[sentence-transformers]"
) from exc
self._model = SentenceTransformer(model, **model_kwargs)

@property
def dims(self) -> int:
return int(self._model.get_sentence_embedding_dimension())

def embed_texts(self, texts: list[str]) -> list[list[float]]:
vectors = self._model.encode(texts, convert_to_numpy=True)
return [v.tolist() for v in vectors]


_REGISTRY: dict[str, type] = {
"openai": OpenAIProvider,
"sentence-transformers": SentenceTransformersProvider,
}
Loading
Loading