Skip to content

Commit 39b58bb

Browse files
author
Zahari Kassabov
committed
Fix processing of dataclasses types
The workaround for python/cpython#137891 is more complicated than just accessing the __annotations__ of the dataclass instead of field.type. Indeed one has to also process all the base classes. Add a test for derived classes as well.
1 parent e16fd5b commit 39b58bb

2 files changed

Lines changed: 42 additions & 2 deletions

File tree

Lines changed: 26 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,41 @@
11
from typing import Any
22
import dataclasses
33

4-
from validobj import parse_input
4+
import pytest
5+
6+
from validobj import parse_input, ValidationError
57

68
@dataclasses.dataclass
79
class Linked:
810
value: Any
911
parents: Linked | list[Linked] | None = None
1012

13+
@dataclasses.dataclass
14+
class Parent:
15+
parent_field: int = 1
16+
override_field: str = 'parent'
17+
18+
@dataclasses.dataclass
19+
class Child(Parent):
20+
child_field: int = 2
21+
override_field: int = 3
22+
1123

1224
def test_delayed_annotations():
1325
inp = {'value': 1, 'parents': {'value': 2, 'parents': [{'value': 3}, {'value': 4, 'parents': {'value': 5}}]}}
1426
assert parse_input(inp, Linked).parents.value == 2
1527

28+
def test_derived_dataclasses():
29+
inp = {'child_field': 3, 'parent_field': 4, 'override_field': 5}
30+
res = parse_input(inp, Child)
31+
assert res.child_field == 3
32+
assert res.parent_field == 4
33+
assert res.override_field == 5
34+
35+
with pytest.raises(ValidationError):
36+
parse_input({'override_field': 'x'}, Child)
37+
38+
39+
40+
1641

validobj/validation.py

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -181,6 +181,21 @@ def _dataclasses_fields(class_or_instance):
181181
or f._field_type is dataclasses._FIELD_INITVAR
182182
)
183183

184+
def _dataclass_types(cls):
185+
""" Workaround for https://github.com/python/cpython/issues/137891
186+
187+
Note that, contrary to dataclasses.fields, the annotations for the base fields are
188+
not propagated automatically, so they need to be extracted from the base classes.
189+
"""
190+
res = {}
191+
192+
# Base classes, including current one, from parent to child, excluding object
193+
for base in cls.__mro__[1::-1]:
194+
if dataclasses.is_dataclass(base):
195+
res.update(base.__annotations__)
196+
197+
return res
198+
184199

185200
def _dataclass_required_allowed(fields):
186201
allowed = set()
@@ -210,7 +225,7 @@ def _parse_dataclass(value, spec):
210225
f"fields do not match.",
211226
)
212227
# Note: We don't use field.type because of https://github.com/python/cpython/issues/137891
213-
types = spec.__annotations__
228+
types = _dataclass_types(spec)
214229

215230
res = {}
216231
field_dict = {

0 commit comments

Comments
 (0)