diff --git a/daft_lance/_lance.py b/daft_lance/_lance.py index 21b7c67..8bc1f15 100644 --- a/daft_lance/_lance.py +++ b/daft_lance/_lance.py @@ -270,6 +270,7 @@ def merge_columns_df( batch_size: int | None = None, left_on: str | None = "_rowaddr", right_on: str | None = "_rowaddr", + blob_columns: list[str] | None = None, ) -> Any: """Row-level merge columns entrypoint using a DataFrame. @@ -351,6 +352,7 @@ def merge_columns_df( left_on=left_on, right_on=effective_right_on, batch_size=effective_batch_size, + blob_columns=blob_columns, ) diff --git a/daft_lance/lance_merge_column.py b/daft_lance/lance_merge_column.py index c0e0cb9..991163b 100644 --- a/daft_lance/lance_merge_column.py +++ b/daft_lance/lance_merge_column.py @@ -106,6 +106,7 @@ def __init__( read_columns: list[str] | None = None, reader_schema: pa.Schema | None = None, batch_size: int | None = 9223372036854775807, + blob_columns: list[str] | None = None, ): """Per-group merge handler that directly invokes Lance fragment.merge with keyed join. @@ -123,6 +124,7 @@ def __init__( self.read_columns = read_columns or [] self.reader_schema = reader_schema self.batch_size = batch_size + self.blob_columns = set(blob_columns or []) @method.batch(return_dtype=_FRAGMENT_HANDLER_RETURN_DTYPE) def __call__(self, *cols: Any) -> list[dict[str, bytes]]: @@ -158,7 +160,7 @@ def __call__(self, *cols: Any) -> list[dict[str, bytes]]: # Convert all arrays to a consistent type to avoid mypy errors arrays.append(key_arr.cast(_pa.int64())) else: - arr = _pa.array(pylist) + arr = lance.blob_array(pylist) if col_name in self.blob_columns else _pa.array(pylist) if _pa.types.is_floating(arr.type): arrays.append(arr) elif _pa.types.is_integer(arr.type): @@ -243,11 +245,13 @@ def __init__( uri: str, new_column_names: list[str], storage_options: dict[str, str] | None = None, + blob_columns: list[str] | None = None, ): self.lance_ds = lance_ds self.uri = str(uri) self.new_column_names = new_column_names self.storage_options = storage_options + self.blob_columns = set(blob_columns or []) @method.batch(return_dtype=_FRAGMENT_HANDLER_RETURN_DTYPE) def __call__(self, *cols: Any) -> list[dict[str, bytes]]: @@ -269,8 +273,9 @@ def __call__(self, *cols: Any) -> list[dict[str, bytes]]: # Build table of new columns arrays = [] - for s in data_cols: - arr = _pa.array(s.to_pylist() if hasattr(s, "to_pylist") else list(s)) + for col_name, s in zip(self.new_column_names, data_cols): + pylist = s.to_pylist() if hasattr(s, "to_pylist") else list(s) + arr = lance.blob_array(pylist) if col_name in self.blob_columns else _pa.array(pylist) arrays.append(arr) tbl = _pa.table({name: arr for name, arr in zip(self.new_column_names, arrays)}) @@ -360,6 +365,7 @@ def merge_columns_from_df( left_on: str | None = "_rowaddr", right_on: str | None = None, batch_size: int | None = 9223372036854775807, + blob_columns: list[str] | None = None, ) -> lance.LanceDataset: # Validate required keys if "fragment_id" not in df.column_names: @@ -396,10 +402,10 @@ def merge_columns_from_df( read_columns = [join_key] + new_cols # Decide: fast path (raw file write) or slow path (keyed join) - use_fast_path = _can_use_fast_path(df, lance_ds, join_key) + use_fast_path = _can_use_fast_path(df, lance_ds, join_key) and not blob_columns if use_fast_path: - return _merge_fast_path(df, lance_ds, uri, new_cols, storage_options=storage_options) + return _merge_fast_path(df, lance_ds, uri, new_cols, storage_options=storage_options, blob_columns=blob_columns) else: return _merge_slow_path( df, @@ -411,6 +417,7 @@ def merge_columns_from_df( reader_schema, batch_size, storage_options=storage_options, + blob_columns=blob_columns, ) @@ -420,9 +427,16 @@ def _merge_fast_path( uri: str | pathlib.Path, new_column_names: list[str], storage_options: dict[str, Any] | None = None, + blob_columns: list[str] | None = None, ) -> lance.LanceDataset: """Metadata-only add_columns: write raw .lance files and stitch into fragment metadata.""" - handler = FastPathFragmentWriter(lance_ds, str(uri), new_column_names, storage_options=storage_options) + handler = FastPathFragmentWriter( + lance_ds, + str(uri), + new_column_names, + storage_options=storage_options, + blob_columns=blob_columns, + ) grouped = df.groupby("fragment_id").map_groups( handler(*(df[c] for c in new_column_names), df["_rowaddr"], df["fragment_id"]).alias("commit_message") # type: ignore[attr-defined] @@ -472,6 +486,7 @@ def _merge_slow_path( reader_schema: pa.Schema | None, batch_size: int | None, storage_options: dict[str, Any] | None = None, + blob_columns: list[str] | None = None, ) -> lance.LanceDataset: """Original keyed-join merge path: rewrites fragment data.""" handler_udf = GroupFragmentMergeUDF( @@ -481,6 +496,7 @@ def _merge_slow_path( read_columns, reader_schema, batch_size, + blob_columns, ) grouped = df.groupby("fragment_id").map_groups( diff --git a/tests/io/lancedb/test_fast_path_merge.py b/tests/io/lancedb/test_fast_path_merge.py index cde4235..aaa83ac 100644 --- a/tests/io/lancedb/test_fast_path_merge.py +++ b/tests/io/lancedb/test_fast_path_merge.py @@ -14,8 +14,10 @@ import lance import pyarrow as pa import pytest +from lance.blob import BlobType import daft +import daft_lance from daft_lance.lance_merge_column import _can_use_fast_path, merge_columns_from_df # --------------------------------------------------------------------------- @@ -127,6 +129,31 @@ def test_fast_path_computed_column(self, ds_path): for x, y, z in zip(result["x"], result["y"], result["z"]): assert x + y == z + def test_fast_path_adds_blob_v2_column(self, ds_path): + lance.write_dataset( + pa.table({"id": [1, 2, 3], "val": [10, 20, 30]}), + ds_path, + data_storage_version="2.2", + ) + df = read_with_metadata(ds_path) + + @daft.func.batch(return_dtype=daft.DataType.binary()) + def make_payload(ids): # type: ignore[no-untyped-def] + return pa.array([f"payload-{id_}".encode() for id_ in ids.to_pylist()], type=pa.large_binary()) + + df = df.with_column("payload", make_payload(df["id"])) + ds2 = daft_lance.merge_columns_df(df, ds_path, blob_columns=["payload"]) + + field = ds2.schema.field("payload") + assert isinstance(field.type, BlobType) + rows = ds2.to_table(columns=["id"], with_row_id=True).to_pylist() + blobs = ds2.take_blobs("payload", [row["_rowid"] for row in rows]) + assert {row["id"]: blob.read() for row, blob in zip(rows, blobs, strict=True)} == { + 1: b"payload-1", + 2: b"payload-2", + 3: b"payload-3", + } + # --------------------------------------------------------------------------- # 2. Fragment integrity