Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions daft_lance/_lance.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ def read_lance(
fragment_group_size: int | None = None,
include_fragment_id: bool | None = None,
checkpoint: CheckpointConfig | None = None,
base_store_params: dict[str, dict[str, str]] | None = None,
) -> DataFrame:
"""Create a DataFrame from a LanceDB table.

Expand Down Expand Up @@ -145,6 +146,7 @@ def read_lance(
index_cache_size=index_cache_size,
default_scan_options=default_scan_options,
metadata_cache_size_bytes=metadata_cache_size_bytes,
base_store_params=base_store_params,
)

lance_operator = LanceDBScanOperator(
Expand Down
69 changes: 63 additions & 6 deletions daft_lance/lance_data_sink.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import uuid
import warnings
from itertools import chain
from typing import TYPE_CHECKING, Literal
from typing import TYPE_CHECKING, Any, Literal

import lance
from lance.fragment import FragmentMetadata
Expand Down Expand Up @@ -53,6 +53,9 @@ def __init__(
use_legacy_format: bool | None = None,
enable_stable_row_ids: bool = False,
storage_options: dict[str, str] | None = None,
initial_bases: list[Any] | None = None,
target_bases: list[str] | None = None,
base_store_params: dict[str, dict[str, str]] | None = None,
use_mem_wal: bool = False,
compact_after_write: bool = True,
) -> None:
Expand All @@ -68,6 +71,9 @@ def __init__(
if storage_options is not None
else io_config_to_storage_options(self._io_config, self._table_uri)
)
self._initial_bases = self._normalize_initial_bases(initial_bases)
self._target_bases = target_bases
self._base_store_params = base_store_params
self._init_lance_knobs(
max_rows_per_file=max_rows_per_file,
max_rows_per_group=max_rows_per_group,
Expand Down Expand Up @@ -157,6 +163,30 @@ def _init_blob_policy(self, blob_columns: list[str] | None) -> None:
self._blob = BlobV2WritePolicy(blob_columns)
self._blob.validate_columns_present(self._pyarrow_schema)

@staticmethod
def _normalize_initial_bases(initial_bases: list[Any] | None) -> list[Any] | None:
"""Assign non-root base IDs to match lance.write_dataset semantics."""
if initial_bases is None:
return None
normalized: list[Any] = []
next_base_id = 1
for base in initial_bases:
base_id = getattr(base, "id", None)
if base_id:
normalized.append(base)
next_base_id = max(next_base_id, int(base_id) + 1)
continue
normalized.append(
lance.DatasetBasePath(
base.path,
name=base.name,
is_dataset_root=base.is_dataset_root,
id=next_base_id,
)
)
next_base_id += 1
return normalized

def _absorb_existing_dataset(self) -> lance.LanceDataset | None:
"""Open the existing dataset (if any), set table-state, and validate the requested mode.

