diff --git a/README.md b/README.md index 11e7c3f..a7133b6 100644 --- a/README.md +++ b/README.md @@ -38,6 +38,29 @@ from daft_lance import merge_columns_df merge_columns_df(df, "s3://bucket/my_dataset") ``` +### Namespace Tables + +```python +import daft +import daft_lance # installs daft.read_lance / DataFrame.write_lance namespace support + +table_id = ["my_table"] +namespace_properties = {"root": "/tmp/lance_tables"} + +df.write_lance( + namespace_impl="dir", + namespace_properties=namespace_properties, + table_id=table_id, + mode="create", +).collect() + +df = daft.read_lance( + namespace_impl="dir", + namespace_properties=namespace_properties, + table_id=table_id, +) +``` + ## Migration The migration only requires replacing `daft.io.lance` with `daft_lance`. diff --git a/daft_lance/__init__.py b/daft_lance/__init__.py index cb07862..8a1ecba 100644 --- a/daft_lance/__init__.py +++ b/daft_lance/__init__.py @@ -1,3 +1,7 @@ +from __future__ import annotations + +from typing import Any, Literal + try: import daft except ImportError: @@ -11,6 +15,54 @@ merge_columns_df, read_lance, ) +from .lance_data_sink import LanceDataSink + + +def _patch_daft_lance_api() -> None: + """Expose the daft-lance implementation through Daft's convenience APIs.""" + daft.read_lance = read_lance # type: ignore[assignment] + + from daft.dataframe import DataFrame + + original_write_lance = getattr(DataFrame, "write_lance") + if getattr(original_write_lance, "_daft_lance_namespace_patch", False): + return + + def write_lance( + self: DataFrame, + uri: Any = None, + mode: Literal["create", "append", "overwrite", "merge"] = "create", + io_config: Any | None = None, + schema: Any | None = None, + left_on: str | None = None, + right_on: str | None = None, + **kwargs: Any, + ) -> DataFrame: + if mode == "merge": + if any(k in kwargs for k in ("namespace_impl", "namespace_properties", "table_id")): + raise ValueError("write_lance(mode='merge') does not support namespace parameters yet.") + return original_write_lance( + self, + uri, + mode=mode, + io_config=io_config, + schema=schema, + left_on=left_on, + right_on=right_on, + **kwargs, + ) + + if schema is None: + schema = self.schema() + sanitized_kwargs = {k: v for k, v in kwargs.items() if k not in ("left_on", "right_on")} + sink = LanceDataSink(uri, schema, mode, io_config, **sanitized_kwargs) + return self.write_sink(sink) + + write_lance._daft_lance_namespace_patch = True # type: ignore[attr-defined] + DataFrame.write_lance = write_lance # type: ignore[assignment] + + +_patch_daft_lance_api() __all__ = [ "compact_files", diff --git a/daft_lance/_lance.py b/daft_lance/_lance.py index 21b7c67..7b5f55c 100644 --- a/daft_lance/_lance.py +++ b/daft_lance/_lance.py @@ -30,7 +30,7 @@ @PublicAPI def read_lance( - uri: str | pathlib.Path, + uri: str | pathlib.Path | None = None, io_config: IOConfig | None = None, version: str | int | None = None, asof: str | None = None, @@ -42,6 +42,10 @@ def read_lance( fragment_group_size: int | None = None, include_fragment_id: bool | None = None, checkpoint: CheckpointConfig | None = None, + *, + table_id: list[str] | None = None, + namespace_impl: str | None = None, + namespace_properties: dict[str, str] | None = None, ) -> DataFrame: """Create a DataFrame from a LanceDB table. @@ -123,8 +127,8 @@ def read_lance( >>> df = daft.read_lance("s3://daft-oss-public-data/lance/words-test-dataset", io_config=io_config) >>> df.show() """ - uri_str = str(uri) - if uri_str.startswith("rest://"): + uri_str = str(uri) if uri is not None else None + if uri_str is not None and uri_str.startswith("rest://"): raise ValueError( "rest:// Lance URIs are no longer supported by daft.read_lance. " "The previous REST-namespace integration did not match the real " @@ -133,11 +137,14 @@ def read_lance( ) io_config = context.get_context().daft_planning_config.default_io_config if io_config is None else io_config - storage_options = io_config_to_storage_options(io_config, uri_str) + storage_options = io_config_to_storage_options(io_config, uri_str) if uri_str is not None else None ds = construct_lance_dataset( uri_str, storage_options=storage_options, + namespace_impl=namespace_impl, + namespace_properties=namespace_properties, + table_id=table_id, version=version, asof=asof, block_size=block_size, diff --git a/daft_lance/lance_data_sink.py b/daft_lance/lance_data_sink.py index 62c36bb..2ec69a7 100644 --- a/daft_lance/lance_data_sink.py +++ b/daft_lance/lance_data_sink.py @@ -26,6 +26,13 @@ detect_blob_v2_columns, resolve_storage_version, ) +from daft_lance.namespace import ( + get_namespace_kwargs, + get_write_fragments_kwargs, + merge_storage_options, + resolve_namespace_table, + validate_uri_or_namespace, +) if TYPE_CHECKING: from collections.abc import Iterator @@ -40,11 +47,14 @@ class LanceDataSink(DataSink[list[FragmentMetadata]]): def __init__( self, - uri: str | pathlib.Path, + uri: str | pathlib.Path | None, schema: Schema | pa.Schema, mode: Literal["create", "append", "overwrite"] = "create", io_config: IOConfig | None = None, *, + table_id: list[str] | None = None, + namespace_impl: str | None = None, + namespace_properties: dict[str, str] | None = None, blob_columns: list[str] | None = None, max_rows_per_file: int = 1024 * 1024, max_rows_per_group: int = 1024, @@ -57,17 +67,26 @@ def __init__( compact_after_write: bool = True, ) -> None: self._reject_unsupported_modes(mode, use_legacy_format) - if not isinstance(uri, (str, pathlib.Path)): + validate_uri_or_namespace(uri, namespace_impl, table_id) + if uri is not None and not isinstance(uri, (str, pathlib.Path)): raise TypeError(f"Expected URI to be str or pathlib.Path, got {type(uri)}") - self._table_uri = str(uri) self._mode = mode + self._namespace_impl = namespace_impl + self._namespace_properties = namespace_properties + self._table_id = table_id self._io_config = get_context().daft_planning_config.default_io_config if io_config is None else io_config - self._storage_options = ( - storage_options - if storage_options is not None - else io_config_to_storage_options(self._io_config, self._table_uri) + base_storage_options = ( + ( + storage_options + if storage_options is not None + else io_config_to_storage_options(self._io_config, str(uri)) + ) + if uri is not None + else storage_options ) + self._table_uri, namespace_storage_options = self._resolve_table_uri(uri) + self._storage_options = merge_storage_options(base_storage_options, namespace_storage_options) self._init_lance_knobs( max_rows_per_file=max_rows_per_file, max_rows_per_group=max_rows_per_group, @@ -109,6 +128,28 @@ def __init__( ] ) + @property + def _namespace_kwargs(self) -> dict[str, object]: + return get_namespace_kwargs(self._namespace_impl, self._namespace_properties, self._table_id) + + @property + def _dataset_uri_arg(self) -> str | None: + return None if self._namespace_impl is not None and self._table_id is not None else self._table_uri + + def _resolve_table_uri(self, uri: str | pathlib.Path | None) -> tuple[str, dict[str, str] | None]: + if uri is not None: + return str(uri), None + mode = "create" if self._mode == "create" else "overwrite" if self._mode == "overwrite" else "read" + resolved_uri, namespace_storage_options = resolve_namespace_table( + namespace_impl=self._namespace_impl, + namespace_properties=self._namespace_properties, + table_id=self._table_id, + mode=mode, + ) + if resolved_uri is None: + raise ValueError("Unable to resolve Lance dataset URI from namespace.") + return resolved_uri, namespace_storage_options + @staticmethod def _reject_unsupported_modes( mode: Literal["create", "append", "overwrite"], use_legacy_format: bool | None @@ -167,7 +208,9 @@ def _absorb_existing_dataset(self) -> lance.LanceDataset | None: """ dataset: lance.LanceDataset | None try: - dataset = lance.dataset(self._table_uri, storage_options=self._storage_options) + dataset = lance.dataset( + self._dataset_uri_arg, storage_options=self._storage_options, **self._namespace_kwargs + ) except (ValueError, FileNotFoundError, OSError) as e: # Pinned to the Rust message format; lance has no typed exception. See test_lance_message_format_unchanged. if "was not found" in str(e): @@ -230,6 +273,7 @@ def _write_arrow_table(self, table: pa.Table) -> WriteResult[list[FragmentMetada data_storage_version=self._data_storage_version, use_legacy_format=self._use_legacy_format, enable_stable_row_ids=self._enable_stable_row_ids, + **get_write_fragments_kwargs(self._namespace_impl, self._namespace_properties, self._table_id), ) # Sum on-disk sizes from fragment metadata. Lance Blob V2 sidecar .blob # files are not tracked in FragmentMetadata.files (out of scope here). @@ -240,7 +284,7 @@ def _write_arrow_table(self, table: pa.Table) -> WriteResult[list[FragmentMetada def _ensure_mem_wal_dataset(self) -> lance.LanceDataset: try: - ds = lance.dataset(self._table_uri, storage_options=self._storage_options) + ds = lance.dataset(self._dataset_uri_arg, storage_options=self._storage_options, **self._namespace_kwargs) except (ValueError, FileNotFoundError, OSError): ds = None @@ -250,11 +294,12 @@ def _ensure_mem_wal_dataset(self) -> lance.LanceDataset: {f.name: pa.array([], type=f.type) for f in self._effective_pyarrow_schema}, schema=self._effective_pyarrow_schema, ), - self._table_uri, + self._dataset_uri_arg, mode="create", storage_options=self._storage_options, data_storage_version=self._data_storage_version, use_legacy_format=self._use_legacy_format, + **self._namespace_kwargs, ) details = ds.mem_wal_index_details() @@ -335,6 +380,7 @@ def _finalize_cow(self, write_results: list[WriteResult[list[FragmentMetadata]]] operation, read_version=self._version, storage_options=self._storage_options, + **self._namespace_kwargs, ) stats = dataset.stats.dataset_stats() stats_dict = MicroPartition.from_pydict( @@ -348,7 +394,7 @@ def _finalize_cow(self, write_results: list[WriteResult[list[FragmentMetadata]]] return stats_dict def _finalize_mem_wal(self, write_results: list[WriteResult[list[FragmentMetadata]]]) -> MicroPartition: - dataset = lance.dataset(self._table_uri, storage_options=self._storage_options) + dataset = lance.dataset(self._dataset_uri_arg, storage_options=self._storage_options, **self._namespace_kwargs) if self._compact_after_write: logger.info( @@ -359,7 +405,9 @@ def _finalize_mem_wal(self, write_results: list[WriteResult[list[FragmentMetadat from daft_lance.lance_compaction import compact_files_internal compact_files_internal(dataset) - dataset = lance.dataset(self._table_uri, storage_options=self._storage_options) + dataset = lance.dataset( + self._dataset_uri_arg, storage_options=self._storage_options, **self._namespace_kwargs + ) stats = dataset.stats.dataset_stats() return MicroPartition.from_pydict( diff --git a/daft_lance/lance_scan.py b/daft_lance/lance_scan.py index de3349f..f52de22 100644 --- a/daft_lance/lance_scan.py +++ b/daft_lance/lance_scan.py @@ -16,6 +16,7 @@ from daft.recordbatch import RecordBatch from ._metadata import convert_lance_schema +from .namespace import get_namespace_kwargs from .point_lookup import detect_point_lookup_columns from .utils import combine_filters_to_arrow @@ -40,7 +41,15 @@ def _lancedb_table_factory_function( "Use nearest with fragment_ids=None for index-driven global vector search." ) - ds = lance.dataset(ds_uri, **(open_kwargs or {})) + open_kwargs = dict(open_kwargs or {}) + namespace_impl = open_kwargs.pop("namespace_impl", None) + namespace_properties = open_kwargs.pop("namespace_properties", None) + table_id = open_kwargs.pop("table_id", None) + ds = lance.dataset( + None if namespace_impl is not None and table_id is not None else ds_uri, + **get_namespace_kwargs(namespace_impl, namespace_properties, table_id), + **open_kwargs, + ) def _iter_batches() -> Iterator[PyRecordBatch]: # Iterate fragments individually; append a fragment_id column only when requested @@ -117,7 +126,15 @@ def _lancedb_count_result_function( filter: pa.compute.Expression | None = None, ) -> Iterator[PyRecordBatch]: """Use LanceDB's API to count rows and return a record batch with the count result.""" - ds = lance.dataset(ds_uri, **(open_kwargs or {})) + open_kwargs = dict(open_kwargs or {}) + namespace_impl = open_kwargs.pop("namespace_impl", None) + namespace_properties = open_kwargs.pop("namespace_properties", None) + table_id = open_kwargs.pop("table_id", None) + ds = lance.dataset( + None if namespace_impl is not None and table_id is not None else ds_uri, + **get_namespace_kwargs(namespace_impl, namespace_properties, table_id), + **open_kwargs, + ) logger.debug("Using metadata for counting all rows") count = ds.count_rows(filter=filter) diff --git a/daft_lance/namespace.py b/daft_lance/namespace.py new file mode 100644 index 0000000..2abe99c --- /dev/null +++ b/daft_lance/namespace.py @@ -0,0 +1,217 @@ +from __future__ import annotations + +import os +from functools import lru_cache +from typing import Any +from urllib.parse import urlparse + +import lance + +_NAMESPACE_CACHE_SIZE = int(os.environ.get("DAFT_LANCE_NAMESPACE_CACHE_SIZE", "16")) +_PYLANCE_5 = (5, 0, 0) + + +@lru_cache(maxsize=1) +def _pylance_version() -> tuple[int, ...]: + version = getattr(lance, "__version__", "0.0.0") + parts = [] + for part in version.split(".")[:3]: + digits = "".join(ch for ch in part if ch.isdigit()) + parts.append(int(digits or "0")) + while len(parts) < 3: + parts.append(0) + return tuple(parts) + + +def has_namespace_params(namespace_impl: str | None, table_id: list[str] | None) -> bool: + return namespace_impl is not None and table_id is not None + + +def validate_uri_or_namespace( + uri: str | os.PathLike[str] | None, + namespace_impl: str | None, + table_id: list[str] | None, +) -> None: + has_uri = uri is not None + has_ns = has_namespace_params(namespace_impl, table_id) + + if namespace_impl is not None and table_id is None: + raise ValueError("'table_id' must be provided when 'namespace_impl' is provided.") + if table_id is not None and namespace_impl is None: + raise ValueError("'namespace_impl' must be provided when 'table_id' is provided.") + if has_uri and has_ns: + raise ValueError( + "Cannot provide both 'uri' and namespace parameters. Use either 'uri' OR ('namespace_impl' + 'table_id')." + ) + if not has_uri and not has_ns: + raise ValueError("Must provide either 'uri' OR ('namespace_impl' + 'table_id').") + + +def _normalize_file_uri(location: str) -> str: + parsed = urlparse(location) + if parsed.scheme == "file": + return parsed.path + return location + + +@lru_cache(maxsize=_NAMESPACE_CACHE_SIZE) +def _get_cached_namespace(namespace_impl: str, namespace_properties_tuple: tuple[tuple[str, str], ...] | None) -> Any: + import lance_namespace as ln + + namespace_properties = dict(namespace_properties_tuple) if namespace_properties_tuple else {} + return ln.connect(namespace_impl, namespace_properties) + + +def get_or_create_namespace(namespace_impl: str | None, namespace_properties: dict[str, str] | None) -> Any | None: + if namespace_impl is None: + return None + namespace_properties_tuple = tuple(sorted(namespace_properties.items())) if namespace_properties else None + return _get_cached_namespace(namespace_impl, namespace_properties_tuple) + + +def _create_storage_options_provider( + namespace_impl: str | None, + namespace_properties: dict[str, str] | None, + table_id: list[str] | None, +) -> Any | None: + if not has_namespace_params(namespace_impl, table_id): + return None + namespace = get_or_create_namespace(namespace_impl, namespace_properties) + if namespace is None or not hasattr(lance, "LanceNamespaceStorageOptionsProvider"): + return None + return lance.LanceNamespaceStorageOptionsProvider(namespace=namespace, table_id=table_id) + + +def get_namespace_kwargs( + namespace_impl: str | None, + namespace_properties: dict[str, str] | None, + table_id: list[str] | None, +) -> dict[str, Any]: + if not has_namespace_params(namespace_impl, table_id): + return {} + + namespace = get_or_create_namespace(namespace_impl, namespace_properties) + if namespace is None: + return {} + + kwargs: dict[str, Any] = {"table_id": table_id} + if _pylance_version() >= _PYLANCE_5: + kwargs["namespace_client"] = namespace + else: + kwargs["namespace"] = namespace + provider = _create_storage_options_provider(namespace_impl, namespace_properties, table_id) + if provider is not None: + kwargs["storage_options_provider"] = provider + return kwargs + + +def get_write_fragments_kwargs( + namespace_impl: str | None, + namespace_properties: dict[str, str] | None, + table_id: list[str] | None, +) -> dict[str, Any]: + if not has_namespace_params(namespace_impl, table_id): + return {} + if _pylance_version() >= _PYLANCE_5: + namespace = get_or_create_namespace(namespace_impl, namespace_properties) + if namespace is None: + return {} + return {"namespace_client": namespace, "table_id": table_id} + provider = _create_storage_options_provider(namespace_impl, namespace_properties, table_id) + return {"storage_options_provider": provider} if provider is not None else {} + + +def _storage_options(response: Any) -> dict[str, str] | None: + storage_options = getattr(response, "storage_options", None) + if storage_options is None: + return None + return dict(storage_options) + + +def _response_location(response: Any) -> str: + location = getattr(response, "location", None) or getattr(response, "table_uri", None) + if not location: + raise ValueError("Namespace response did not include a table location.") + return _normalize_file_uri(str(location)) + + +def _declare_table(namespace: Any, table_id: list[str]) -> tuple[str, dict[str, str] | None]: + try: + from lance_namespace import DeclareTableRequest + + response = namespace.declare_table(DeclareTableRequest(id=table_id, location=None)) + return _response_location(response), _storage_options(response) + except (AttributeError, NotImplementedError): + from lance_namespace import CreateEmptyTableRequest + + response = namespace.create_empty_table(CreateEmptyTableRequest(id=table_id)) + return _response_location(response), _storage_options(response) + + +def resolve_namespace_table( + *, + namespace_impl: str | None, + namespace_properties: dict[str, str] | None, + table_id: list[str] | None, + mode: str = "read", +) -> tuple[str | None, dict[str, str] | None]: + namespace = get_or_create_namespace(namespace_impl, namespace_properties) + if namespace is None or table_id is None: + return None, None + + if mode == "create": + return _declare_table(namespace, table_id) + + from lance_namespace import DescribeTableRequest + + try: + response = namespace.describe_table(DescribeTableRequest(id=table_id)) + return _response_location(response), _storage_options(response) + except Exception: + if mode == "overwrite": + return _declare_table(namespace, table_id) + raise + + +def merge_storage_options( + storage_options: dict[str, Any] | None, + namespace_storage_options: dict[str, str] | None, +) -> dict[str, Any] | None: + merged: dict[str, Any] = {} + if storage_options: + merged.update(storage_options) + if namespace_storage_options: + merged.update(namespace_storage_options) + return merged or None + + +def open_lance_dataset( + uri: str | os.PathLike[str] | None, + *, + storage_options: dict[str, Any] | None = None, + namespace_impl: str | None = None, + namespace_properties: dict[str, str] | None = None, + table_id: list[str] | None = None, + mode: str = "read", + **kwargs: Any, +) -> lance.LanceDataset: + validate_uri_or_namespace(uri, namespace_impl, table_id) + resolved_uri = str(uri) if uri is not None else None + namespace_storage_options = None + if resolved_uri is None: + resolved_uri, namespace_storage_options = resolve_namespace_table( + namespace_impl=namespace_impl, + namespace_properties=namespace_properties, + table_id=table_id, + mode=mode, + ) + if resolved_uri is None: + raise ValueError("Unable to resolve Lance dataset URI.") + + merged_storage_options = merge_storage_options(storage_options, namespace_storage_options) + return lance.dataset( + None if has_namespace_params(namespace_impl, table_id) else resolved_uri, + storage_options=merged_storage_options, + **get_namespace_kwargs(namespace_impl, namespace_properties, table_id), + **kwargs, + ) diff --git a/daft_lance/utils.py b/daft_lance/utils.py index 0488262..799fe4c 100644 --- a/daft_lance/utils.py +++ b/daft_lance/utils.py @@ -7,6 +7,13 @@ from daft.dependencies import pa from daft.logical.schema import Schema as DaftSchema +from daft_lance.namespace import ( + get_namespace_kwargs, + has_namespace_params, + merge_storage_options, + resolve_namespace_table, + validate_uri_or_namespace, +) if TYPE_CHECKING: import pathlib @@ -82,12 +89,29 @@ def distribute_fragments_balanced(fragments: list[Any], fragment_group_size: int def construct_lance_dataset( - uri: str | pathlib.Path, + uri: str | pathlib.Path | None, version: int | str | None = None, storage_options: dict[str, Any] | None = None, + namespace_impl: str | None = None, + namespace_properties: dict[str, str] | None = None, + table_id: list[str] | None = None, **kwargs: Any, ) -> lance.LanceDataset: """Construct a Lance dataset with common options.""" + validate_uri_or_namespace(uri, namespace_impl, table_id) + resolved_uri = str(uri) if uri is not None else None + namespace_storage_options = None + if resolved_uri is None: + resolved_uri, namespace_storage_options = resolve_namespace_table( + namespace_impl=namespace_impl, + namespace_properties=namespace_properties, + table_id=table_id, + mode="read", + ) + if resolved_uri is None: + raise ValueError("Unable to resolve Lance dataset URI.") + merged_storage_options = merge_storage_options(storage_options, namespace_storage_options) + original_default_scan_options = kwargs.pop("default_scan_options", None) safe_default_scan_options = None if isinstance(original_default_scan_options, dict): @@ -98,11 +122,21 @@ def construct_lance_dataset( # Non-dict defaults are forwarded as-is. kwargs["default_scan_options"] = original_default_scan_options - ds = lance.dataset(uri, storage_options=storage_options, version=version, **kwargs) + dataset_uri = None if has_namespace_params(namespace_impl, table_id) else resolved_uri + ds = lance.dataset( + dataset_uri, + storage_options=merged_storage_options, + version=version, + **get_namespace_kwargs(namespace_impl, namespace_properties, table_id), + **kwargs, + ) effective_kwargs = { - "storage_options": storage_options, + "storage_options": merged_storage_options, "version": version, + "namespace_impl": namespace_impl, + "namespace_properties": namespace_properties, + "table_id": table_id, } effective_kwargs.update(kwargs or {}) try: diff --git a/tests/io/lancedb/test_namespace.py b/tests/io/lancedb/test_namespace.py new file mode 100644 index 0000000..c573a56 --- /dev/null +++ b/tests/io/lancedb/test_namespace.py @@ -0,0 +1,84 @@ +from __future__ import annotations + +import pytest + +import daft +import daft_lance + + +def test_namespace_write_read_append_roundtrip(tmp_path): + namespace_properties = {"root": str(tmp_path)} + table_id = ["roundtrip"] + + df1 = daft.from_pydict({"id": [1, 2], "label": ["a", "b"]}) + df2 = daft.from_pydict({"id": [3], "label": ["c"]}) + + df1.write_lance( + namespace_impl="dir", + namespace_properties=namespace_properties, + table_id=table_id, + mode="create", + ).collect() + df2.write_lance( + namespace_impl="dir", + namespace_properties=namespace_properties, + table_id=table_id, + mode="append", + ).collect() + + result = daft.read_lance( + namespace_impl="dir", + namespace_properties=namespace_properties, + table_id=table_id, + ).to_pydict() + + assert result == {"id": [1, 2, 3], "label": ["a", "b", "c"]} + + +def test_namespace_read_supports_pushdowns(tmp_path): + namespace_properties = {"root": str(tmp_path)} + table_id = ["pushdowns"] + + daft.from_pydict( + { + "id": [1, 2, 3], + "label": ["a", "b", "c"], + "score": [10, 20, 30], + } + ).write_lance( + namespace_impl="dir", + namespace_properties=namespace_properties, + table_id=table_id, + mode="create", + ).collect() + + result = ( + daft.read_lance( + namespace_impl="dir", + namespace_properties=namespace_properties, + table_id=table_id, + ) + .where(daft.col("score") > 10) + .select("label") + .to_pydict() + ) + + assert result == {"label": ["b", "c"]} + + +def test_namespace_rejects_uri_and_namespace(tmp_path): + with pytest.raises(ValueError, match="Cannot provide both 'uri' and namespace parameters"): + daft_lance.read_lance( + str(tmp_path / "dataset"), + namespace_impl="dir", + namespace_properties={"root": str(tmp_path)}, + table_id=["tbl"], + ) + + +def test_namespace_requires_table_id(tmp_path): + with pytest.raises(ValueError, match="'table_id' must be provided"): + daft_lance.read_lance( + namespace_impl="dir", + namespace_properties={"root": str(tmp_path)}, + ) diff --git a/tests/io/lancedb/test_namespace_rest_e2e.py b/tests/io/lancedb/test_namespace_rest_e2e.py new file mode 100644 index 0000000..d9f27e9 --- /dev/null +++ b/tests/io/lancedb/test_namespace_rest_e2e.py @@ -0,0 +1,94 @@ +from __future__ import annotations + +import os +import uuid + +import pytest + +import daft +import daft_lance # noqa: F401 - patches daft.read_lance and DataFrame.write_lance. + +pytestmark = pytest.mark.skipif( + os.environ.get("DAFT_LANCE_REST_URI") is None, + reason="Set DAFT_LANCE_REST_URI to run the Lance REST namespace integration test.", +) + + +def test_rest_namespace_write_read_append_roundtrip() -> None: + import lance + import lance_namespace as ln + from lance_namespace import CreateNamespaceRequest, DescribeTableRequest, NamespaceExistsRequest + + namespace_properties = {"uri": os.environ["DAFT_LANCE_REST_URI"]} + catalog = os.environ.get("DAFT_LANCE_REST_CATALOG", "lance_catalog") + schema = os.environ.get("DAFT_LANCE_REST_SCHEMA", "daft_ns_e2e") + table_id = [catalog, schema, f"orders_{uuid.uuid4().hex[:8]}"] + + namespace = ln.connect("rest", namespace_properties) + try: + namespace.namespace_exists(NamespaceExistsRequest(id=[catalog, schema])) + except Exception: + namespace.create_namespace(CreateNamespaceRequest(id=[catalog, schema], mode="CREATE")) + + daft.from_pydict( + { + "id": [1, 2, 3], + "label": ["a", "b", "c"], + "score": [10, 20, 30], + } + ).write_lance( + namespace_impl="rest", + namespace_properties=namespace_properties, + table_id=table_id, + mode="create", + ).collect() + + daft.from_pydict( + { + "id": [4, 5], + "label": ["d", "e"], + "score": [40, 50], + } + ).write_lance( + namespace_impl="rest", + namespace_properties=namespace_properties, + table_id=table_id, + mode="append", + ).collect() + + describe = namespace.describe_table(DescribeTableRequest(id=table_id)) + location = getattr(describe, "location", None) or getattr(describe, "table_uri", None) + assert location + + result = daft.read_lance( + namespace_impl="rest", + namespace_properties=namespace_properties, + table_id=table_id, + ).to_pydict() + assert result == { + "id": [1, 2, 3, 4, 5], + "label": ["a", "b", "c", "d", "e"], + "score": [10, 20, 30, 40, 50], + } + + filtered = ( + daft.read_lance( + namespace_impl="rest", + namespace_properties=namespace_properties, + table_id=table_id, + ) + .where(daft.col("score") >= 30) + .select("id", "label") + .to_pydict() + ) + assert filtered == {"id": [3, 4, 5], "label": ["c", "d", "e"]} + + assert ( + daft.read_lance( + namespace_impl="rest", + namespace_properties=namespace_properties, + table_id=table_id, + ).count_rows() + == 5 + ) + assert lance.dataset(None, namespace_client=namespace, table_id=table_id).count_rows() == 5