diff --git a/daft_lance/__init__.py b/daft_lance/__init__.py index cb07862..26b50bd 100644 --- a/daft_lance/__init__.py +++ b/daft_lance/__init__.py @@ -7,16 +7,20 @@ from ._lance import ( compact_files, create_scalar_index, + delete_from_lance, merge_columns, merge_columns_df, read_lance, + update_lance, ) __all__ = [ "compact_files", "create_scalar_index", + "delete_from_lance", "merge_columns", "merge_columns_df", "read_lance", "take_blobs", + "update_lance", ] diff --git a/daft_lance/_lance.py b/daft_lance/_lance.py index 95a67bf..17c827b 100644 --- a/daft_lance/_lance.py +++ b/daft_lance/_lance.py @@ -551,3 +551,72 @@ def compact_files( partition_num=partition_num, concurrency=concurrency, ) + + +@PublicAPI +def update_lance( + uri: str | pathlib.Path, + updates: dict[str, str], + *, + where: str | None = None, + io_config: IOConfig | None = None, +) -> dict[str, int]: + """Update rows in a Lance dataset matching the given SQL predicate. + + Args: + uri: The URI of the Lance dataset. Accepts a local path or an + object-store URI like ``"s3://bucket/path"``. + updates: Mapping of column names to SQL expressions, + e.g. ``{"age": "age + 1", "name": "'updated'"}``. + where: Optional SQL predicate indicating which rows to update, + e.g. ``"age > 30"``. If not provided, all rows are updated. + io_config: Optional IOConfig to use when accessing Lance data. + + Returns: + dict with a ``"num_rows_updated"`` key. + + Example: + >>> import daft_lance + >>> result = daft_lance.update_lance( + ... "/path/to/dataset", updates={"age": "age + 1"}, where="age > 40" + ... ) + >>> print(result["num_rows_updated"]) + + See Also: + `Lance docs — Updating rows `_ + """ + io_config = context.get_context().daft_planning_config.default_io_config if io_config is None else io_config + storage_options = io_config_to_storage_options(io_config, str(uri) if isinstance(uri, pathlib.Path) else uri) + + dataset = lance.dataset(uri, storage_options=storage_options) + return dataset.update(updates, where=where) + + +@PublicAPI +def delete_from_lance( + uri: str | pathlib.Path, + where: str, + *, + io_config: IOConfig | None = None, +) -> None: + """Delete rows from a Lance dataset matching the given SQL predicate. + + Args: + uri: The URI of the Lance dataset. Accepts a local path or an + object-store URI like ``"s3://bucket/path"``. + where: SQL predicate indicating which rows to delete, + e.g. ``"name IS NULL"``. + io_config: Optional IOConfig to use when accessing Lance data. + + Example: + >>> import daft_lance + >>> daft_lance.delete_from_lance("/path/to/dataset", where="name IS NULL") + + See Also: + `Lance docs — Deleting rows `_ + """ + io_config = context.get_context().daft_planning_config.default_io_config if io_config is None else io_config + storage_options = io_config_to_storage_options(io_config, str(uri) if isinstance(uri, pathlib.Path) else uri) + + dataset = lance.dataset(uri, storage_options=storage_options) + dataset.delete(where) diff --git a/tests/io/lancedb/test_lancedb_update_delete.py b/tests/io/lancedb/test_lancedb_update_delete.py new file mode 100644 index 0000000..bfc8849 --- /dev/null +++ b/tests/io/lancedb/test_lancedb_update_delete.py @@ -0,0 +1,142 @@ +from __future__ import annotations + +import pathlib +import tempfile + +import lance +import pandas as pd +import pytest + +import daft_lance + + +@pytest.fixture +def temp_dir(): + with tempfile.TemporaryDirectory() as tmp: + yield pathlib.Path(tmp) + + +@pytest.fixture +def lance_dataset(temp_dir): + path = str(temp_dir / "test.lance") + df = pd.DataFrame( + { + "id": [1, 2, 3, 4, 5], + "name": ["Alice", "Bob", "Charlie", "David", "Eve"], + "age": [25, 30, 35, 40, 45], + } + ) + lance.write_dataset(df, path) + return path + + +@pytest.fixture +def lance_dataset_with_nulls(temp_dir): + path = str(temp_dir / "nulls.lance") + df = pd.DataFrame( + { + "id": [1, 2, 3, 4, 5], + "name": ["Alice", None, "Charlie", "David", None], + "age": [25, 30, None, 40, 45], + } + ) + lance.write_dataset(df, path) + return path + + +class TestUpdateLance: + def test_update_matching_rows(self, lance_dataset): + result = daft_lance.update_lance( + lance_dataset, updates={"age": "age + 1"}, where="age > 35" + ) + assert result["num_rows_updated"] == 2 + + ds = lance.dataset(lance_dataset) + table = ds.to_table().to_pandas() + assert table.loc[table["id"] == 3, "age"].values[0] == 35 + assert table.loc[table["id"] == 4, "age"].values[0] == 41 + assert table.loc[table["id"] == 5, "age"].values[0] == 46 + + def test_update_all_rows(self, lance_dataset): + result = daft_lance.update_lance(lance_dataset, updates={"age": "age + 10"}) + assert result["num_rows_updated"] == 5 + + ds = lance.dataset(lance_dataset) + table = ds.to_table().to_pandas() + assert table["age"].to_list() == [35, 40, 45, 50, 55] + + def test_update_string_column(self, lance_dataset): + result = daft_lance.update_lance( + lance_dataset, updates={"name": "'Updated'"}, where="id = 1" + ) + assert result["num_rows_updated"] == 1 + + ds = lance.dataset(lance_dataset) + table = ds.to_table().to_pandas() + assert table.loc[table["id"] == 1, "name"].values[0] == "Updated" + assert table.loc[table["id"] == 2, "name"].values[0] == "Bob" + + def test_update_no_matching_rows(self, lance_dataset): + result = daft_lance.update_lance( + lance_dataset, updates={"age": "age + 1"}, where="age > 100" + ) + assert result["num_rows_updated"] == 0 + + def test_update_multiple_columns(self, lance_dataset): + result = daft_lance.update_lance( + lance_dataset, + updates={"age": "age + 5", "name": "'Modified'"}, + where="id IN (1, 2)", + ) + assert result["num_rows_updated"] == 2 + + ds = lance.dataset(lance_dataset) + table = ds.to_table().to_pandas() + row1 = table[table["id"] == 1] + assert row1["age"].values[0] == 30 + assert row1["name"].values[0] == "Modified" + row3 = table[table["id"] == 3] + assert row3["name"].values[0] == "Charlie" + + +class TestDeleteFromLance: + def test_delete_matching_rows(self, lance_dataset): + daft_lance.delete_from_lance(lance_dataset, where="age > 35") + + ds = lance.dataset(lance_dataset) + table = ds.to_table().to_pandas() + assert len(table) == 3 + assert set(table["id"].to_list()) == {1, 2, 3} + + def test_delete_single_row(self, lance_dataset): + daft_lance.delete_from_lance(lance_dataset, where="id = 3") + + ds = lance.dataset(lance_dataset) + table = ds.to_table().to_pandas() + assert len(table) == 4 + assert 3 not in table["id"].values + + def test_delete_no_matching_rows(self, lance_dataset): + daft_lance.delete_from_lance(lance_dataset, where="age > 100") + + ds = lance.dataset(lance_dataset) + table = ds.to_table().to_pandas() + assert len(table) == 5 + + def test_delete_with_null_check(self, lance_dataset_with_nulls): + daft_lance.delete_from_lance(lance_dataset_with_nulls, where="name IS NULL") + + ds = lance.dataset(lance_dataset_with_nulls) + table = ds.to_table().to_pandas() + assert len(table) == 3 + for name in table["name"]: + assert name is not None + + def test_delete_with_is_not_null(self, lance_dataset_with_nulls): + daft_lance.delete_from_lance( + lance_dataset_with_nulls, where="name IS NOT NULL" + ) + + ds = lance.dataset(lance_dataset_with_nulls) + table = ds.to_table().to_pandas() + assert len(table) == 2