diff --git a/rust/lance-core/src/datatypes.rs b/rust/lance-core/src/datatypes.rs index 628f9cf9a90..2e5f9561c15 100644 --- a/rust/lance-core/src/datatypes.rs +++ b/rust/lance-core/src/datatypes.rs @@ -147,7 +147,11 @@ fn timeunit_to_str(unit: &TimeUnit) -> &'static str { fn is_supported_fixed_size_list_child(data_type: &DataType, nested: bool) -> bool { match data_type { DataType::Struct(_) => !nested, - DataType::List(_) | DataType::LargeList(_) | DataType::Map(_, _) => false, + DataType::List(_) + | DataType::LargeList(_) + | DataType::ListView(_) + | DataType::LargeListView(_) + | DataType::Map(_, _) => false, DataType::FixedSizeList(field, _) => { is_supported_fixed_size_list_child(field.data_type(), true) } @@ -211,11 +215,11 @@ impl TryFrom<&DataType> for LogicalType { false ) } - DataType::List(elem) => match elem.data_type() { + DataType::List(elem) | DataType::ListView(elem) => match elem.data_type() { DataType::Struct(_) => "list.struct".to_string(), _ => "list".to_string(), }, - DataType::LargeList(elem) => match elem.data_type() { + DataType::LargeList(elem) | DataType::LargeListView(elem) => match elem.data_type() { DataType::Struct(_) => "large_list.struct".to_string(), _ => "large_list".to_string(), }, diff --git a/rust/lance-core/src/datatypes/field.rs b/rust/lance-core/src/datatypes/field.rs index 4c2665a3640..deb62788e09 100644 --- a/rust/lance-core/src/datatypes/field.rs +++ b/rust/lance-core/src/datatypes/field.rs @@ -494,12 +494,27 @@ impl Field { } } DataType::List(_) => { - let list_arr = arr.as_list::(); - self.children[0].set_dictionary(list_arr.values()); + let values = match arr.data_type() { + DataType::List(_) => arr.as_list::().values(), + DataType::ListView(_) => arr.as_list_view::().values(), + data_type => { + panic!("List field had an unexpected array type: {}", data_type); + } + }; + self.children[0].set_dictionary(values); } DataType::LargeList(_) => { - let list_arr = arr.as_list::(); - self.children[0].set_dictionary(list_arr.values()); + let values = match arr.data_type() { + DataType::LargeList(_) => arr.as_list::().values(), + DataType::LargeListView(_) => arr.as_list_view::().values(), + data_type => { + panic!( + "LargeList field had an unexpected array type: {}", + data_type + ); + } + }; + self.children[0].set_dictionary(values); } _ => { // Field types that don't support dictionaries @@ -1094,8 +1109,12 @@ impl TryFrom<&ArrowField> for Field { .iter() .map(|f| Self::try_from(f.as_ref())) .collect::>()?, - DataType::List(item) => vec![Self::try_from(item.as_ref())?], - DataType::LargeList(item) => vec![Self::try_from(item.as_ref())?], + DataType::List(item) | DataType::ListView(item) => { + vec![Self::try_from(item.as_ref())?] + } + DataType::LargeList(item) | DataType::LargeListView(item) => { + vec![Self::try_from(item.as_ref())?] + } DataType::FixedSizeList(item, _) if matches!(item.data_type(), DataType::Struct(_)) => { vec![Self::try_from(item.as_ref())?] } @@ -1170,9 +1189,11 @@ impl TryFrom<&ArrowField> for Field { dt if dt.is_binary_like() => Some(Encoding::VarBinary), DataType::Dictionary(_, _) => Some(Encoding::Dictionary), // Use plain encoder to store the offsets of list and map. - DataType::List(_) | DataType::LargeList(_) | DataType::Map(_, _) => { - Some(Encoding::Plain) - } + DataType::List(_) + | DataType::LargeList(_) + | DataType::ListView(_) + | DataType::LargeListView(_) + | DataType::Map(_, _) => Some(Encoding::Plain), _ => None, }, metadata, @@ -1317,6 +1338,35 @@ mod tests { LogicalType::try_from(&DataType::BinaryView).unwrap().0, "binary" ); + + let item = Arc::new(ArrowField::new("item", DataType::Int32, true)); + let field = Field::try_from(&ArrowField::new( + "l", + DataType::ListView(item.clone()), + true, + )) + .unwrap(); + assert_eq!(field.data_type(), DataType::List(item.clone())); + assert_eq!( + LogicalType::try_from(&DataType::ListView(item.clone())) + .unwrap() + .0, + "list" + ); + + let field = Field::try_from(&ArrowField::new( + "ll", + DataType::LargeListView(item.clone()), + true, + )) + .unwrap(); + assert_eq!(field.data_type(), DataType::LargeList(item.clone())); + assert_eq!( + LogicalType::try_from(&DataType::LargeListView(item)) + .unwrap() + .0, + "large_list" + ); } #[test] diff --git a/rust/lance/src/dataset/utils.rs b/rust/lance/src/dataset/utils.rs index c9770a3167b..d4ab625b646 100644 --- a/rust/lance/src/dataset/utils.rs +++ b/rust/lance/src/dataset/utils.rs @@ -151,11 +151,27 @@ fn physical_field(field: &ArrowField) -> Option { ArrowField::new(field.name(), DataType::Binary, field.is_nullable()) .with_metadata(field.metadata().clone()), ), + DataType::ListView(item) => Some( + ArrowField::new( + field.name(), + DataType::List(Arc::clone(item)), + field.is_nullable(), + ) + .with_metadata(field.metadata().clone()), + ), + DataType::LargeListView(item) => Some( + ArrowField::new( + field.name(), + DataType::LargeList(Arc::clone(item)), + field.is_nullable(), + ) + .with_metadata(field.metadata().clone()), + ), _ => None, } } -/// Cast `Utf8View`/`BinaryView` columns in a batch to their classic offset equivalents. +/// Cast supported Arrow view columns in a batch to their classic offset equivalents. fn downcast_view_columns( batch: &RecordBatch, ) -> std::result::Result { diff --git a/rust/lance/src/dataset/write.rs b/rust/lance/src/dataset/write.rs index 1e73618fc6b..a7a6720456e 100644 --- a/rust/lance/src/dataset/write.rs +++ b/rust/lance/src/dataset/write.rs @@ -1490,7 +1490,11 @@ mod tests { use super::*; use std::collections::HashMap; - use arrow_array::{Int32Array, RecordBatchIterator, RecordBatchReader, StructArray}; + use arrow_array::{ + Int32Array, LargeListArray, LargeListViewArray, ListArray, ListViewArray, + RecordBatchIterator, RecordBatchReader, StructArray, + }; + use arrow_buffer::{NullBuffer, OffsetBuffer, ScalarBuffer}; use arrow_schema::{DataType, Field as ArrowField, Fields, Schema as ArrowSchema}; use datafusion::{error::DataFusionError, physical_plan::stream::RecordBatchStreamAdapter}; use datafusion_physical_plan::RecordBatchStream; @@ -1511,6 +1515,153 @@ mod tests { assert!(!params.skip_auto_cleanup); } + #[tokio::test] + async fn test_write_list_view_columns_as_offset_lists() { + let item = Arc::new(ArrowField::new("item", DataType::Int32, true)); + let list_view = ListViewArray::new( + item.clone(), + ScalarBuffer::from(vec![6, 0, 3, 0]), + ScalarBuffer::from(vec![3, 1, 3, 2]), + Arc::new(Int32Array::from(vec![1, 2, 3, 4, 5, 6, 7, 8, 9])), + None, + ); + let large_list_view = LargeListViewArray::new( + item.clone(), + ScalarBuffer::from(vec![0_i64, 2, 4, 4]), + ScalarBuffer::from(vec![0_i64, 0, 1, 0]), + Arc::new(Int32Array::from(vec![10, 11, 12, 13, 14])), + Some(NullBuffer::from(vec![true, false, true, true])), + ); + + let schema = Arc::new(ArrowSchema::new(vec![ + ArrowField::new("list", DataType::ListView(item.clone()), true), + ArrowField::new("large_list", DataType::LargeListView(item.clone()), true), + ])); + let batch = RecordBatch::try_new( + schema.clone(), + vec![Arc::new(list_view), Arc::new(large_list_view)], + ) + .unwrap(); + + let dataset = Dataset::write( + RecordBatchIterator::new(vec![Ok(batch)], schema), + "memory://", + None, + ) + .await + .unwrap(); + let output = dataset.scan().try_into_batch().await.unwrap(); + let output_schema = output.schema(); + + let DataType::List(output_item) = output_schema.field(0).data_type() else { + panic!( + "expected List output, got {}", + output_schema.field(0).data_type() + ); + }; + assert_eq!(output_item.data_type(), &DataType::Int32); + let DataType::LargeList(output_item) = output_schema.field(1).data_type() else { + panic!( + "expected LargeList output, got {}", + output_schema.field(1).data_type() + ); + }; + assert_eq!(output_item.data_type(), &DataType::Int32); + + let expected_list = ListArray::new( + item.clone(), + OffsetBuffer::from_lengths([3, 1, 3, 2]), + Arc::new(Int32Array::from(vec![7, 8, 9, 1, 4, 5, 6, 1, 2])), + None, + ); + let actual_list = output + .column(0) + .as_any() + .downcast_ref::() + .unwrap(); + assert_eq!(actual_list, &expected_list); + + let expected_large_list = LargeListArray::new( + item, + OffsetBuffer::from_lengths([0, 0, 1, 0]), + Arc::new(Int32Array::from(vec![14])), + Some(NullBuffer::from(vec![true, false, true, true])), + ); + let actual_large_list = output + .column(1) + .as_any() + .downcast_ref::() + .unwrap(); + assert_eq!(actual_large_list, &expected_large_list); + } + + #[tokio::test] + async fn test_append_list_view_to_list_column() { + let item = Arc::new(ArrowField::new("item", DataType::Int32, true)); + let initial_list = ListArray::new( + item.clone(), + OffsetBuffer::from_lengths([2, 0]), + Arc::new(Int32Array::from(vec![10, 11])), + None, + ); + let initial_schema = Arc::new(ArrowSchema::new(vec![ArrowField::new( + "list", + DataType::List(item.clone()), + true, + )])); + let initial_batch = + RecordBatch::try_new(initial_schema.clone(), vec![Arc::new(initial_list)]).unwrap(); + + let dataset = Dataset::write( + RecordBatchIterator::new(vec![Ok(initial_batch)], initial_schema), + "memory://", + None, + ) + .await + .unwrap(); + + let appended_list_view = ListViewArray::new( + item.clone(), + ScalarBuffer::from(vec![2, 0, 1]), + ScalarBuffer::from(vec![1, 1, 0]), + Arc::new(Int32Array::from(vec![20, 21, 22])), + None, + ); + let append_schema = Arc::new(ArrowSchema::new(vec![ArrowField::new( + "list", + DataType::ListView(item.clone()), + true, + )])); + let append_batch = + RecordBatch::try_new(append_schema.clone(), vec![Arc::new(appended_list_view)]) + .unwrap(); + + let dataset = Dataset::write( + RecordBatchIterator::new(vec![Ok(append_batch)], append_schema), + Arc::new(dataset), + Some(WriteParams { + mode: WriteMode::Append, + ..Default::default() + }), + ) + .await + .unwrap(); + let output = dataset.scan().try_into_batch().await.unwrap(); + + let expected = ListArray::new( + item, + OffsetBuffer::from_lengths([2, 0, 1, 1, 0]), + Arc::new(Int32Array::from(vec![10, 11, 22, 20])), + None, + ); + let actual = output + .column(0) + .as_any() + .downcast_ref::() + .unwrap(); + assert_eq!(actual, &expected); + } + #[tokio::test] async fn test_chunking_large_batches() { // Create a stream of 3 batches of 10 rows