Skip to content

Commit 69ce167

Browse files
authored
Merge pull request #22 from bdowning/updates
Try to detect nullability from Python types
2 parents 71e1272 + 5bb6710 commit 69ce167

4 files changed

Lines changed: 51 additions & 28 deletions

File tree

.bumpversion.cfg

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
[bumpversion]
2-
current_version = 0.4.0-alpha-7
2+
current_version = 0.4.0-alpha-8
33
parse = (?P<major>\d+)\.(?P<minor>\d+)\.(?P<patch>\d+)(-(?P<release>.*)-(?P<build>\d+))?
44
serialize =
55
{major}.{minor}.{patch}-{release}-{build}

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[tool.poetry]
22
name = "sql-athame"
3-
version = "0.4.0-alpha-7"
3+
version = "0.4.0-alpha-8"
44
description = "Python tool for slicing and dicing SQL"
55
authors = ["Brian Downing <bdowning@lavos.net>"]
66
license = "MIT"

sql_athame/dataclasses.py

Lines changed: 34 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
Optional,
1111
TypeVar,
1212
Union,
13+
get_args,
1314
get_origin,
1415
get_type_hints,
1516
)
@@ -84,30 +85,38 @@ def create_table_string(self) -> str:
8485
return " ".join(parts)
8586

8687

88+
NULLABLE_TYPES = (type(None), Any, object)
89+
90+
91+
def split_nullable(typ: type) -> tuple[bool, type]:
92+
nullable = typ in NULLABLE_TYPES
93+
if get_origin(typ) is Union:
94+
args = []
95+
for arg in get_args(typ):
96+
if arg in NULLABLE_TYPES:
97+
nullable = True
98+
else:
99+
args.append(arg)
100+
return nullable, Union[tuple(args)] # type: ignore
101+
return nullable, typ
102+
103+
87104
sql_create_type_map = {
88105
"BIGSERIAL": "BIGINT",
89106
"SERIAL": "INTEGER",
90107
"SMALLSERIAL": "SMALLINT",
91108
}
92109

93110

94-
sql_type_map: dict[Any, tuple[str, bool]] = {
95-
Optional[bool]: ("BOOLEAN", True),
96-
Optional[bytes]: ("BYTEA", True),
97-
Optional[datetime.date]: ("DATE", True),
98-
Optional[datetime.datetime]: ("TIMESTAMP", True),
99-
Optional[float]: ("DOUBLE PRECISION", True),
100-
Optional[int]: ("INTEGER", True),
101-
Optional[str]: ("TEXT", True),
102-
Optional[uuid.UUID]: ("UUID", True),
103-
bool: ("BOOLEAN", False),
104-
bytes: ("BYTEA", False),
105-
datetime.date: ("DATE", False),
106-
datetime.datetime: ("TIMESTAMP", False),
107-
float: ("DOUBLE PRECISION", False),
108-
int: ("INTEGER", False),
109-
str: ("TEXT", False),
110-
uuid.UUID: ("UUID", False),
111+
sql_type_map: dict[Any, str] = {
112+
bool: "BOOLEAN",
113+
bytes: "BYTEA",
114+
datetime.date: "DATE",
115+
datetime.datetime: "TIMESTAMP",
116+
float: "DOUBLE PRECISION",
117+
int: "INTEGER",
118+
str: "TEXT",
119+
uuid.UUID: "UUID",
111120
}
112121

113122

@@ -171,16 +180,16 @@ def type_hints(cls) -> dict[str, type]:
171180
def column_info_for_field(cls, field: Field) -> ConcreteColumnInfo:
172181
type_info = cls.type_hints()[field.name]
173182
base_type = type_info
183+
metadata = []
174184
if get_origin(type_info) is Annotated:
175-
base_type = type_info.__origin__ # type: ignore
176-
info = []
185+
base_type, *metadata = get_args(type_info)
186+
nullable, base_type = split_nullable(base_type)
187+
info = [ColumnInfo(nullable=nullable)]
177188
if base_type in sql_type_map:
178-
_type, nullable = sql_type_map[base_type]
179-
info.append(ColumnInfo(type=_type, nullable=nullable))
180-
if get_origin(type_info) is Annotated:
181-
for md in type_info.__metadata__: # type: ignore
182-
if isinstance(md, ColumnInfo):
183-
info.append(md)
189+
info.append(ColumnInfo(type=sql_type_map[base_type]))
190+
for md in metadata:
191+
if isinstance(md, ColumnInfo):
192+
info.append(md)
184193
return ConcreteColumnInfo.from_column_info(field.name, *info)
185194

186195
@classmethod

tests/test_dataclasses.py

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44

55
import uuid
66
from dataclasses import dataclass
7-
from typing import Annotated, Optional
7+
from typing import Annotated, Any, Optional, Union
88

99
import pytest
1010

@@ -75,6 +75,13 @@ class Test(ModelBase, table_name="table", primary_key="foo"):
7575
ColumnInfo(constraints="REFERENCES foobar"),
7676
ColumnInfo(constraints="BLAH", nullable=True),
7777
]
78+
any: Annotated[Any, ColumnInfo(type="TEXT")]
79+
any_not_null: Annotated[Any, ColumnInfo(type="TEXT", nullable=False)]
80+
obj: Annotated[object, ColumnInfo(type="TEXT")]
81+
obj_not_null: Annotated[object, ColumnInfo(type="TEXT", nullable=False)]
82+
combined_nullable: Annotated[Union[int, Any], ColumnInfo(type="INTEGER")]
83+
null_jsonb: Annotated[Optional[dict], ColumnInfo(type="JSONB")]
84+
not_null_jsonb: Annotated[dict, ColumnInfo(type="JSONB")]
7885

7986
assert list(Test.create_table_sql()) == [
8087
'CREATE TABLE IF NOT EXISTS "table" ('
@@ -83,6 +90,13 @@ class Test(ModelBase, table_name="table", primary_key="foo"):
8390
'"baz" UUID, '
8491
'"quux" INTEGER NOT NULL REFERENCES foobar, '
8592
'"quuux" INTEGER REFERENCES foobar BLAH, '
93+
'"any" TEXT, '
94+
'"any_not_null" TEXT NOT NULL, '
95+
'"obj" TEXT, '
96+
'"obj_not_null" TEXT NOT NULL, '
97+
'"combined_nullable" INTEGER, '
98+
'"null_jsonb" JSONB, '
99+
'"not_null_jsonb" JSONB NOT NULL, '
86100
'PRIMARY KEY ("foo"))'
87101
]
88102

0 commit comments

Comments
 (0)