diff --git a/protarrow/arrow_to_proto.py b/protarrow/arrow_to_proto.py index 31dc6ee..b36b0db 100644 --- a/protarrow/arrow_to_proto.py +++ b/protarrow/arrow_to_proto.py @@ -1,7 +1,7 @@ import collections.abc import dataclasses import datetime -from typing import Any, Callable, Iterable, Iterator, List, Optional, Tuple, Type +from typing import Any, Callable, Iterable, Iterator, List, Optional, Tuple, Type, Union import pyarrow as pa from google.protobuf.descriptor import Descriptor, EnumDescriptor, FieldDescriptor @@ -457,11 +457,34 @@ def _extract_struct_field( _extract_array_messages(array, field_descriptor.message_type, nested_list) +def _convert_list_back_to_map( + array: Union[pa.ListArray, pa.LargeListArray], +) -> pa.MapArray: + """ + Converts a list of structs back to a map + :param array: A list of structs. + :return: A map. + """ + assert pa.types.is_struct(array.values.type), array.values.type + assert len(array.values.type.fields) == 2, ( + f"Must have only 2 fields, got {array.values.type.fields}." + ) + return pa.MapArray.from_arrays( + offsets=array.offsets, + keys=array.values.field(0), + items=array.values.field(1), + ) + + def _extract_map_field( - array: pa.MapArray, + array: Union[pa.MapArray, pa.ListArray, pa.LargeListArray], field_descriptor: FieldDescriptor, messages: Iterable[Message], ) -> None: + + if pa.types.is_list(array.type) or pa.types.is_large_list(array.type): + array = _convert_list_back_to_map(array=array) + assert pa.types.is_map(array.type), array.type value_descriptor = field_descriptor.message_type.fields_by_name["value"] diff --git a/protarrow/cast_to_proto.py b/protarrow/cast_to_proto.py index d0c4cd9..bce3716 100644 --- a/protarrow/cast_to_proto.py +++ b/protarrow/cast_to_proto.py @@ -17,6 +17,7 @@ from protarrow.proto_to_arrow import ( _PROTO_DESCRIPTOR_TO_PYARROW, _PROTO_PRIMITIVE_TYPE_TO_PYARROW, + _map_as_list_from_arrays, field_descriptor_to_field, get_map_descriptors, is_map, @@ -107,26 +108,47 @@ def _cast_array( config: ProtarrowConfig, ) -> pa.Array: if is_map(field_descriptor): - assert isinstance(array, pa.MapArray) key_field, value_field = get_map_descriptors(field_descriptor) - map_array = pa.MapArray.from_arrays( - # TODO: remove when https://github.com/apache/arrow/issues/40750 is fixed - # and library is pinned to pyarrow>=17.0.0 - maybe_copy_offsets(array.offsets), - _cast_array(array.keys, key_field, config), - _cast_array(array.items, value_field, config), - ) - return map_array.cast( - pa.map_( - map_array.type.key_type, - pa.field( - config.map_value_name, - map_array.type.item_type, - nullable=config.map_value_nullable, - metadata=config.field_metadata(field_descriptor.number), - ), + + if pa.types.is_map(array.type): + keys = array.keys + values = array.items + else: + assert pa.types.is_list(array.type) or pa.types.is_large_list(array.type), ( + array.type + ) + assert pa.types.is_struct(array.values.type), array.values.type + assert len(array.values.type.fields) == 2, ( + f"Must have only 2 fields, got {array.values.type.fields}." + ) + keys = array.values.field(0) + values = array.values.field(1) + + # TODO: remove when https://github.com/apache/arrow/issues/40750 is fixed + # and library is pinned to pyarrow>=17.0.0 + offsets = maybe_copy_offsets(array.offsets) + keys = _cast_array(keys, key_field, config) + values = _cast_array(values, value_field, config) + + if config.map_as_list: + return _map_as_list_from_arrays( + offsets=offsets, + keys=keys, + values=values, + config=config, + ) + else: + return pa.MapArray.from_arrays(offsets, keys, values).cast( + pa.map_( + keys.type, + pa.field( + config.map_value_name, + values.type, + nullable=config.map_value_nullable, + metadata=config.field_metadata(field_descriptor.number), + ), + ) ) - ) elif field_descriptor.is_repeated: assert isinstance(array, (pa.ListArray, pa.LargeListArray)) diff --git a/protarrow/common.py b/protarrow/common.py index 97a2c21..35b6333 100644 --- a/protarrow/common.py +++ b/protarrow/common.py @@ -61,6 +61,7 @@ class ProtarrowConfig: binary_type: pa.DataType = pa.binary() list_array_type: type = pa.ListArray skip_recursive_messages: bool = False + map_as_list: bool = False def __post_init__(self): _validate_enum_type(self.enum_type, self.string_type, self.binary_type) diff --git a/protarrow/message_extractor.py b/protarrow/message_extractor.py index 605c918..acaa124 100644 --- a/protarrow/message_extractor.py +++ b/protarrow/message_extractor.py @@ -1,4 +1,4 @@ -from typing import Any, Callable, Dict, Generic, List, Type, TypeVar +from typing import Any, Callable, Dict, Generic, List, Type, TypeVar, Union import pyarrow as pa from google.protobuf.descriptor import Descriptor, FieldDescriptor @@ -65,6 +65,35 @@ def __call__(self, scalar: pa.MapScalar) -> Dict[Any, Any]: return {} +class MapAsListConverterAdapter: + def __init__( + self, + list_type: Union[pa.ListType, pa.LargeListType], + key_descriptor: FieldDescriptor, + value_descriptor: FieldDescriptor, + ): + struct_type = list_type.value_field.type + assert pa.types.is_struct(struct_type) + key_field, value_field = struct_type.fields + + self._key_converter = get_flat_field_converter(key_field.type, key_descriptor) + self._value_converter = get_flat_field_converter( + value_field.type, value_descriptor + ) + + def __call__( + self, + scalar: Union[pa.ListScalar, pa.LargeListScalar], + ) -> Dict[Any, Any]: + if scalar.is_valid: + return { + self._key_converter(item.get(0)): self._value_converter(item.get(1)) + for item in scalar.values + } + else: + return {} + + class NullableConverterAdapter: def __init__( self, converter: Callable[[pa.Scalar], Any], message_type: Type[Message] @@ -99,7 +128,10 @@ def get_field_converter( ) -> Callable[[pa.Scalar], Any]: if is_map(field_descriptor): key, value = get_map_descriptors(field_descriptor) - return MapConverterAdapter(field.type, key, value) + if pa.types.is_map(field.type): + return MapConverterAdapter(field.type, key, value) + else: + return MapAsListConverterAdapter(field.type, key, value) else: if field_descriptor.is_repeated: return RepeatedConverterAdapter( diff --git a/protarrow/proto_to_arrow.py b/protarrow/proto_to_arrow.py index 190f169..bce942a 100644 --- a/protarrow/proto_to_arrow.py +++ b/protarrow/proto_to_arrow.py @@ -275,12 +275,31 @@ def field_descriptor_to_field( value_type = field_descriptor_to_data_type( value_field, config, descriptor_trace ) - return pa.field( - field_descriptor.name, - pa.map_( + if config.map_as_list: + map_type = config.list_( + item_type=pa.struct( + fields=[ + pa.field( + name="key", + type=key_type, + nullable=False, + ), + pa.field( + name=config.map_value_name, + type=value_type, + nullable=config.map_value_nullable, + ), + ] + ) + ) + else: + map_type = pa.map_( key_type, pa.field(config.map_value_name, value_type, config.map_value_nullable), - ), + ) + return pa.field( + field_descriptor.name, + map_type, nullable=config.map_nullable, metadata=config.field_metadata(field_descriptor.number), ) @@ -472,6 +491,58 @@ def _repeated_proto_to_array( ) +def _map_as_list_from_arrays( + offsets: pa.Array, + keys: pa.Array, + values: pa.Array, + config: ProtarrowConfig, +) -> pa.Array: + """ + Creates a "map as list" from arrays. + + :param offsets: An array of offsets. + :param keys: An array of keys. + :param values: An array of values. + :param config: The ProtarrowConfig. + :return: An Array. + """ + return config.list_array_type.from_arrays( + offsets=offsets, + values=pa.StructArray.from_arrays( + arrays=[keys, values], + fields=[ + pa.field( + name="key", + type=keys.type, + nullable=False, + ), + pa.field( + name=config.map_value_name, + type=values.type, + nullable=config.map_value_nullable, + ), + ], + ), + ).cast( + config.list_( + item_type=pa.struct( + fields=[ + pa.field( + name="key", + type=keys.type, + nullable=False, + ), + pa.field( + name=config.map_value_name, + type=values.type, + nullable=config.map_value_nullable, + ), + ] + ) + ) + ) + + def _proto_map_to_array( maps: Iterable[MessageMap], field_descriptor: FieldDescriptor, @@ -498,14 +569,25 @@ def _proto_map_to_array( config=config, descriptor_trace=descriptor_trace, ) - return pa.MapArray.from_arrays(offsets, keys, values).cast( - pa.map_( - keys.type, - pa.field( - config.map_value_name, values.type, nullable=config.map_value_nullable - ), + if config.map_as_list: + array = _map_as_list_from_arrays( + offsets=offsets, + keys=keys, + values=values, + config=config, ) - ) + else: + array = pa.MapArray.from_arrays(offsets, keys, values).cast( + pa.map_( + keys.type, + pa.field( + config.map_value_name, + values.type, + nullable=config.map_value_nullable, + ), + ) + ) + return array def _proto_field_nullable( diff --git a/tests/test_conversion.py b/tests/test_conversion.py index ba08945..264826e 100644 --- a/tests/test_conversion.py +++ b/tests/test_conversion.py @@ -51,7 +51,12 @@ from tests.random_generator import generate_messages, truncate_messages, truncate_nanos TEST_MESSAGE_COUNT = 5 -MESSAGES = [ExampleMessage, NestedExampleMessage, SuperNestedExampleMessage] +MESSAGES = [ + ExampleMessage, + NestedExampleMessage, + SuperNestedExampleMessage, +] + CONFIGS = [ ProtarrowConfig(), ProtarrowConfig(enum_type=pa.binary()), @@ -86,6 +91,7 @@ ProtarrowConfig(field_number_key=b"PARQUET:field_id"), ProtarrowConfig(string_type=pa.large_string()), ProtarrowConfig(binary_type=pa.large_binary()), + ProtarrowConfig(map_as_list=True), ProtarrowConfig(list_array_type=pa.LargeListArray), ] diff --git a/tests/test_coverage.py b/tests/test_coverage.py index 4e8fbc8..0d6bc46 100644 --- a/tests/test_coverage.py +++ b/tests/test_coverage.py @@ -26,9 +26,12 @@ _extract_map_field, _extract_record_batch_messages, convert_scalar, + record_batch_to_messages, ) from protarrow.cast_to_proto import get_arrow_default_value +from protarrow.common import ProtarrowConfig from protarrow.message_extractor import ( + MapAsListConverterAdapter, MapConverterAdapter, NullableConverterAdapter, RepeatedConverterAdapter, @@ -39,6 +42,8 @@ _get_converter, field_descriptor_to_data_type, get_enum_converter, + message_type_to_schema, + messages_to_record_batch, ) from protarrow_protos.bench_pb2 import ( ExampleMessage, @@ -61,6 +66,88 @@ def test_map_converter_adapter(): assert map_converter_adapter(pa.scalar(None, map_type)) == {} +def test_map_as_list_converter_adapter(): + list_type = pa.list_(pa.struct([("key", pa.int32()), ("value", pa.float64())])) + map_field = ExampleMessage.DESCRIPTOR.fields_by_name["double_int32_map"] + map_converter_adapter = MapAsListConverterAdapter( + list_type=list_type, + key_descriptor=map_field.message_type.fields_by_name["key"], + value_descriptor=map_field.message_type.fields_by_name["value"], + ) + assert map_converter_adapter(pa.scalar([(123, 1.0)], list_type)) == {123: 1.0} + assert map_converter_adapter(pa.scalar([], list_type)) == {} + assert map_converter_adapter(pa.scalar(None, list_type)) == {} + + +@pytest.mark.parametrize( + "config", + [ + ProtarrowConfig(list_array_type=pa.ListArray, map_as_list=True), + ProtarrowConfig(list_array_type=pa.LargeListArray, map_as_list=True), + ProtarrowConfig( + list_array_type=pa.LargeListArray, + map_as_list=True, + map_value_name="something", + ), + ], +) +def test_map_as_list_example_message_1(config: ProtarrowConfig): + input_data_batch = [ + { + 1: 0.4, + -4: -0.6, + 5: 0.2, + 2: 0.3, + }, + { + 6: 0.2, + }, + {}, + ] + message_batch = [] + for input_ in input_data_batch: + message = ExampleMessage() + for k, v in input_.items(): + message.double_int32_map[k] = v + message_batch.append(message) + + record_batch = messages_to_record_batch(message_batch, ExampleMessage, config) + + output_data_batch = [] + for record in record_batch.to_pylist(): + output_data = {} + for e in record["double_int32_map"]: + output_data[e["key"]] = e[config.map_value_name] + output_data_batch.append(output_data) + + for input_data, output_data in zip(input_data_batch, output_data_batch): + assert input_data == output_data + + +@pytest.mark.parametrize( + "config", + [ + ProtarrowConfig(list_array_type=pa.ListArray, map_as_list=True), + ProtarrowConfig(list_array_type=pa.LargeListArray, map_as_list=True), + ProtarrowConfig( + list_array_type=pa.LargeListArray, + map_as_list=True, + map_value_name="something", + ), + ], +) +def test_map_as_list_example_message_2(config: ProtarrowConfig): + schema = message_type_to_schema(ExampleMessage, config) + input_data = {1: 0.4, 2: 0.3, -4: -0.6, 5: 0.2} + record_batch = pa.RecordBatch.from_pylist( + [{"double_int32_map": list(input_data.items())}], + schema=schema, + ) + (message,) = record_batch_to_messages(record_batch, ExampleMessage) + output_data = message.double_int32_map + assert input_data == output_data + + def test_nullable_converter_adapter(): nullable_converter_adapter = NullableConverterAdapter(convert_scalar, DoubleValue) assert nullable_converter_adapter(pa.scalar(1.0, pa.float64())) == DoubleValue(