Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
79 changes: 78 additions & 1 deletion daft_lance/lance_compaction.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,21 @@
from __future__ import annotations

import logging
from dataclasses import dataclass
from typing import Any

import lance
from lance import LanceDataset
from lance.optimize import Compaction, CompactionMetrics, CompactionOptions, CompactionTask, RewriteResult

import daft
from daft.dependencies import pa
from daft_lance._blob import detect_blob_v2_columns

logger = logging.getLogger(__name__)

_VALID_COMPACTION_OPTION_KEYS = set(CompactionOptions.__annotations__) | {"materialize_deletions_threshold"}


class CompactionTaskUDF:
"""UDF to execute a batch of Lance CompactionTasks on remote workers and return execution result dictionaries."""
Expand All @@ -25,15 +31,86 @@ def __call__(self, task: CompactionTask) -> RewriteResult:
return rewrite


@dataclass(frozen=True)
class BlobV2CompactionMetrics:
fragments_removed: int
fragments_added: int
files_removed: int
files_added: int


def _compact_blob_v2_dataset(
lance_ds: LanceDataset,
blob_v2_columns: list[str],
compaction_options: dict[str, Any] | None,
) -> BlobV2CompactionMetrics | None:
"""Compact Blob V2 datasets by materializing visible rows and rewriting them."""
options = compaction_options or {}
unknown_options = set(options) - _VALID_COMPACTION_OPTION_KEYS
if unknown_options:
raise ValueError(f"Invalid compaction options: {sorted(unknown_options)}")

if lance_ds.version != lance_ds.latest_version:
raise ValueError("Blob V2 compaction fallback only supports compacting the latest dataset version")

fragments_before = len(lance_ds.get_fragments())
deletions_before = sum(fragment.num_deletions for fragment in lance_ds.get_fragments())
materialize_deletions = options.get("materialize_deletions", True)
if fragments_before <= 1 and (not materialize_deletions or deletions_before == 0):
logger.info("No Blob V2 compaction needed")
return None

blob_column_set = set(blob_v2_columns)
non_blob_columns = [field.name for field in lance_ds.schema if field.name not in blob_column_set]
visible_rows = lance_ds.to_table(columns=non_blob_columns, with_row_id=True)
row_ids = visible_rows.column("_rowid").to_pylist()

arrays: list[pa.Array[Any] | pa.ChunkedArray[Any]] = []
fields: list[pa.Field[Any]] = []
for field in lance_ds.schema:
if field.name in blob_column_set:
blobs = lance_ds.take_blobs(field.name, row_ids)
blob_array = lance.blob_array([blob.read() if blob is not None else None for blob in blobs])
arrays.append(blob_array)
fields.append(pa.field(field.name, blob_array.type, nullable=field.nullable, metadata=field.metadata))
else:
arrays.append(visible_rows.column(field.name))
fields.append(field)

table = pa.Table.from_arrays(arrays, schema=pa.schema(fields, metadata=lance_ds.schema.metadata))
files_before = sum(len(fragment.metadata.files) for fragment in lance_ds.get_fragments())
compacted = lance.write_dataset(
table,
lance_ds,
mode="overwrite",
data_storage_version=getattr(lance_ds, "data_storage_version", None) or "2.2",
max_rows_per_file=options.get("target_rows_per_fragment") or options.get("max_rows_per_file") or 1024 * 1024,
max_rows_per_group=options.get("max_rows_per_group") or 1024,
max_bytes_per_file=options.get("max_bytes_per_file") or 90 * 1024 * 1024 * 1024,
)
fragments_after = len(compacted.get_fragments())
files_after = sum(len(fragment.metadata.files) for fragment in compacted.get_fragments())
return BlobV2CompactionMetrics(
fragments_removed=fragments_before,
fragments_added=fragments_after,
files_removed=files_before,
files_added=files_after,
)


def compact_files_internal(
lance_ds: LanceDataset,
*,
compaction_options: dict[str, Any] | None = None,
partition_num: int | None = None,
concurrency: int | None = None,
) -> CompactionMetrics | None:
) -> CompactionMetrics | BlobV2CompactionMetrics | None:
"""Execute Lance file compaction in distributed environment using Daft UDF style."""
logger.info("Starting UDF-style distributed compaction")
blob_v2_columns = detect_blob_v2_columns(lance_ds.schema)
if blob_v2_columns:
return _compact_blob_v2_dataset(lance_ds, blob_v2_columns, compaction_options)

plan = Compaction.plan(
lance_ds,
CompactionOptions(
Expand Down
66 changes: 66 additions & 0 deletions tests/io/lancedb/test_lancedb_compaction.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,3 +140,69 @@ def test_compaction_with_partition_num(tmp_path: Path):
post_rows = dataset.count_rows()
assert post_fragments == 2, "Fragment count should be reduced after compaction"
assert post_rows == pre_rows, "Row count should remain unchanged after compaction"


def _blob_v2_table(ids: list[int], blobs: list[bytes]) -> pa.Table:
return pa.table(
{
"id": pa.array(ids, type=pa.int64()),
"blob": lance.blob_array(blobs),
}
)


def _read_blob_bytes_by_id(uri: str) -> dict[int, bytes]:
ds = lance.dataset(uri)
rows = ds.to_table(columns=["id"], with_row_id=True).to_pylist()
blobs = ds.take_blobs("blob", [row["_rowid"] for row in rows])
return {row["id"]: blob.read() for row, blob in zip(rows, blobs, strict=True)}


def test_blob_v2_compaction_materializes_deletions_and_preserves_bytes(tmp_path: Path):
dataset_path = tmp_path / "test_blob_v2_deletion_compaction"
payloads = {
1: b"inline",
2: b"x" * 100_000,
3: b"y" * 5_000_000,
4: b"survivor",
}
lance.write_dataset(
_blob_v2_table([1, 2], [payloads[1], payloads[2]]),
dataset_path,
data_storage_version="2.2",
max_rows_per_file=2,
)
lance.write_dataset(
_blob_v2_table([3, 4], [payloads[3], payloads[4]]),
dataset_path,
mode="append",
data_storage_version="2.2",
max_rows_per_file=2,
)

ds = lance.dataset(str(dataset_path))
assert len(ds.get_fragments()) == 2
assert _read_blob_bytes_by_id(str(dataset_path)) == payloads
ds.delete("id = 2")

metrics = compact_files(
uri=str(dataset_path),
compaction_options={
"materialize_deletions": True,
"materialize_deletions_threshold": 0.0,
"target_rows_per_fragment": 100,
"num_threads": 1,
},
)

assert metrics is not None
assert getattr(metrics, "fragments_removed", None) == 2
assert getattr(metrics, "fragments_added", None) == 1
ds = lance.dataset(str(dataset_path))
assert len(ds.get_fragments()) == 1
assert ds.count_rows() == 3
assert _read_blob_bytes_by_id(str(dataset_path)) == {
1: payloads[1],
3: payloads[3],
4: payloads[4],
}
Loading