Skip to content

Commit 4af468d

Browse files
author
Zahari Kassabov
committed
Support forward refs
Use get_type_hints instead of the raw annotations to resolve references within dataclasses, namedtuples and typed dicts.
1 parent 82a214f commit 4af468d

3 files changed

Lines changed: 76 additions & 7 deletions

File tree

validobj/tests/test_custom.py

Lines changed: 19 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import pytest
22
import dataclasses
3-
from typing import Any
3+
from typing import Any, TypedDict
44

55
try:
66
from validobj.custom import Parser, Validator, InputType
@@ -13,7 +13,7 @@
1313

1414

1515
@pytest.mark.skipif(not HAVE_CUSTOM, reason="Custom type not found")
16-
def test_custom():
16+
def test_custom_dataclass():
1717
def my_float(inp: str) -> float:
1818
return float(inp) + 1
1919

@@ -30,6 +30,23 @@ class Container:
3030
with pytest.raises(ValidationError):
3131
parse_input({"value": 5}, Container)
3232

33+
@pytest.mark.skipif(not HAVE_CUSTOM, reason="Custom type not found")
34+
def test_custom_typeddict():
35+
def my_float(inp: str) -> float:
36+
return float(inp) + 1
37+
38+
MyFloat = Parser(my_float)
39+
assert MyFloat.__origin__ is float
40+
assert isinstance(MyFloat.__metadata__[0], InputType)
41+
assert isinstance(MyFloat.__metadata__[1], Validator)
42+
43+
class Container(TypedDict):
44+
value: MyFloat
45+
46+
assert parse_input({"value": "5"}, Container) == Container(value=6)
47+
with pytest.raises(ValidationError):
48+
parse_input({"value": 5}, Container)
49+
3350

3451
@pytest.mark.skipif(not HAVE_CUSTOM, reason="Custom type not found")
3552
def test_no_annotations():

validobj/tests/test_forward.py

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
from typing import NamedTuple, TypedDict, Optional
2+
import dataclasses
3+
4+
import validobj
5+
6+
7+
class C(TypedDict):
8+
a: 'A'
9+
b: 'B'
10+
c: Optional['C'] = None
11+
12+
13+
class B(NamedTuple):
14+
a0: 'A'
15+
a1: 'A'
16+
17+
18+
@dataclasses.dataclass
19+
class A:
20+
children: list['A']
21+
22+
23+
def test_dataclass():
24+
assert validobj.parse_input({"children": [{"children": []}]}, A)
25+
26+
27+
def test_namedtuple():
28+
assert validobj.parse_input([{"children": []}, {"children": []}], B)
29+
30+
31+
def test_typeddict():
32+
assert validobj.parse_input(
33+
{
34+
'a': {'children': [{'children': []}]},
35+
'b': [{"children": []}, {"children": []}],
36+
'c': None,
37+
},
38+
C,
39+
)

validobj/validation.py

Lines changed: 18 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,8 @@
1010
1111
1212
"""
13-
from typing import Set, Union, Any, Optional, TypeVar, Type, Literal
13+
14+
from typing import Set, Union, Any, Optional, TypeVar, Type, Literal, get_type_hints
1415
import sys
1516

1617
try:
@@ -209,10 +210,18 @@ def _parse_dataclass(value, spec):
209210
header=f"Cannot process value into {_typename(spec)!r} because "
210211
f"fields do not match.",
211212
)
213+
214+
# Use this to resolve forward references.
215+
annotations = get_type_hints(spec, include_extras=True)
216+
212217
res = {}
213218
field_dict = {
214219
# Look inside InitVar
215-
f.name: f.type if not isinstance(f.type, dataclasses.InitVar) else f.type.type
220+
f.name: (
221+
annotations[f.name]
222+
if not isinstance(f.type, dataclasses.InitVar)
223+
else annotations[f.name].type
224+
)
216225
for f in fields
217226
}
218227
for k, v in value.items():
@@ -239,10 +248,12 @@ def _parse_typed_dict(value, spec):
239248
header=f"Cannot process value into {_typename(spec)!r} because "
240249
f"fields do not match.",
241250
)
251+
# Resolve forward references.
252+
annotations = get_type_hints(spec, include_extras=True)
242253
res = {}
243254
for k, v in value.items():
244255
try:
245-
res[k] = parse_input(v, spec.__annotations__[k])
256+
res[k] = parse_input(v, annotations[k])
246257
except ValidationError as e:
247258
raise WrongFieldError(
248259
f"Cannot process field {k!r} of value into the "
@@ -281,10 +292,12 @@ def _parse_namedtuple(value, spec):
281292

282293
res = {}
283294

295+
annotations = get_type_hints(spec, include_extras=True)
296+
284297
for i, (k, v) in enumerate(field_inputs.items()):
285-
if k in spec.__annotations__:
298+
if k in annotations:
286299
try:
287-
res[k] = parse_input(v, spec.__annotations__[k])
300+
res[k] = parse_input(v, annotations[k])
288301
except ValidationError as e:
289302
raise WrongListItemError(
290303
f"Cannot process list item {i+1} into the field {k!r} of {_typename(spec)!r}",

0 commit comments

Comments
 (0)