Expand All @@ -167,7 +197,11 @@ def _absorb_existing_dataset(self) -> lance.LanceDataset | None:
"""
dataset: lance.LanceDataset | None
try:
dataset = lance.dataset(self._table_uri, storage_options=self._storage_options)
dataset = lance.dataset(
self._table_uri,
storage_options=self._storage_options,
base_store_params=self._base_store_params,
)
except (ValueError, FileNotFoundError, OSError) as e:
# Pinned to the Rust message format; lance has no typed exception. See test_lance_message_format_unchanged.
if "was not found" in str(e):
Expand All @@ -191,6 +225,10 @@ def _absorb_existing_dataset(self) -> lance.LanceDataset | None:
if self._mode == "create":
raise ValueError("Cannot create a Lance dataset at a location where one already exists.")

if self._mode == "append" and self._initial_bases:
dataset = dataset.add_bases(self._initial_bases)
self._version = dataset.latest_version

if self._mode == "append" and not _pyarrow_schema_castable(
blob_aware_schema_for_validation(self._pyarrow_schema, self._table_schema),
blob_aware_schema_for_validation(self._table_schema, self._table_schema),
Expand Down Expand Up @@ -230,6 +268,9 @@ def _write_arrow_table(self, table: pa.Table) -> WriteResult[list[FragmentMetada
data_storage_version=self._data_storage_version,
use_legacy_format=self._use_legacy_format,
enable_stable_row_ids=self._enable_stable_row_ids,
initial_bases=self._initial_bases,
target_bases=self._target_bases,
base_store_params=self._base_store_params,
)
# Sum on-disk sizes from fragment metadata. Lance Blob V2 sidecar .blob
# files are not tracked in FragmentMetadata.files (out of scope here).
Expand All @@ -240,7 +281,11 @@ def _write_arrow_table(self, table: pa.Table) -> WriteResult[list[FragmentMetada

def _ensure_mem_wal_dataset(self) -> lance.LanceDataset:
try:
ds = lance.dataset(self._table_uri, storage_options=self._storage_options)
ds = lance.dataset(
self._table_uri,
storage_options=self._storage_options,
base_store_params=self._base_store_params,
)
except (ValueError, FileNotFoundError, OSError):
ds = None

Expand All @@ -255,6 +300,9 @@ def _ensure_mem_wal_dataset(self) -> lance.LanceDataset:
storage_options=self._storage_options,
data_storage_version=self._data_storage_version,
use_legacy_format=self._use_legacy_format,
initial_bases=self._initial_bases,
target_bases=self._target_bases,
base_store_params=self._base_store_params,
)

details = ds.mem_wal_index_details()
Expand Down Expand Up @@ -326,7 +374,7 @@ def _finalize_cow(self, write_results: list[WriteResult[list[FragmentMetadata]]]

operation: lance.LanceOperation.BaseOperation
if self._mode == "create" or self._mode == "overwrite":
operation = lance.LanceOperation.Overwrite(self._effective_pyarrow_schema, fragments)
operation = lance.LanceOperation.Overwrite(self._effective_pyarrow_schema, fragments, self._initial_bases)
elif self._mode == "append":
operation = lance.LanceOperation.Append(fragments)

Expand All @@ -335,6 +383,7 @@ def _finalize_cow(self, write_results: list[WriteResult[list[FragmentMetadata]]]
operation,
read_version=self._version,
storage_options=self._storage_options,
base_store_params=self._base_store_params,
)
stats = dataset.stats.dataset_stats()
stats_dict = MicroPartition.from_pydict(
Expand All @@ -348,7 +397,11 @@ def _finalize_cow(self, write_results: list[WriteResult[list[FragmentMetadata]]]
return stats_dict

def _finalize_mem_wal(self, write_results: list[WriteResult[list[FragmentMetadata]]]) -> MicroPartition:
dataset = lance.dataset(self._table_uri, storage_options=self._storage_options)
dataset = lance.dataset(
self._table_uri,
storage_options=self._storage_options,
base_store_params=self._base_store_params,
)

if self._compact_after_write:
logger.info(
Expand All @@ -359,7 +412,11 @@ def _finalize_mem_wal(self, write_results: list[WriteResult[list[FragmentMetadat
from daft_lance.lance_compaction import compact_files_internal

compact_files_internal(dataset)
dataset = lance.dataset(self._table_uri, storage_options=self._storage_options)
dataset = lance.dataset(
self._table_uri,
storage_options=self._storage_options,
base_store_params=self._base_store_params,
)

stats = dataset.stats.dataset_stats()
return MicroPartition.from_pydict(
Expand Down
65 changes: 64 additions & 1 deletion tests/io/lancedb/test_lance_blob_v2_write.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import lance.blob
import pyarrow as pa
import pytest
from lance import Blob
from lance import Blob, DatasetBasePath
from lance.blob import BlobType
from pytest import TempPathFactory

Expand Down Expand Up @@ -64,6 +64,69 @@ def test_blob_columns_dedicated(lance_path: str) -> None:
assert _kinds(lance_path) == [KIND_DEDICATED, KIND_DEDICATED]


def test_blob_columns_dedicated_multi_base(tmp_path) -> None:
dataset_path = tmp_path / "dataset"
blob_base = tmp_path / "blob-base"
blob_base.mkdir()
payload = b"multi-base-payload" * 300_000

daft.from_pydict({"id": [1], "data": [payload]}).write_lance(
str(dataset_path),
blob_columns=["data"],
initial_bases=[
DatasetBasePath(
str(blob_base),
name="blob_base",
is_dataset_root=False,
)
],
target_bases=["blob_base"],
).collect()

ds = lance.dataset(str(dataset_path))
fragment_files = ds.get_fragments()[0].metadata.to_json()["files"]
assert fragment_files[0]["base_id"] == 1
assert list(blob_base.rglob("*.blob"))
row_ids = ds.to_table(columns=[], with_row_id=True).column("_rowid").to_pylist()
assert ds.take_blobs("data", row_ids)[0].read() == payload


def test_blob_columns_append_multi_base(tmp_path) -> None:
dataset_path = tmp_path / "dataset"
blob_base = tmp_path / "blob-base"
blob_base.mkdir()
first = b"first-multi-base" * 300_000
second = b"second-multi-base" * 300_000

daft.from_pydict({"id": [1], "data": [first]}).write_lance(
str(dataset_path),
blob_columns=["data"],
initial_bases=[
DatasetBasePath(
str(blob_base),
name="blob_base",
is_dataset_root=False,
)
],
target_bases=["blob_base"],
).collect()
daft.from_pydict({"id": [2], "data": [second]}).write_lance(
str(dataset_path),
mode="append",
target_bases=["blob_base"],
).collect()

ds = lance.dataset(str(dataset_path))
assert [file["base_id"] for fragment in ds.get_fragments() for file in fragment.metadata.to_json()["files"]] == [
1,
1,
]
assert len(list(blob_base.rglob("*.blob"))) == 2
rows = ds.to_table(columns=["id"], with_row_id=True).to_pylist()
blobs = ds.take_blobs("data", [row["_rowid"] for row in rows])
assert {row["id"]: blob.read() for row, blob in zip(rows, blobs, strict=True)} == {1: first, 2: second}


def test_blob_columns_default_storage_version(lance_path: str) -> None:
"""No explicit data_storage_version should still produce a 2.2 dataset."""
df = daft.from_pydict({"id": [1], "data": [b"hi"]})
Expand Down
Loading