@@ -141,8 +141,10 @@ class FieldMetadata:
141141 number : int
142142 # Protobuf type name
143143 proto_type : str
144+
144145 # Map information if the proto_type is a map
145- map_types : tuple [str , str ] | None = None
146+ map_meta : tuple [FieldMetadata , FieldMetadata ] | None = None
147+
146148 # Groups several "one-of" fields together
147149 group : str | None = None
148150
@@ -160,12 +162,24 @@ def get(field: dataclasses.Field) -> FieldMetadata:
160162 return field .metadata ["betterproto" ]
161163
162164
165+ def map_meta (
166+ proto_type_1 : str ,
167+ proto_type_2 : str ,
168+ * ,
169+ unwrap_2 : Callable [[], type ] | None = None ,
170+ ) -> tuple [FieldMetadata , FieldMetadata ]:
171+ key_meta = FieldMetadata (1 , proto_type_1 )
172+ value_meta = FieldMetadata (2 , proto_type_2 , unwrap = unwrap_2 )
173+
174+ return key_meta , value_meta
175+
176+
163177def field (
164178 number : int ,
165179 proto_type : str ,
166180 * ,
167181 default_factory : Callable [[], Any ] | None = None ,
168- map_types : tuple [str , str ] | None = None ,
182+ map_meta : tuple [FieldMetadata , FieldMetadata ] | None = None ,
169183 group : str | None = None ,
170184 unwrap : Callable [[], type ] | None = None ,
171185 optional : bool = False ,
@@ -202,7 +216,7 @@ def field(
202216
203217 return dataclasses .field (
204218 default_factory = default_factory ,
205- metadata = {"betterproto" : FieldMetadata (number , proto_type , map_types , group , unwrap , optional , repeated )},
219+ metadata = {"betterproto" : FieldMetadata (number , proto_type , map_meta , group , unwrap , optional , repeated )},
206220 )
207221
208222
@@ -485,7 +499,7 @@ def _get_cls_by_field(cls: type[Message], fields: Iterable[dataclasses.Field]) -
485499 for field_ in fields :
486500 meta = FieldMetadata .get (field_ )
487501 if meta .proto_type == TYPE_MAP :
488- assert meta .map_types
502+ assert meta .map_meta
489503 kt = cls ._cls_for (field_ , index = 0 )
490504 vt = cls ._cls_for (field_ , index = 1 )
491505 field_cls [field_ .name ] = dataclasses .make_dataclass (
@@ -494,12 +508,12 @@ def _get_cls_by_field(cls: type[Message], fields: Iterable[dataclasses.Field]) -
494508 (
495509 "key" ,
496510 kt ,
497- field (1 , meta .map_types [0 ], default_factory = kt ),
511+ field (1 , meta .map_meta [0 ]. proto_type , default_factory = kt ),
498512 ),
499513 (
500514 "value" ,
501515 vt ,
502- field (2 , meta .map_types [1 ], default_factory = vt ),
516+ field (2 , meta .map_meta [1 ]. proto_type , default_factory = vt ),
503517 ),
504518 ],
505519 bases = (Message ,),
@@ -720,9 +734,9 @@ def __bytes__(self) -> bytes:
720734
721735 elif isinstance (value , dict ):
722736 for k , v in value .items ():
723- assert meta .map_types
724- sk = _serialize_single (1 , meta .map_types [0 ], k )
725- sv = _serialize_single (2 , meta .map_types [1 ], v )
737+ assert meta .map_meta
738+ sk = _serialize_single (1 , meta .map_meta [0 ]. proto_type , k )
739+ sv = _serialize_single (2 , meta .map_meta [1 ]. proto_type , v , unwrap = meta . map_meta [ 1 ]. unwrap )
726740 stream .write (_serialize_single (meta .number , meta .proto_type , sk + sv ))
727741 else :
728742 stream .write (
@@ -1007,13 +1021,12 @@ def to_dict(
10071021 output [cased_name ] = output_value
10081022
10091023 elif meta .proto_type == TYPE_MAP :
1010- assert meta .map_types is not None
1024+ assert meta .map_meta is not None
10111025 field_type_k = field_types [field_name ].__args__ [0 ]
10121026 field_type_v = field_types [field_name ].__args__ [1 ]
1013- # TODO wrapped types don't work in maps
10141027 output_map = {
1015- _value_to_dict (k , meta .map_types [0 ], field_type_k , None , ** kwargs )[0 ]: _value_to_dict (
1016- v , meta .map_types [1 ], field_type_v , None , ** kwargs
1028+ _value_to_dict (k , meta .map_meta [0 ]. proto_type , field_type_k , None , ** kwargs )[0 ]: _value_to_dict (
1029+ v , meta .map_meta [1 ]. proto_type , field_type_v , meta . map_meta [ 1 ]. unwrap , ** kwargs
10171030 )[0 ]
10181031 for k , v in value .items ()
10191032 }
@@ -1058,7 +1071,7 @@ def _from_dict_init(cls, mapping: Mapping[str, Any]) -> Mapping[str, Any]:
10581071 value , meta .proto_type , cls ._betterproto .cls_by_field [field_name ], meta .unwrap
10591072 )
10601073
1061- elif meta .map_types and meta .map_types [1 ] == TYPE_MESSAGE :
1074+ elif meta .map_meta and meta .map_meta [1 ]. proto_type == TYPE_MESSAGE :
10621075 sub_cls = cls ._betterproto .cls_by_field [f"{ field_name } .value" ]
10631076 value = {k : sub_cls .from_dict (v ) for k , v in value .items ()}
10641077 else :
@@ -1209,7 +1222,7 @@ def from_pydict(self: T, value: Mapping[str, Any]) -> T:
12091222 v = value [key ]
12101223 else :
12111224 v = cls ().from_pydict (value [key ])
1212- elif meta .map_types and meta .map_types [1 ] == TYPE_MESSAGE :
1225+ elif meta .map_meta and meta .map_meta [1 ]. proto_type == TYPE_MESSAGE :
12131226 v = getattr (self , field_name )
12141227 cls = self ._betterproto .cls_by_field [f"{ field_name } .value" ]
12151228 for k in value [key ]:
0 commit comments