|
10 | 10 | Optional, |
11 | 11 | TypeVar, |
12 | 12 | Union, |
| 13 | + get_args, |
13 | 14 | get_origin, |
14 | 15 | get_type_hints, |
15 | 16 | ) |
@@ -84,30 +85,38 @@ def create_table_string(self) -> str: |
84 | 85 | return " ".join(parts) |
85 | 86 |
|
86 | 87 |
|
| 88 | +NULLABLE_TYPES = (type(None), Any, object) |
| 89 | + |
| 90 | + |
| 91 | +def split_nullable(typ: type) -> tuple[bool, type]: |
| 92 | + nullable = typ in NULLABLE_TYPES |
| 93 | + if get_origin(typ) is Union: |
| 94 | + args = [] |
| 95 | + for arg in get_args(typ): |
| 96 | + if arg in NULLABLE_TYPES: |
| 97 | + nullable = True |
| 98 | + else: |
| 99 | + args.append(arg) |
| 100 | + return nullable, Union[tuple(args)] # type: ignore |
| 101 | + return nullable, typ |
| 102 | + |
| 103 | + |
87 | 104 | sql_create_type_map = { |
88 | 105 | "BIGSERIAL": "BIGINT", |
89 | 106 | "SERIAL": "INTEGER", |
90 | 107 | "SMALLSERIAL": "SMALLINT", |
91 | 108 | } |
92 | 109 |
|
93 | 110 |
|
94 | | -sql_type_map: dict[Any, tuple[str, bool]] = { |
95 | | - Optional[bool]: ("BOOLEAN", True), |
96 | | - Optional[bytes]: ("BYTEA", True), |
97 | | - Optional[datetime.date]: ("DATE", True), |
98 | | - Optional[datetime.datetime]: ("TIMESTAMP", True), |
99 | | - Optional[float]: ("DOUBLE PRECISION", True), |
100 | | - Optional[int]: ("INTEGER", True), |
101 | | - Optional[str]: ("TEXT", True), |
102 | | - Optional[uuid.UUID]: ("UUID", True), |
103 | | - bool: ("BOOLEAN", False), |
104 | | - bytes: ("BYTEA", False), |
105 | | - datetime.date: ("DATE", False), |
106 | | - datetime.datetime: ("TIMESTAMP", False), |
107 | | - float: ("DOUBLE PRECISION", False), |
108 | | - int: ("INTEGER", False), |
109 | | - str: ("TEXT", False), |
110 | | - uuid.UUID: ("UUID", False), |
| 111 | +sql_type_map: dict[Any, str] = { |
| 112 | + bool: "BOOLEAN", |
| 113 | + bytes: "BYTEA", |
| 114 | + datetime.date: "DATE", |
| 115 | + datetime.datetime: "TIMESTAMP", |
| 116 | + float: "DOUBLE PRECISION", |
| 117 | + int: "INTEGER", |
| 118 | + str: "TEXT", |
| 119 | + uuid.UUID: "UUID", |
111 | 120 | } |
112 | 121 |
|
113 | 122 |
|
@@ -171,16 +180,16 @@ def type_hints(cls) -> dict[str, type]: |
171 | 180 | def column_info_for_field(cls, field: Field) -> ConcreteColumnInfo: |
172 | 181 | type_info = cls.type_hints()[field.name] |
173 | 182 | base_type = type_info |
| 183 | + metadata = [] |
174 | 184 | if get_origin(type_info) is Annotated: |
175 | | - base_type = type_info.__origin__ # type: ignore |
176 | | - info = [] |
| 185 | + base_type, *metadata = get_args(type_info) |
| 186 | + nullable, base_type = split_nullable(base_type) |
| 187 | + info = [ColumnInfo(nullable=nullable)] |
177 | 188 | if base_type in sql_type_map: |
178 | | - _type, nullable = sql_type_map[base_type] |
179 | | - info.append(ColumnInfo(type=_type, nullable=nullable)) |
180 | | - if get_origin(type_info) is Annotated: |
181 | | - for md in type_info.__metadata__: # type: ignore |
182 | | - if isinstance(md, ColumnInfo): |
183 | | - info.append(md) |
| 189 | + info.append(ColumnInfo(type=sql_type_map[base_type])) |
| 190 | + for md in metadata: |
| 191 | + if isinstance(md, ColumnInfo): |
| 192 | + info.append(md) |
184 | 193 | return ConcreteColumnInfo.from_column_info(field.name, *info) |
185 | 194 |
|
186 | 195 | @classmethod |
|
0 commit comments