Skip to content

Commit b4cca7c

Browse files
committed
Support 3.10+ UnionType
1 parent 1136b74 commit b4cca7c

2 files changed

Lines changed: 30 additions & 1 deletion

File tree

sql_athame/dataclasses.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import datetime
22
import functools
3+
import sys
34
import uuid
45
from collections.abc import AsyncGenerator, Iterable, Mapping
56
from dataclasses import Field, InitVar, dataclass, fields
@@ -105,12 +106,18 @@ def maybe_serialize(self, value: Any) -> Any:
105106
return value
106107

107108

109+
UNION_TYPES: tuple = (Union,)
110+
if sys.version_info >= (3, 10):
111+
from types import UnionType
112+
113+
UNION_TYPES = (Union, UnionType)
114+
108115
NULLABLE_TYPES = (type(None), Any, object)
109116

110117

111118
def split_nullable(typ: type) -> tuple[bool, type]:
112119
nullable = typ in NULLABLE_TYPES
113-
if get_origin(typ) is Union:
120+
if get_origin(typ) in UNION_TYPES:
114121
args = []
115122
for arg in get_args(typ):
116123
if arg in NULLABLE_TYPES:

tests/test_dataclasses.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
from __future__ import annotations
44

5+
import sys
56
import uuid
67
from dataclasses import dataclass
78
from typing import Annotated, Any, Optional, Union
@@ -101,6 +102,27 @@ class Test(ModelBase, table_name="table", primary_key="foo"):
101102
]
102103

103104

105+
@pytest.mark.skipif(sys.version_info < (3, 10), reason="needs python3.10 or greater")
106+
def test_py310_unions():
107+
@dataclass
108+
class Test(ModelBase, table_name="table", primary_key="foo"):
109+
foo: int
110+
bar: str
111+
baz: uuid.UUID | None
112+
foo_nullable: int | None
113+
bar_nullable: str | None
114+
115+
assert list(Test.create_table_sql()) == [
116+
'CREATE TABLE IF NOT EXISTS "table" ('
117+
'"foo" INTEGER NOT NULL, '
118+
'"bar" TEXT NOT NULL, '
119+
'"baz" UUID, '
120+
'"foo_nullable" INTEGER, '
121+
'"bar_nullable" TEXT, '
122+
'PRIMARY KEY ("foo"))'
123+
]
124+
125+
104126
def test_modelclass_missing_type():
105127
@dataclass
106128
class Test(ModelBase, table_name="table", primary_key="foo"):

0 commit comments

Comments
 (0)