Skip to content

Commit 00acbd1

Browse files
More validation (#104)
* Pydantic validation method * Validate the message before exporting it * Validate after parsing * Fix uint parsing * Simplify code * Fix sint parsing * Add test * Test manual validation * Fix typechecking
1 parent b2ce674 commit 00acbd1

3 files changed

Lines changed: 89 additions & 3 deletions

File tree

src/betterproto2/__init__.py

Lines changed: 43 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,13 @@
2020

2121
from typing_extensions import Self
2222

23+
try:
24+
import pydantic
25+
import pydantic_core
26+
except ImportError:
27+
pydantic = None
28+
pydantic_core = None
29+
2330
import betterproto2.validators as validators
2431
from betterproto2.message_pool import MessagePool
2532
from betterproto2.utils import unwrap
@@ -697,6 +704,26 @@ def _betterproto(cls: type[Self]) -> ProtoClassMetadata: # type: ignore
697704
cls._betterproto_meta = ProtoClassMetadata(cls)
698705
return cls._betterproto_meta
699706

707+
def _is_pydantic(self) -> bool:
708+
"""
709+
Check if the message is a pydantic dataclass.
710+
"""
711+
return pydantic is not None and pydantic.dataclasses.is_pydantic_dataclass(type(self))
712+
713+
def _validate(self) -> None:
714+
"""
715+
Manually validate the message using pydantic.
716+
717+
This is useful since pydantic does not revalidate the message when fields are changed. This function doesn't
718+
validate the fields recursively.
719+
"""
720+
if not self._is_pydantic():
721+
raise TypeError("Validation is only available for pydantic dataclasses.")
722+
723+
dict = self.__dict__.copy()
724+
del dict["_unknown_fields"]
725+
pydantic_core.SchemaValidator(self.__pydantic_core_schema__).validate_python(dict) # type: ignore
726+
700727
def dump(self, stream: SupportsWrite[bytes], delimit: bool = False) -> None:
701728
"""
702729
Dumps the binary encoded Protobuf message to the stream.
@@ -720,6 +747,9 @@ def __bytes__(self) -> bytes:
720747
"""
721748
Get the binary encoded Protobuf representation of this message instance.
722749
"""
750+
if self._is_pydantic():
751+
self._validate()
752+
723753
with BytesIO() as stream:
724754
for field_name, meta in self._betterproto.meta_by_field_name.items():
725755
value = getattr(self, field_name)
@@ -822,13 +852,17 @@ def _postprocess_single(self, wire_type: int, meta: FieldMetadata, field_name: s
822852
"""Adjusts values after parsing."""
823853
if wire_type == WIRE_VARINT:
824854
if meta.proto_type in (TYPE_INT32, TYPE_INT64):
825-
bits = int(meta.proto_type[3:])
855+
bits = 32 if meta.proto_type == TYPE_INT32 else 64
826856
value = value & ((1 << bits) - 1)
827857
signbit = 1 << (bits - 1)
828858
value = int((value ^ signbit) - signbit)
859+
elif meta.proto_type in (TYPE_UINT32, TYPE_UINT64):
860+
bits = 32 if meta.proto_type == TYPE_UINT32 else 64
861+
value = value & ((1 << bits) - 1)
829862
elif meta.proto_type in (TYPE_SINT32, TYPE_SINT64):
830-
# Undo zig-zag encoding
831-
value = (value >> 1) ^ (-(value & 1))
863+
bits = 32 if meta.proto_type == TYPE_SINT32 else 64
864+
value = value & ((1 << bits) - 1)
865+
value = (value >> 1) ^ (-(value & 1)) # Undo zig-zag encoding
832866
elif meta.proto_type == TYPE_BOOL:
833867
# Booleans use a varint encoding, so convert it to true/false.
834868
value = value > 0
@@ -947,6 +981,9 @@ def load(
947981
" or the expected size may have been incorrect."
948982
)
949983

984+
if self._is_pydantic():
985+
self._validate()
986+
950987
return self
951988

952989
@classmethod
@@ -1017,6 +1054,9 @@ def to_dict(
10171054
Dict[:class:`str`, Any]
10181055
The JSON serializable dict representation of this object.
10191056
"""
1057+
if self._is_pydantic():
1058+
self._validate()
1059+
10201060
kwargs = { # For recursive calls
10211061
"output_format": output_format,
10221062
"casing": casing,

tests/test_encoding_decoding.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
def test_int_overflow():
2+
"""Make sure that overflows in encoded values are handled correctly."""
3+
from tests.output_betterproto_pydantic.encoding_decoding import Overflow32, Overflow64
4+
5+
b = bytes(Overflow64(uint=2**50 + 42))
6+
msg = Overflow32.parse(b)
7+
assert msg.uint == 42
8+
9+
b = bytes(Overflow64(int=2**50 + 42))
10+
msg = Overflow32.parse(b)
11+
assert msg.int == 42
12+
13+
b = bytes(Overflow64(int=2**50 - 42))
14+
msg = Overflow32.parse(b)
15+
assert msg.int == -42
16+
17+
b = bytes(Overflow64(sint=2**50 + 42))
18+
msg = Overflow32.parse(b)
19+
assert msg.sint == 42
20+
21+
b = bytes(Overflow64(sint=-(2**50) - 42))
22+
msg = Overflow32.parse(b)
23+
assert msg.sint == -42

tests/test_manual_validation.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
import pydantic
2+
import pytest
3+
4+
5+
def test_manual_validation():
6+
from tests.output_betterproto_pydantic.manual_validation import Msg
7+
8+
msg = Msg()
9+
10+
msg.x = 12
11+
msg._validate()
12+
13+
msg.x = 2**50 # This is an invalid int32 value
14+
with pytest.raises(pydantic.ValidationError):
15+
msg._validate()
16+
17+
18+
def test_manual_validation_non_pydantic():
19+
from tests.output_betterproto.manual_validation import Msg
20+
21+
# Validation is not available for non-pydantic messages
22+
with pytest.raises(TypeError):
23+
Msg()._validate()

0 commit comments

Comments
 (0)