Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 23 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,29 @@ from daft_lance import merge_columns_df
merge_columns_df(df, "s3://bucket/my_dataset")
```

### Namespace Tables

```python
import daft
import daft_lance # installs daft.read_lance / DataFrame.write_lance namespace support

table_id = ["my_table"]
namespace_properties = {"root": "/tmp/lance_tables"}

df.write_lance(
namespace_impl="dir",
namespace_properties=namespace_properties,
table_id=table_id,
mode="create",
).collect()

df = daft.read_lance(
namespace_impl="dir",
namespace_properties=namespace_properties,
table_id=table_id,
)
```

## Migration

The migration only requires replacing `daft.io.lance` with `daft_lance`.
Expand Down
52 changes: 52 additions & 0 deletions daft_lance/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,7 @@
from __future__ import annotations

from typing import Any, Literal

try:
import daft
except ImportError:
Expand All @@ -11,6 +15,54 @@
merge_columns_df,
read_lance,
)
from .lance_data_sink import LanceDataSink


def _patch_daft_lance_api() -> None:
"""Expose the daft-lance implementation through Daft's convenience APIs."""
daft.read_lance = read_lance # type: ignore[assignment]

from daft.dataframe import DataFrame

original_write_lance = getattr(DataFrame, "write_lance")
if getattr(original_write_lance, "_daft_lance_namespace_patch", False):
return

def write_lance(
self: DataFrame,
uri: Any = None,
mode: Literal["create", "append", "overwrite", "merge"] = "create",
io_config: Any | None = None,
schema: Any | None = None,
left_on: str | None = None,
right_on: str | None = None,
**kwargs: Any,
) -> DataFrame:
if mode == "merge":
if any(k in kwargs for k in ("namespace_impl", "namespace_properties", "table_id")):
raise ValueError("write_lance(mode='merge') does not support namespace parameters yet.")
return original_write_lance(
self,
uri,
mode=mode,
io_config=io_config,
schema=schema,
left_on=left_on,
right_on=right_on,
**kwargs,
)

if schema is None:
schema = self.schema()
sanitized_kwargs = {k: v for k, v in kwargs.items() if k not in ("left_on", "right_on")}
sink = LanceDataSink(uri, schema, mode, io_config, **sanitized_kwargs)
return self.write_sink(sink)

write_lance._daft_lance_namespace_patch = True # type: ignore[attr-defined]
DataFrame.write_lance = write_lance # type: ignore[assignment]


_patch_daft_lance_api()

__all__ = [
"compact_files",
Expand Down
15 changes: 11 additions & 4 deletions daft_lance/_lance.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@

@PublicAPI
def read_lance(
uri: str | pathlib.Path,
uri: str | pathlib.Path | None = None,
io_config: IOConfig | None = None,
version: str | int | None = None,
asof: str | None = None,
Expand All @@ -42,6 +42,10 @@ def read_lance(
fragment_group_size: int | None = None,
include_fragment_id: bool | None = None,
checkpoint: CheckpointConfig | None = None,
*,
table_id: list[str] | None = None,
namespace_impl: str | None = None,
namespace_properties: dict[str, str] | None = None,
) -> DataFrame:
"""Create a DataFrame from a LanceDB table.

Expand Down Expand Up @@ -123,8 +127,8 @@ def read_lance(
>>> df = daft.read_lance("s3://daft-oss-public-data/lance/words-test-dataset", io_config=io_config)
>>> df.show()
"""
uri_str = str(uri)
if uri_str.startswith("rest://"):
uri_str = str(uri) if uri is not None else None
if uri_str is not None and uri_str.startswith("rest://"):
raise ValueError(
"rest:// Lance URIs are no longer supported by daft.read_lance. "
"The previous REST-namespace integration did not match the real "
Expand All @@ -133,11 +137,14 @@ def read_lance(
)

io_config = context.get_context().daft_planning_config.default_io_config if io_config is None else io_config
storage_options = io_config_to_storage_options(io_config, uri_str)
storage_options = io_config_to_storage_options(io_config, uri_str) if uri_str is not None else None

ds = construct_lance_dataset(
uri_str,
storage_options=storage_options,
namespace_impl=namespace_impl,
namespace_properties=namespace_properties,
table_id=table_id,
version=version,
asof=asof,
block_size=block_size,
Expand Down
72 changes: 60 additions & 12 deletions daft_lance/lance_data_sink.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,13 @@
detect_blob_v2_columns,
resolve_storage_version,
)
from daft_lance.namespace import (
get_namespace_kwargs,
get_write_fragments_kwargs,
merge_storage_options,
resolve_namespace_table,
validate_uri_or_namespace,
)

if TYPE_CHECKING:
from collections.abc import Iterator
Expand All @@ -40,11 +47,14 @@ class LanceDataSink(DataSink[list[FragmentMetadata]]):

def __init__(
self,
uri: str | pathlib.Path,
uri: str | pathlib.Path | None,
schema: Schema | pa.Schema,
mode: Literal["create", "append", "overwrite"] = "create",
io_config: IOConfig | None = None,
*,
table_id: list[str] | None = None,
namespace_impl: str | None = None,
namespace_properties: dict[str, str] | None = None,
blob_columns: list[str] | None = None,
max_rows_per_file: int = 1024 * 1024,
max_rows_per_group: int = 1024,
Expand All @@ -57,17 +67,26 @@ def __init__(
compact_after_write: bool = True,
) -> None:
self._reject_unsupported_modes(mode, use_legacy_format)
if not isinstance(uri, (str, pathlib.Path)):
validate_uri_or_namespace(uri, namespace_impl, table_id)
if uri is not None and not isinstance(uri, (str, pathlib.Path)):
raise TypeError(f"Expected URI to be str or pathlib.Path, got {type(uri)}")

self._table_uri = str(uri)
self._mode = mode
self._namespace_impl = namespace_impl
self._namespace_properties = namespace_properties
self._table_id = table_id
self._io_config = get_context().daft_planning_config.default_io_config if io_config is None else io_config
self._storage_options = (
storage_options
if storage_options is not None
else io_config_to_storage_options(self._io_config, self._table_uri)
base_storage_options = (
(
storage_options
if storage_options is not None
else io_config_to_storage_options(self._io_config, str(uri))
)
if uri is not None
else storage_options
)
self._table_uri, namespace_storage_options = self._resolve_table_uri(uri)
self._storage_options = merge_storage_options(base_storage_options, namespace_storage_options)
self._init_lance_knobs(
max_rows_per_file=max_rows_per_file,
max_rows_per_group=max_rows_per_group,
Expand Down Expand Up @@ -109,6 +128,28 @@ def __init__(
]
)

@property
def _namespace_kwargs(self) -> dict[str, object]:
return get_namespace_kwargs(self._namespace_impl, self._namespace_properties, self._table_id)

@property
def _dataset_uri_arg(self) -> str | None:
return None if self._namespace_impl is not None and self._table_id is not None else self._table_uri

def _resolve_table_uri(self, uri: str | pathlib.Path | None) -> tuple[str, dict[str, str] | None]:
if uri is not None:
return str(uri), None
mode = "create" if self._mode == "create" else "overwrite" if self._mode == "overwrite" else "read"
resolved_uri, namespace_storage_options = resolve_namespace_table(
namespace_impl=self._namespace_impl,
namespace_properties=self._namespace_properties,
table_id=self._table_id,
mode=mode,
)
if resolved_uri is None:
raise ValueError("Unable to resolve Lance dataset URI from namespace.")
return resolved_uri, namespace_storage_options

@staticmethod
def _reject_unsupported_modes(
mode: Literal["create", "append", "overwrite"], use_legacy_format: bool | None
Expand Down Expand Up @@ -167,7 +208,9 @@ def _absorb_existing_dataset(self) -> lance.LanceDataset | None:
"""
dataset: lance.LanceDataset | None
try:
dataset = lance.dataset(self._table_uri, storage_options=self._storage_options)
dataset = lance.dataset(
self._dataset_uri_arg, storage_options=self._storage_options, **self._namespace_kwargs
)
except (ValueError, FileNotFoundError, OSError) as e:
# Pinned to the Rust message format; lance has no typed exception. See test_lance_message_format_unchanged.
if "was not found" in str(e):
Expand Down Expand Up @@ -230,6 +273,7 @@ def _write_arrow_table(self, table: pa.Table) -> WriteResult[list[FragmentMetada
data_storage_version=self._data_storage_version,
use_legacy_format=self._use_legacy_format,
enable_stable_row_ids=self._enable_stable_row_ids,
**get_write_fragments_kwargs(self._namespace_impl, self._namespace_properties, self._table_id),
)
# Sum on-disk sizes from fragment metadata. Lance Blob V2 sidecar .blob
# files are not tracked in FragmentMetadata.files (out of scope here).
Expand All @@ -240,7 +284,7 @@ def _write_arrow_table(self, table: pa.Table) -> WriteResult[list[FragmentMetada

def _ensure_mem_wal_dataset(self) -> lance.LanceDataset:
try:
ds = lance.dataset(self._table_uri, storage_options=self._storage_options)
ds = lance.dataset(self._dataset_uri_arg, storage_options=self._storage_options, **self._namespace_kwargs)
except (ValueError, FileNotFoundError, OSError):
ds = None

Expand All @@ -250,11 +294,12 @@ def _ensure_mem_wal_dataset(self) -> lance.LanceDataset:
{f.name: pa.array([], type=f.type) for f in self._effective_pyarrow_schema},
schema=self._effective_pyarrow_schema,
),
self._table_uri,
self._dataset_uri_arg,
mode="create",
storage_options=self._storage_options,
data_storage_version=self._data_storage_version,
use_legacy_format=self._use_legacy_format,
**self._namespace_kwargs,
)

details = ds.mem_wal_index_details()
Expand Down Expand Up @@ -335,6 +380,7 @@ def _finalize_cow(self, write_results: list[WriteResult[list[FragmentMetadata]]]
operation,
read_version=self._version,
storage_options=self._storage_options,
**self._namespace_kwargs,
)
stats = dataset.stats.dataset_stats()
stats_dict = MicroPartition.from_pydict(
Expand All @@ -348,7 +394,7 @@ def _finalize_cow(self, write_results: list[WriteResult[list[FragmentMetadata]]]
return stats_dict

def _finalize_mem_wal(self, write_results: list[WriteResult[list[FragmentMetadata]]]) -> MicroPartition:
dataset = lance.dataset(self._table_uri, storage_options=self._storage_options)
dataset = lance.dataset(self._dataset_uri_arg, storage_options=self._storage_options, **self._namespace_kwargs)

if self._compact_after_write:
logger.info(
Expand All @@ -359,7 +405,9 @@ def _finalize_mem_wal(self, write_results: list[WriteResult[list[FragmentMetadat
from daft_lance.lance_compaction import compact_files_internal

compact_files_internal(dataset)
dataset = lance.dataset(self._table_uri, storage_options=self._storage_options)
dataset = lance.dataset(
self._dataset_uri_arg, storage_options=self._storage_options, **self._namespace_kwargs
)

stats = dataset.stats.dataset_stats()
return MicroPartition.from_pydict(
Expand Down
21 changes: 19 additions & 2 deletions daft_lance/lance_scan.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from daft.recordbatch import RecordBatch

from ._metadata import convert_lance_schema
from .namespace import get_namespace_kwargs
from .point_lookup import detect_point_lookup_columns
from .utils import combine_filters_to_arrow

Expand All @@ -40,7 +41,15 @@ def _lancedb_table_factory_function(
"Use nearest with fragment_ids=None for index-driven global vector search."
)

ds = lance.dataset(ds_uri, **(open_kwargs or {}))
open_kwargs = dict(open_kwargs or {})
namespace_impl = open_kwargs.pop("namespace_impl", None)
namespace_properties = open_kwargs.pop("namespace_properties", None)
table_id = open_kwargs.pop("table_id", None)
ds = lance.dataset(
None if namespace_impl is not None and table_id is not None else ds_uri,
**get_namespace_kwargs(namespace_impl, namespace_properties, table_id),
**open_kwargs,
)

def _iter_batches() -> Iterator[PyRecordBatch]:
# Iterate fragments individually; append a fragment_id column only when requested
Expand Down Expand Up @@ -117,7 +126,15 @@ def _lancedb_count_result_function(
filter: pa.compute.Expression | None = None,
) -> Iterator[PyRecordBatch]:
"""Use LanceDB's API to count rows and return a record batch with the count result."""
ds = lance.dataset(ds_uri, **(open_kwargs or {}))
open_kwargs = dict(open_kwargs or {})
namespace_impl = open_kwargs.pop("namespace_impl", None)
namespace_properties = open_kwargs.pop("namespace_properties", None)
table_id = open_kwargs.pop("table_id", None)
ds = lance.dataset(
None if namespace_impl is not None and table_id is not None else ds_uri,
**get_namespace_kwargs(namespace_impl, namespace_properties, table_id),
**open_kwargs,
)
logger.debug("Using metadata for counting all rows")
count = ds.count_rows(filter=filter)

Expand Down
Loading