Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion python/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ build-backend = "maturin"
[project.optional-dependencies]
# `moto[server]` pulls in flask + flask-cors so moto.server can be launched
# as a subprocess for the S3 integration tests.
tests = ["pytest", "ruff", "moto[s3,server]", "boto3", "botocore"]
tests = ["pytest", "pytest-asyncio", "ruff", "moto[s3,server]", "boto3", "botocore"]
dev = ["ruff", "pyright"]

[tool.ruff]
Expand Down
8 changes: 6 additions & 2 deletions python/python/lance_context/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,9 @@
from __future__ import annotations

from .api import Context, __version__ # pyright: ignore[reportMissingImports]
from .api import ( # pyright: ignore[reportMissingImports]
AsyncContext,
Context,
__version__,
)

__all__ = ["Context", "__version__"]
__all__ = ["AsyncContext", "Context", "__version__"]
119 changes: 118 additions & 1 deletion python/python/lance_context/api.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from __future__ import annotations

import asyncio
import warnings
from datetime import datetime
from io import BytesIO
Expand All @@ -8,7 +9,7 @@
from ._internal import Context as _Context # pyright: ignore[reportMissingImports]
from ._internal import version as _version # pyright: ignore[reportMissingImports]

__all__ = ["Context", "__version__"]
__all__ = ["AsyncContext", "Context", "__version__"]

__version__ = _version()

Expand Down Expand Up @@ -450,3 +451,119 @@ def _from_inner(cls, inner: _Context) -> Context:
obj = cls.__new__(cls)
obj._inner = inner
return obj


class AsyncContext:
"""Async wrapper around :class:`Context`.

Every I/O method is dispatched to a thread-pool executor via
:func:`asyncio.get_running_loop().run_in_executor`. The underlying Rust
code releases the GIL during I/O, so the executor thread is only occupied
briefly for the Python ↔ Rust boundary crossing.

Usage::

ctx = await AsyncContext.create("/tmp/context.lance")
await ctx.add("user", "hello")
results = await ctx.list()
"""

def __init__(self, sync_ctx: Context) -> None:
self._sync = sync_ctx

@classmethod
async def create(
cls,
uri: str,
**kwargs: Any,
) -> AsyncContext:
loop = asyncio.get_running_loop()
sync_ctx = await loop.run_in_executor(
None, lambda: Context.create(uri, **kwargs)
)
return cls(sync_ctx)

def uri(self) -> str:
return self._sync.uri()

def branch(self) -> str:
return self._sync.branch()

def entries(self) -> int:
return self._sync.entries()

def version(self) -> int:
return self._sync.version()

async def add(
self,
role: str,
content: Any,
content_type: str | None = None,
data_type: str | None = None,
embedding: list[float] | None = None,
bot_id: str | None = None,
session_id: str | None = None,
) -> None:
loop = asyncio.get_running_loop()
await loop.run_in_executor(
None,
lambda: self._sync.add(
role,
content,
content_type=content_type,
data_type=data_type,
embedding=embedding,
bot_id=bot_id,
session_id=session_id,
),
)

def snapshot(self, label: str | None = None) -> str:
return self._sync.snapshot(label)

def fork(self, branch_name: str) -> AsyncContext:
sync_fork = self._sync.fork(branch_name)
return AsyncContext(sync_fork)

async def checkout(self, version_id: int | str) -> None:
loop = asyncio.get_running_loop()
await loop.run_in_executor(None, lambda: self._sync.checkout(version_id))

async def search(
self, query: Any, limit: int | None = None
) -> list[dict[str, Any]]:
loop = asyncio.get_running_loop()
return await loop.run_in_executor(None, lambda: self._sync.search(query, limit))

async def list(
self, limit: int | None = None, offset: int | None = None
) -> list[dict[str, Any]]:
loop = asyncio.get_running_loop()
return await loop.run_in_executor(None, lambda: self._sync.list(limit, offset))

async def compact(
self,
*,
target_rows_per_fragment: int | None = None,
materialize_deletions: bool = True,
) -> dict[str, int]:
loop = asyncio.get_running_loop()
return await loop.run_in_executor(
None,
lambda: self._sync.compact(
target_rows_per_fragment=target_rows_per_fragment,
materialize_deletions=materialize_deletions,
),
)

