diff --git a/.gitignore b/.gitignore index dcc9c5089ff..3da82f3e690 100644 --- a/.gitignore +++ b/.gitignore @@ -125,3 +125,4 @@ docs/src/community/project-specific/namespace-impls.md docs/src/community/project-specific/ray.md docs/src/community/project-specific/spark.md docs/src/community/project-specific/trino.md +reproduce.py diff --git a/python/python/lance/dataset.py b/python/python/lance/dataset.py index 45dc1b253d3..fc45dc159b6 100644 --- a/python/python/lance/dataset.py +++ b/python/python/lance/dataset.py @@ -793,6 +793,45 @@ def __deserialize__( base_store_params=base_store_params, ) + @classmethod + def from_pydantic_model( + cls, + model_class, + data, + uri: Optional[Union[str, Path]] = None, + mode: str = "create", + **kwargs, + ) -> "LanceDataset": + """Create a LanceDataset from a Pydantic model class and a list of instances. + + The table name is inferred from the model class name converted to snake_case. + The schema is inferred from the data. + + Parameters + ---------- + model_class : type + A Pydantic BaseModel subclass. + data : list + A list of Pydantic model instances. + uri : str or Path, optional + The URI to write the dataset to. If not provided, the model class name + converted to snake_case is used as the path. + mode : str, optional + The write mode. One of "create", "overwrite", or "append". + **kwargs + Additional arguments passed to write_dataset(). + """ + import re + + if uri is None: + uri = re.sub(r"(? tuple[ModuleType, bool]: torch, _TORCH_AVAILABLE = _lazy_import("torch") datasets, _HUGGING_FACE_AVAILABLE = _lazy_import("datasets") tensorflow, _TENSORFLOW_AVAILABLE = _lazy_import("tensorflow") + _, _PYDANTIC_AVAILABLE = _lazy_import("pydantic") @lru_cache(maxsize=None) @@ -221,6 +223,12 @@ def _check_for_tensorflow(obj: Any, *, check_type: bool = True) -> bool: ) +def _check_for_pydantic(obj: Any, *, check_type: bool = True) -> bool: + return _PYDANTIC_AVAILABLE and _might_be( + cast("Hashable", type(obj) if check_type else obj), "pydantic" + ) + + __all__ = [ # lazy-load third party libs "datasets", @@ -234,6 +242,7 @@ def _check_for_tensorflow(obj: Any, *, check_type: bool = True) -> bool: "_check_for_numpy", "_check_for_pandas", "_check_for_polars", + "_check_for_pydantic", "_check_for_tensorflow", "_check_for_torch", "_LazyModule", @@ -241,6 +250,7 @@ def _check_for_tensorflow(obj: Any, *, check_type: bool = True) -> bool: "_NUMPY_AVAILABLE", "_PANDAS_AVAILABLE", "_POLARS_AVAILABLE", + "_PYDANTIC_AVAILABLE", "_TORCH_AVAILABLE", "_HUGGING_FACE_AVAILABLE", "_TENSORFLOW_AVAILABLE", diff --git a/python/python/lance/types.py b/python/python/lance/types.py index 41cc191e4d6..2a9e271a660 100644 --- a/python/python/lance/types.py +++ b/python/python/lance/types.py @@ -9,7 +9,11 @@ from pyarrow import RecordBatch from . import dataset -from .dependencies import _check_for_hugging_face, _check_for_pandas +from .dependencies import ( + _check_for_hugging_face, + _check_for_pandas, + _check_for_pydantic, +) from .dependencies import pandas as pd if TYPE_CHECKING: @@ -116,6 +120,20 @@ def batch_iter(): # List of dictionaries batch = pa.RecordBatch.from_pylist(data_obj, schema=schema) return pa.RecordBatchReader.from_batches(batch.schema, [batch]) + elif ( + isinstance(data_obj, list) + and len(data_obj) > 0 + and _check_for_pydantic(data_obj[0]) + ): + from pydantic import BaseModel + + if isinstance(data_obj[0], BaseModel): + dicts = [ + item.model_dump() if hasattr(item, "model_dump") else item.dict() + for item in data_obj + ] + batch = pa.RecordBatch.from_pylist(dicts, schema=schema) + return pa.RecordBatchReader.from_batches(batch.schema, [batch]) # for other iterables, assume they are of type Iterable[RecordBatch] elif isinstance(data_obj, Iterable): if schema is not None: diff --git a/python/python/tests/test_dataset.py b/python/python/tests/test_dataset.py index 45866f3c4da..64936a68474 100644 --- a/python/python/tests/test_dataset.py +++ b/python/python/tests/test_dataset.py @@ -34,9 +34,17 @@ from lance.file import stable_version from lance.schema import LanceSchema from lance.util import validate_vector_index +from pydantic import BaseModel # Various valid inputs for write_dataset input_schema = pa.schema([pa.field("a", pa.float64()), pa.field("b", pa.int64())]) + + +class _InputModel(BaseModel): + a: float + b: int + + input_data = [ # (schema, data) (None, pa.table({"a": [1.0, 2.0], "b": [20, 30]})), @@ -59,6 +67,8 @@ ).to_batches() ), ), + # Pydantic model instances are auto-converted + (None, [_InputModel(a=1.0, b=20), _InputModel(a=2.0, b=30)]), ] @@ -69,6 +79,22 @@ def test_input_data(tmp_path: Path, schema, data): assert dataset.to_table() == input_data[0][1] +def test_from_pydantic_model(tmp_path: Path): + class UserRecord(BaseModel): + name: str + score: float + + data = [UserRecord(name="alice", score=0.9), UserRecord(name="bob", score=0.8)] + uri = str(tmp_path / "user_record") + ds = lance.LanceDataset.from_pydantic_model(UserRecord, data, uri=uri) + + table = ds.to_table() + assert table.num_rows == 2 + assert table.schema.names == ["name", "score"] + assert table.column("name").to_pylist() == ["alice", "bob"] + assert table.column("score").to_pylist() == [0.9, 0.8] + + def test_roundtrip_types(tmp_path: Path): table = pa.table( {