Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions daft_lance/_lance.py
Original file line number Diff line number Diff line change
Expand Up @@ -270,6 +270,7 @@ def merge_columns_df(
batch_size: int | None = None,
left_on: str | None = "_rowaddr",
right_on: str | None = "_rowaddr",
blob_columns: list[str] | None = None,
) -> Any:
"""Row-level merge columns entrypoint using a DataFrame.

Expand Down Expand Up @@ -351,6 +352,7 @@ def merge_columns_df(
left_on=left_on,
right_on=effective_right_on,
batch_size=effective_batch_size,
blob_columns=blob_columns,
)


Expand Down
28 changes: 22 additions & 6 deletions daft_lance/lance_merge_column.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,7 @@ def __init__(
read_columns: list[str] | None = None,
reader_schema: pa.Schema | None = None,
batch_size: int | None = 9223372036854775807,
blob_columns: list[str] | None = None,
):
"""Per-group merge handler that directly invokes Lance fragment.merge with keyed join.

Expand All @@ -123,6 +124,7 @@ def __init__(
self.read_columns = read_columns or []
self.reader_schema = reader_schema
self.batch_size = batch_size
self.blob_columns = set(blob_columns or [])

@method.batch(return_dtype=_FRAGMENT_HANDLER_RETURN_DTYPE)
def __call__(self, *cols: Any) -> list[dict[str, bytes]]:
Expand Down Expand Up @@ -158,7 +160,7 @@ def __call__(self, *cols: Any) -> list[dict[str, bytes]]:
# Convert all arrays to a consistent type to avoid mypy errors
arrays.append(key_arr.cast(_pa.int64()))
else:
arr = _pa.array(pylist)
arr = lance.blob_array(pylist) if col_name in self.blob_columns else _pa.array(pylist)
if _pa.types.is_floating(arr.type):
arrays.append(arr)
elif _pa.types.is_integer(arr.type):
Expand Down Expand Up @@ -243,11 +245,13 @@ def __init__(
uri: str,
new_column_names: list[str],
storage_options: dict[str, str] | None = None,
blob_columns: list[str] | None = None,
):
self.lance_ds = lance_ds
self.uri = str(uri)
self.new_column_names = new_column_names
self.storage_options = storage_options
self.blob_columns = set(blob_columns or [])

@method.batch(return_dtype=_FRAGMENT_HANDLER_RETURN_DTYPE)
def __call__(self, *cols: Any) -> list[dict[str, bytes]]:
Expand All @@ -269,8 +273,9 @@ def __call__(self, *cols: Any) -> list[dict[str, bytes]]:

# Build table of new columns
arrays = []
for s in data_cols:
arr = _pa.array(s.to_pylist() if hasattr(s, "to_pylist") else list(s))
for col_name, s in zip(self.new_column_names, data_cols):
pylist = s.to_pylist() if hasattr(s, "to_pylist") else list(s)
arr = lance.blob_array(pylist) if col_name in self.blob_columns else _pa.array(pylist)
arrays.append(arr)
tbl = _pa.table({name: arr for name, arr in zip(self.new_column_names, arrays)})

Expand Down Expand Up @@ -360,6 +365,7 @@ def merge_columns_from_df(
left_on: str | None = "_rowaddr",
right_on: str | None = None,
batch_size: int | None = 9223372036854775807,
blob_columns: list[str] | None = None,
) -> lance.LanceDataset:
# Validate required keys
if "fragment_id" not in df.column_names:
Expand Down Expand Up @@ -396,10 +402,10 @@ def merge_columns_from_df(
read_columns = [join_key] + new_cols

# Decide: fast path (raw file write) or slow path (keyed join)
use_fast_path = _can_use_fast_path(df, lance_ds, join_key)
use_fast_path = _can_use_fast_path(df, lance_ds, join_key) and not blob_columns

if use_fast_path:
return _merge_fast_path(df, lance_ds, uri, new_cols, storage_options=storage_options)
return _merge_fast_path(df, lance_ds, uri, new_cols, storage_options=storage_options, blob_columns=blob_columns)
else:
return _merge_slow_path(
df,
Expand All @@ -411,6 +417,7 @@ def merge_columns_from_df(
reader_schema,
batch_size,
storage_options=storage_options,
blob_columns=blob_columns,
)


Expand All @@ -420,9 +427,16 @@ def _merge_fast_path(
uri: str | pathlib.Path,
new_column_names: list[str],
storage_options: dict[str, Any] | None = None,
blob_columns: list[str] | None = None,
) -> lance.LanceDataset:
"""Metadata-only add_columns: write raw .lance files and stitch into fragment metadata."""
handler = FastPathFragmentWriter(lance_ds, str(uri), new_column_names, storage_options=storage_options)
handler = FastPathFragmentWriter(
lance_ds,
str(uri),
new_column_names,
storage_options=storage_options,
blob_columns=blob_columns,
)

grouped = df.groupby("fragment_id").map_groups(
handler(*(df[c] for c in new_column_names), df["_rowaddr"], df["fragment_id"]).alias("commit_message") # type: ignore[attr-defined]
Expand Down Expand Up @@ -472,6 +486,7 @@ def _merge_slow_path(
reader_schema: pa.Schema | None,
batch_size: int | None,
storage_options: dict[str, Any] | None = None,
blob_columns: list[str] | None = None,
) -> lance.LanceDataset:
"""Original keyed-join merge path: rewrites fragment data."""
handler_udf = GroupFragmentMergeUDF(
Expand All @@ -481,6 +496,7 @@ def _merge_slow_path(
read_columns,
reader_schema,
batch_size,
blob_columns,
)

grouped = df.groupby("fragment_id").map_groups(
Expand Down
27 changes: 27 additions & 0 deletions tests/io/lancedb/test_fast_path_merge.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,10 @@
import lance
import pyarrow as pa
import pytest
from lance.blob import BlobType

import daft
import daft_lance
from daft_lance.lance_merge_column import _can_use_fast_path, merge_columns_from_df

# ---------------------------------------------------------------------------
Expand Down Expand Up @@ -127,6 +129,31 @@ def test_fast_path_computed_column(self, ds_path):
for x, y, z in zip(result["x"], result["y"], result["z"]):
assert x + y == z

def test_fast_path_adds_blob_v2_column(self, ds_path):
lance.write_dataset(
pa.table({"id": [1, 2, 3], "val": [10, 20, 30]}),
ds_path,
data_storage_version="2.2",
)
df = read_with_metadata(ds_path)

@daft.func.batch(return_dtype=daft.DataType.binary())
def make_payload(ids): # type: ignore[no-untyped-def]
return pa.array([f"payload-{id_}".encode() for id_ in ids.to_pylist()], type=pa.large_binary())

df = df.with_column("payload", make_payload(df["id"]))
ds2 = daft_lance.merge_columns_df(df, ds_path, blob_columns=["payload"])

field = ds2.schema.field("payload")
assert isinstance(field.type, BlobType)
rows = ds2.to_table(columns=["id"], with_row_id=True).to_pylist()
blobs = ds2.take_blobs("payload", [row["_rowid"] for row in rows])
assert {row["id"]: blob.read() for row, blob in zip(rows, blobs, strict=True)} == {
1: b"payload-1",
2: b"payload-2",
3: b"payload-3",
}


# ---------------------------------------------------------------------------
# 2. Fragment integrity
Expand Down
Loading