2020
2121from typing_extensions import Self
2222
23+ try :
24+ import pydantic
25+ import pydantic_core
26+ except ImportError :
27+ pydantic = None
28+ pydantic_core = None
29+
2330import betterproto2 .validators as validators
2431from betterproto2 .message_pool import MessagePool
2532from betterproto2 .utils import unwrap
@@ -697,6 +704,26 @@ def _betterproto(cls: type[Self]) -> ProtoClassMetadata: # type: ignore
697704 cls ._betterproto_meta = ProtoClassMetadata (cls )
698705 return cls ._betterproto_meta
699706
707+ def _is_pydantic (self ) -> bool :
708+ """
709+ Check if the message is a pydantic dataclass.
710+ """
711+ return pydantic is not None and pydantic .dataclasses .is_pydantic_dataclass (type (self ))
712+
713+ def _validate (self ) -> None :
714+ """
715+ Manually validate the message using pydantic.
716+
717+ This is useful since pydantic does not revalidate the message when fields are changed. This function doesn't
718+ validate the fields recursively.
719+ """
720+ if not self ._is_pydantic ():
721+ raise TypeError ("Validation is only available for pydantic dataclasses." )
722+
723+ dict = self .__dict__ .copy ()
724+ del dict ["_unknown_fields" ]
725+ pydantic_core .SchemaValidator (self .__pydantic_core_schema__ ).validate_python (dict ) # type: ignore
726+
700727 def dump (self , stream : SupportsWrite [bytes ], delimit : bool = False ) -> None :
701728 """
702729 Dumps the binary encoded Protobuf message to the stream.
@@ -720,6 +747,9 @@ def __bytes__(self) -> bytes:
720747 """
721748 Get the binary encoded Protobuf representation of this message instance.
722749 """
750+ if self ._is_pydantic ():
751+ self ._validate ()
752+
723753 with BytesIO () as stream :
724754 for field_name , meta in self ._betterproto .meta_by_field_name .items ():
725755 value = getattr (self , field_name )
@@ -822,13 +852,17 @@ def _postprocess_single(self, wire_type: int, meta: FieldMetadata, field_name: s
822852 """Adjusts values after parsing."""
823853 if wire_type == WIRE_VARINT :
824854 if meta .proto_type in (TYPE_INT32 , TYPE_INT64 ):
825- bits = int ( meta .proto_type [ 3 :])
855+ bits = 32 if meta .proto_type == TYPE_INT32 else 64
826856 value = value & ((1 << bits ) - 1 )
827857 signbit = 1 << (bits - 1 )
828858 value = int ((value ^ signbit ) - signbit )
859+ elif meta .proto_type in (TYPE_UINT32 , TYPE_UINT64 ):
860+ bits = 32 if meta .proto_type == TYPE_UINT32 else 64
861+ value = value & ((1 << bits ) - 1 )
829862 elif meta .proto_type in (TYPE_SINT32 , TYPE_SINT64 ):
830- # Undo zig-zag encoding
831- value = (value >> 1 ) ^ (- (value & 1 ))
863+ bits = 32 if meta .proto_type == TYPE_SINT32 else 64
864+ value = value & ((1 << bits ) - 1 )
865+ value = (value >> 1 ) ^ (- (value & 1 )) # Undo zig-zag encoding
832866 elif meta .proto_type == TYPE_BOOL :
833867 # Booleans use a varint encoding, so convert it to true/false.
834868 value = value > 0
@@ -947,6 +981,9 @@ def load(
947981 " or the expected size may have been incorrect."
948982 )
949983
984+ if self ._is_pydantic ():
985+ self ._validate ()
986+
950987 return self
951988
952989 @classmethod
@@ -1017,6 +1054,9 @@ def to_dict(
10171054 Dict[:class:`str`, Any]
10181055 The JSON serializable dict representation of this object.
10191056 """
1057+ if self ._is_pydantic ():
1058+ self ._validate ()
1059+
10201060 kwargs = { # For recursive calls
10211061 "output_format" : output_format ,
10221062 "casing" : casing ,
0 commit comments