|
1 | 1 | import datetime |
| 2 | +import functools |
2 | 3 | import uuid |
3 | 4 | from collections.abc import AsyncGenerator, Iterable, Mapping |
4 | 5 | from dataclasses import Field, InitVar, dataclass, fields |
|
29 | 30 |
|
30 | 31 | @dataclass |
31 | 32 | class ColumnInfo: |
32 | | - type: str |
33 | | - create_type: str = "" |
34 | | - nullable: bool = False |
| 33 | + type: Optional[str] = None |
| 34 | + create_type: Optional[str] = None |
| 35 | + nullable: Optional[bool] = None |
35 | 36 | _constraints: tuple[str, ...] = () |
36 | 37 |
|
37 | 38 | constraints: InitVar[Union[str, Iterable[str], None]] = None |
38 | 39 |
|
39 | 40 | def __post_init__(self, constraints: Union[str, Iterable[str], None]) -> None: |
40 | | - if self.create_type == "": |
41 | | - self.create_type = self.type |
42 | | - self.type = sql_create_type_map.get(self.type.upper(), self.type) |
43 | 41 | if constraints is not None: |
44 | 42 | if type(constraints) is str: |
45 | 43 | constraints = (constraints,) |
46 | 44 | self._constraints = tuple(constraints) |
47 | 45 |
|
| 46 | + @staticmethod |
| 47 | + def merge(a: "ColumnInfo", b: "ColumnInfo") -> "ColumnInfo": |
| 48 | + return ColumnInfo( |
| 49 | + type=b.type if b.type is not None else a.type, |
| 50 | + create_type=b.create_type if b.create_type is not None else a.create_type, |
| 51 | + nullable=b.nullable if b.nullable is not None else a.nullable, |
| 52 | + _constraints=(*a._constraints, *b._constraints), |
| 53 | + ) |
| 54 | + |
| 55 | + |
| 56 | +@dataclass |
| 57 | +class ConcreteColumnInfo: |
| 58 | + type: str |
| 59 | + create_type: str |
| 60 | + nullable: bool |
| 61 | + constraints: tuple[str, ...] |
| 62 | + |
| 63 | + @staticmethod |
| 64 | + def from_column_info(name: str, *args: ColumnInfo) -> "ConcreteColumnInfo": |
| 65 | + info = functools.reduce(ColumnInfo.merge, args, ColumnInfo()) |
| 66 | + if info.create_type is None and info.type is not None: |
| 67 | + info.create_type = info.type |
| 68 | + info.type = sql_create_type_map.get(info.type.upper(), info.type) |
| 69 | + if type(info.type) is not str or type(info.create_type) is not str: |
| 70 | + raise ValueError(f"Missing SQL type for column {name!r}") |
| 71 | + return ConcreteColumnInfo( |
| 72 | + type=info.type, |
| 73 | + create_type=info.create_type, |
| 74 | + nullable=bool(info.nullable), |
| 75 | + constraints=info._constraints, |
| 76 | + ) |
| 77 | + |
48 | 78 | def create_table_string(self) -> str: |
49 | 79 | parts = ( |
50 | 80 | self.create_type, |
51 | 81 | *(() if self.nullable else ("NOT NULL",)), |
52 | | - *self._constraints, |
| 82 | + *self.constraints, |
53 | 83 | ) |
54 | 84 | return " ".join(parts) |
55 | 85 |
|
@@ -86,7 +116,7 @@ def create_table_string(self) -> str: |
86 | 116 |
|
87 | 117 |
|
88 | 118 | class ModelBase: |
89 | | - _column_info: Optional[dict[str, ColumnInfo]] |
| 119 | + _column_info: Optional[dict[str, ConcreteColumnInfo]] |
90 | 120 | _cache: dict[tuple, Any] |
91 | 121 | table_name: str |
92 | 122 | primary_key_names: tuple[str, ...] |
@@ -138,19 +168,23 @@ def type_hints(cls) -> dict[str, type]: |
138 | 168 | return cls._type_hints |
139 | 169 |
|
140 | 170 | @classmethod |
141 | | - def column_info_for_field(cls, field: Field) -> ColumnInfo: |
| 171 | + def column_info_for_field(cls, field: Field) -> ConcreteColumnInfo: |
142 | 172 | type_info = cls.type_hints()[field.name] |
143 | 173 | base_type = type_info |
144 | 174 | if get_origin(type_info) is Annotated: |
145 | 175 | base_type = type_info.__origin__ # type: ignore |
| 176 | + info = [] |
| 177 | + if base_type in sql_type_map: |
| 178 | + _type, nullable = sql_type_map[base_type] |
| 179 | + info.append(ColumnInfo(type=_type, nullable=nullable)) |
| 180 | + if get_origin(type_info) is Annotated: |
146 | 181 | for md in type_info.__metadata__: # type: ignore |
147 | 182 | if isinstance(md, ColumnInfo): |
148 | | - return md |
149 | | - type, nullable = sql_type_map[base_type] |
150 | | - return ColumnInfo(type=type, nullable=nullable) |
| 183 | + info.append(md) |
| 184 | + return ConcreteColumnInfo.from_column_info(field.name, *info) |
151 | 185 |
|
152 | 186 | @classmethod |
153 | | - def column_info(cls, column: str) -> ColumnInfo: |
| 187 | + def column_info(cls, column: str) -> ConcreteColumnInfo: |
154 | 188 | try: |
155 | 189 | return cls._column_info[column] # type: ignore |
156 | 190 | except AttributeError: |
|
0 commit comments