diff --git a/daft_lance/_lance.py b/daft_lance/_lance.py index 21b7c67..d909d60 100644 --- a/daft_lance/_lance.py +++ b/daft_lance/_lance.py @@ -42,6 +42,7 @@ def read_lance( fragment_group_size: int | None = None, include_fragment_id: bool | None = None, checkpoint: CheckpointConfig | None = None, + base_store_params: dict[str, dict[str, str]] | None = None, ) -> DataFrame: """Create a DataFrame from a LanceDB table. @@ -145,6 +146,7 @@ def read_lance( index_cache_size=index_cache_size, default_scan_options=default_scan_options, metadata_cache_size_bytes=metadata_cache_size_bytes, + base_store_params=base_store_params, ) lance_operator = LanceDBScanOperator( diff --git a/daft_lance/lance_data_sink.py b/daft_lance/lance_data_sink.py index 62c36bb..4d9af88 100644 --- a/daft_lance/lance_data_sink.py +++ b/daft_lance/lance_data_sink.py @@ -5,7 +5,7 @@ import uuid import warnings from itertools import chain -from typing import TYPE_CHECKING, Literal +from typing import TYPE_CHECKING, Any, Literal import lance from lance.fragment import FragmentMetadata @@ -53,6 +53,9 @@ def __init__( use_legacy_format: bool | None = None, enable_stable_row_ids: bool = False, storage_options: dict[str, str] | None = None, + initial_bases: list[Any] | None = None, + target_bases: list[str] | None = None, + base_store_params: dict[str, dict[str, str]] | None = None, use_mem_wal: bool = False, compact_after_write: bool = True, ) -> None: @@ -68,6 +71,9 @@ def __init__( if storage_options is not None else io_config_to_storage_options(self._io_config, self._table_uri) ) + self._initial_bases = self._normalize_initial_bases(initial_bases) + self._target_bases = target_bases + self._base_store_params = base_store_params self._init_lance_knobs( max_rows_per_file=max_rows_per_file, max_rows_per_group=max_rows_per_group, @@ -157,6 +163,30 @@ def _init_blob_policy(self, blob_columns: list[str] | None) -> None: self._blob = BlobV2WritePolicy(blob_columns) self._blob.validate_columns_present(self._pyarrow_schema) + @staticmethod + def _normalize_initial_bases(initial_bases: list[Any] | None) -> list[Any] | None: + """Assign non-root base IDs to match lance.write_dataset semantics.""" + if initial_bases is None: + return None + normalized: list[Any] = [] + next_base_id = 1 + for base in initial_bases: + base_id = getattr(base, "id", None) + if base_id: + normalized.append(base) + next_base_id = max(next_base_id, int(base_id) + 1) + continue + normalized.append( + lance.DatasetBasePath( + base.path, + name=base.name, + is_dataset_root=base.is_dataset_root, + id=next_base_id, + ) + ) + next_base_id += 1 + return normalized + def _absorb_existing_dataset(self) -> lance.LanceDataset | None: """Open the existing dataset (if any), set table-state, and validate the requested mode. @@ -167,7 +197,11 @@ def _absorb_existing_dataset(self) -> lance.LanceDataset | None: """ dataset: lance.LanceDataset | None try: - dataset = lance.dataset(self._table_uri, storage_options=self._storage_options) + dataset = lance.dataset( + self._table_uri, + storage_options=self._storage_options, + base_store_params=self._base_store_params, + ) except (ValueError, FileNotFoundError, OSError) as e: # Pinned to the Rust message format; lance has no typed exception. See test_lance_message_format_unchanged. if "was not found" in str(e): @@ -191,6 +225,10 @@ def _absorb_existing_dataset(self) -> lance.LanceDataset | None: if self._mode == "create": raise ValueError("Cannot create a Lance dataset at a location where one already exists.") + if self._mode == "append" and self._initial_bases: + dataset = dataset.add_bases(self._initial_bases) + self._version = dataset.latest_version + if self._mode == "append" and not _pyarrow_schema_castable( blob_aware_schema_for_validation(self._pyarrow_schema, self._table_schema), blob_aware_schema_for_validation(self._table_schema, self._table_schema), @@ -230,6 +268,9 @@ def _write_arrow_table(self, table: pa.Table) -> WriteResult[list[FragmentMetada data_storage_version=self._data_storage_version, use_legacy_format=self._use_legacy_format, enable_stable_row_ids=self._enable_stable_row_ids, + initial_bases=self._initial_bases, + target_bases=self._target_bases, + base_store_params=self._base_store_params, ) # Sum on-disk sizes from fragment metadata. Lance Blob V2 sidecar .blob # files are not tracked in FragmentMetadata.files (out of scope here). @@ -240,7 +281,11 @@ def _write_arrow_table(self, table: pa.Table) -> WriteResult[list[FragmentMetada def _ensure_mem_wal_dataset(self) -> lance.LanceDataset: try: - ds = lance.dataset(self._table_uri, storage_options=self._storage_options) + ds = lance.dataset( + self._table_uri, + storage_options=self._storage_options, + base_store_params=self._base_store_params, + ) except (ValueError, FileNotFoundError, OSError): ds = None @@ -255,6 +300,9 @@ def _ensure_mem_wal_dataset(self) -> lance.LanceDataset: storage_options=self._storage_options, data_storage_version=self._data_storage_version, use_legacy_format=self._use_legacy_format, + initial_bases=self._initial_bases, + target_bases=self._target_bases, + base_store_params=self._base_store_params, ) details = ds.mem_wal_index_details() @@ -326,7 +374,7 @@ def _finalize_cow(self, write_results: list[WriteResult[list[FragmentMetadata]]] operation: lance.LanceOperation.BaseOperation if self._mode == "create" or self._mode == "overwrite": - operation = lance.LanceOperation.Overwrite(self._effective_pyarrow_schema, fragments) + operation = lance.LanceOperation.Overwrite(self._effective_pyarrow_schema, fragments, self._initial_bases) elif self._mode == "append": operation = lance.LanceOperation.Append(fragments) @@ -335,6 +383,7 @@ def _finalize_cow(self, write_results: list[WriteResult[list[FragmentMetadata]]] operation, read_version=self._version, storage_options=self._storage_options, + base_store_params=self._base_store_params, ) stats = dataset.stats.dataset_stats() stats_dict = MicroPartition.from_pydict( @@ -348,7 +397,11 @@ def _finalize_cow(self, write_results: list[WriteResult[list[FragmentMetadata]]] return stats_dict def _finalize_mem_wal(self, write_results: list[WriteResult[list[FragmentMetadata]]]) -> MicroPartition: - dataset = lance.dataset(self._table_uri, storage_options=self._storage_options) + dataset = lance.dataset( + self._table_uri, + storage_options=self._storage_options, + base_store_params=self._base_store_params, + ) if self._compact_after_write: logger.info( @@ -359,7 +412,11 @@ def _finalize_mem_wal(self, write_results: list[WriteResult[list[FragmentMetadat from daft_lance.lance_compaction import compact_files_internal compact_files_internal(dataset) - dataset = lance.dataset(self._table_uri, storage_options=self._storage_options) + dataset = lance.dataset( + self._table_uri, + storage_options=self._storage_options, + base_store_params=self._base_store_params, + ) stats = dataset.stats.dataset_stats() return MicroPartition.from_pydict( diff --git a/tests/io/lancedb/test_lance_blob_v2_write.py b/tests/io/lancedb/test_lance_blob_v2_write.py index b2d996e..e855731 100644 --- a/tests/io/lancedb/test_lance_blob_v2_write.py +++ b/tests/io/lancedb/test_lance_blob_v2_write.py @@ -6,7 +6,7 @@ import lance.blob import pyarrow as pa import pytest -from lance import Blob +from lance import Blob, DatasetBasePath from lance.blob import BlobType from pytest import TempPathFactory @@ -64,6 +64,69 @@ def test_blob_columns_dedicated(lance_path: str) -> None: assert _kinds(lance_path) == [KIND_DEDICATED, KIND_DEDICATED] +def test_blob_columns_dedicated_multi_base(tmp_path) -> None: + dataset_path = tmp_path / "dataset" + blob_base = tmp_path / "blob-base" + blob_base.mkdir() + payload = b"multi-base-payload" * 300_000 + + daft.from_pydict({"id": [1], "data": [payload]}).write_lance( + str(dataset_path), + blob_columns=["data"], + initial_bases=[ + DatasetBasePath( + str(blob_base), + name="blob_base", + is_dataset_root=False, + ) + ], + target_bases=["blob_base"], + ).collect() + + ds = lance.dataset(str(dataset_path)) + fragment_files = ds.get_fragments()[0].metadata.to_json()["files"] + assert fragment_files[0]["base_id"] == 1 + assert list(blob_base.rglob("*.blob")) + row_ids = ds.to_table(columns=[], with_row_id=True).column("_rowid").to_pylist() + assert ds.take_blobs("data", row_ids)[0].read() == payload + + +def test_blob_columns_append_multi_base(tmp_path) -> None: + dataset_path = tmp_path / "dataset" + blob_base = tmp_path / "blob-base" + blob_base.mkdir() + first = b"first-multi-base" * 300_000 + second = b"second-multi-base" * 300_000 + + daft.from_pydict({"id": [1], "data": [first]}).write_lance( + str(dataset_path), + blob_columns=["data"], + initial_bases=[ + DatasetBasePath( + str(blob_base), + name="blob_base", + is_dataset_root=False, + ) + ], + target_bases=["blob_base"], + ).collect() + daft.from_pydict({"id": [2], "data": [second]}).write_lance( + str(dataset_path), + mode="append", + target_bases=["blob_base"], + ).collect() + + ds = lance.dataset(str(dataset_path)) + assert [file["base_id"] for fragment in ds.get_fragments() for file in fragment.metadata.to_json()["files"]] == [ + 1, + 1, + ] + assert len(list(blob_base.rglob("*.blob"))) == 2 + rows = ds.to_table(columns=["id"], with_row_id=True).to_pylist() + blobs = ds.take_blobs("data", [row["_rowid"] for row in rows]) + assert {row["id"]: blob.read() for row, blob in zip(rows, blobs, strict=True)} == {1: first, 2: second} + + def test_blob_columns_default_storage_version(lance_path: str) -> None: """No explicit data_storage_version should still produce a 2.2 dataset.""" df = daft.from_pydict({"id": [1], "data": [b"hi"]})