async def compaction_stats(self) -> dict[str, Any]:
loop = asyncio.get_running_loop()
return await loop.run_in_executor(None, self._sync.compaction_stats)

def __repr__(self) -> str:
return (
f"AsyncContext(uri={self._sync.uri()!r}, "
f"branch={self._sync.branch()!r}, "
f"entries={self._sync.entries()})"
)
138 changes: 138 additions & 0 deletions python/tests/test_async.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,138 @@
"""Tests for AsyncContext wrapper."""

from __future__ import annotations

from typing import TYPE_CHECKING

import pytest

if TYPE_CHECKING:
from pathlib import Path

from lance_context import AsyncContext


@pytest.mark.asyncio
async def test_create_and_add(tmp_path: Path) -> None:
uri = str(tmp_path / "ctx.lance")
ctx = await AsyncContext.create(uri)

await ctx.add("user", "hello")
await ctx.add("assistant", "hi there")

assert ctx.entries() == 2
assert ctx.uri() == uri


@pytest.mark.asyncio
async def test_list(tmp_path: Path) -> None:
uri = str(tmp_path / "ctx.lance")
ctx = await AsyncContext.create(uri)

for i in range(5):
await ctx.add("user", f"msg-{i}")

results = await ctx.list()
assert len(results) == 5
texts = {r["text"] for r in results}
for i in range(5):
assert f"msg-{i}" in texts


@pytest.mark.asyncio
async def test_list_with_limit_and_offset(tmp_path: Path) -> None:
uri = str(tmp_path / "ctx.lance")
ctx = await AsyncContext.create(uri)

for i in range(10):
await ctx.add("user", f"msg-{i}")

page = await ctx.list(limit=3, offset=2)
assert len(page) == 3


@pytest.mark.asyncio
async def test_compact(tmp_path: Path) -> None:
uri = str(tmp_path / "ctx.lance")
ctx = await AsyncContext.create(uri)

for i in range(10):
await ctx.add("user", f"entry-{i}")

metrics = await ctx.compact()
assert isinstance(metrics, dict)
assert "fragments_removed" in metrics


@pytest.mark.asyncio
async def test_compaction_stats(tmp_path: Path) -> None:
uri = str(tmp_path / "ctx.lance")
ctx = await AsyncContext.create(uri)
await ctx.add("user", "hello")

stats = await ctx.compaction_stats()
assert isinstance(stats, dict)
assert "total_fragments" in stats


@pytest.mark.asyncio
async def test_snapshot_and_checkout(tmp_path: Path) -> None:
uri = str(tmp_path / "ctx.lance")
ctx = await AsyncContext.create(uri)

await ctx.add("user", "v1")
v1 = ctx.version()

await ctx.add("user", "v2")
assert ctx.entries() == 2

await ctx.checkout(v1)


@pytest.mark.asyncio
async def test_fork(tmp_path: Path) -> None:
uri = str(tmp_path / "ctx.lance")
ctx = await AsyncContext.create(uri)

await ctx.add("user", "main-msg")
forked = ctx.fork("experiment")

assert forked.branch() == "experiment"
assert isinstance(forked, AsyncContext)


@pytest.mark.asyncio
async def test_search(tmp_path: Path) -> None:
uri = str(tmp_path / "ctx.lance")
ctx = await AsyncContext.create(uri)

dim = 1536
emb_a = [1.0] + [0.0] * (dim - 1)
emb_b = [0.0] + [1.0] + [0.0] * (dim - 2)

await ctx.add("user", "hello", embedding=emb_a)
await ctx.add("user", "world", embedding=emb_b)

results = await ctx.search(emb_a, limit=1)
assert len(results) == 1
assert results[0]["text"] == "hello"


@pytest.mark.asyncio
async def test_repr(tmp_path: Path) -> None:
uri = str(tmp_path / "ctx.lance")
ctx = await AsyncContext.create(uri)
r = repr(ctx)
assert r.startswith("AsyncContext(")
assert uri in r


@pytest.mark.asyncio
async def test_create_with_options(tmp_path: Path) -> None:
"""AsyncContext.create forwards kwargs to Context.create."""
uri = str(tmp_path / "ctx.lance")
ctx = await AsyncContext.create(uri, id_index_type="btree")

await ctx.add("user", "indexed")
results = await ctx.list()
assert len(results) == 1
Loading
Loading