Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -125,3 +125,4 @@ docs/src/community/project-specific/namespace-impls.md
docs/src/community/project-specific/ray.md
docs/src/community/project-specific/spark.md
docs/src/community/project-specific/trino.md
reproduce.py
39 changes: 39 additions & 0 deletions python/python/lance/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -793,6 +793,45 @@ def __deserialize__(
base_store_params=base_store_params,
)

@classmethod
def from_pydantic_model(
cls,
model_class,
data,
uri: Optional[Union[str, Path]] = None,
mode: str = "create",
**kwargs,
) -> "LanceDataset":
"""Create a LanceDataset from a Pydantic model class and a list of instances.

The table name is inferred from the model class name converted to snake_case.
The schema is inferred from the data.

Parameters
----------
model_class : type
A Pydantic BaseModel subclass.
data : list
A list of Pydantic model instances.
uri : str or Path, optional
The URI to write the dataset to. If not provided, the model class name
converted to snake_case is used as the path.
mode : str, optional
The write mode. One of "create", "overwrite", or "append".
**kwargs
Additional arguments passed to write_dataset().
"""
import re

if uri is None:
uri = re.sub(r"(?<!^)(?=[A-Z])", "_", model_class.__name__).lower()
dicts = [
item.model_dump() if hasattr(item, "model_dump") else item.dict()
for item in data
]
table = pa.Table.from_pylist(dicts)
return write_dataset(table, uri, mode=mode, **kwargs)

def __reduce__(self):
return type(self).__deserialize__, (
self.uri,
Expand Down
10 changes: 10 additions & 0 deletions python/python/lance/dependencies.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
_RAFT_COMMON_AVAILABLE = True
_HUGGING_FACE_AVAILABLE = True
_TENSORFLOW_AVAILABLE = True
_PYDANTIC_AVAILABLE = True


class _LazyModule(ModuleType):
Expand Down Expand Up @@ -173,6 +174,7 @@ def _lazy_import(module_name: str) -> tuple[ModuleType, bool]:
torch, _TORCH_AVAILABLE = _lazy_import("torch")
datasets, _HUGGING_FACE_AVAILABLE = _lazy_import("datasets")
tensorflow, _TENSORFLOW_AVAILABLE = _lazy_import("tensorflow")
_, _PYDANTIC_AVAILABLE = _lazy_import("pydantic")


@lru_cache(maxsize=None)
Expand Down Expand Up @@ -221,6 +223,12 @@ def _check_for_tensorflow(obj: Any, *, check_type: bool = True) -> bool:
)


def _check_for_pydantic(obj: Any, *, check_type: bool = True) -> bool:
return _PYDANTIC_AVAILABLE and _might_be(
cast("Hashable", type(obj) if check_type else obj), "pydantic"
)


__all__ = [
# lazy-load third party libs
"datasets",
Expand All @@ -234,13 +242,15 @@ def _check_for_tensorflow(obj: Any, *, check_type: bool = True) -> bool:
"_check_for_numpy",
"_check_for_pandas",
"_check_for_polars",
"_check_for_pydantic",
"_check_for_tensorflow",
"_check_for_torch",
"_LazyModule",
# exported flags/guards
"_NUMPY_AVAILABLE",
"_PANDAS_AVAILABLE",
"_POLARS_AVAILABLE",
"_PYDANTIC_AVAILABLE",
"_TORCH_AVAILABLE",
"_HUGGING_FACE_AVAILABLE",
"_TENSORFLOW_AVAILABLE",
Expand Down
20 changes: 19 additions & 1 deletion python/python/lance/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,11 @@
from pyarrow import RecordBatch

from . import dataset
from .dependencies import _check_for_hugging_face, _check_for_pandas
from .dependencies import (
_check_for_hugging_face,
_check_for_pandas,
_check_for_pydantic,
)
from .dependencies import pandas as pd

if TYPE_CHECKING:
Expand Down Expand Up @@ -116,6 +120,20 @@ def batch_iter():
# List of dictionaries
batch = pa.RecordBatch.from_pylist(data_obj, schema=schema)
return pa.RecordBatchReader.from_batches(batch.schema, [batch])
elif (
isinstance(data_obj, list)
and len(data_obj) > 0
and _check_for_pydantic(data_obj[0])
):
from pydantic import BaseModel

if isinstance(data_obj[0], BaseModel):
dicts = [
item.model_dump() if hasattr(item, "model_dump") else item.dict()
for item in data_obj
]
batch = pa.RecordBatch.from_pylist(dicts, schema=schema)
return pa.RecordBatchReader.from_batches(batch.schema, [batch])
# for other iterables, assume they are of type Iterable[RecordBatch]
elif isinstance(data_obj, Iterable):
if schema is not None:
Expand Down
26 changes: 26 additions & 0 deletions python/python/tests/test_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,9 +34,17 @@
from lance.file import stable_version
from lance.schema import LanceSchema
from lance.util import validate_vector_index
from pydantic import BaseModel

# Various valid inputs for write_dataset
input_schema = pa.schema([pa.field("a", pa.float64()), pa.field("b", pa.int64())])


class _InputModel(BaseModel):
a: float
b: int


input_data = [
# (schema, data)
(None, pa.table({"a": [1.0, 2.0], "b": [20, 30]})),
Expand All @@ -59,6 +67,8 @@
).to_batches()
),
),
# Pydantic model instances are auto-converted
(None, [_InputModel(a=1.0, b=20), _InputModel(a=2.0, b=30)]),
]


Expand All @@ -69,6 +79,22 @@ def test_input_data(tmp_path: Path, schema, data):
assert dataset.to_table() == input_data[0][1]


def test_from_pydantic_model(tmp_path: Path):
class UserRecord(BaseModel):
name: str
score: float

data = [UserRecord(name="alice", score=0.9), UserRecord(name="bob", score=0.8)]
uri = str(tmp_path / "user_record")
ds = lance.LanceDataset.from_pydantic_model(UserRecord, data, uri=uri)

table = ds.to_table()
assert table.num_rows == 2
assert table.schema.names == ["name", "score"]
assert table.column("name").to_pylist() == ["alice", "bob"]
assert table.column("score").to_pylist() == [0.9, 0.8]


def test_roundtrip_types(tmp_path: Path):
table = pa.table(
{
Expand Down
Loading