11import datetime
22import uuid
33from collections .abc import AsyncGenerator , Iterable , Mapping
4- from dataclasses import InitVar , dataclass , field , fields
4+ from dataclasses import Field , InitVar , dataclass , fields
55from typing import (
6+ Annotated ,
67 Any ,
78 Callable ,
89 Optional ,
910 TypeVar ,
1011 Union ,
12+ get_origin ,
13+ get_type_hints ,
1114)
1215
1316from .base import Fragment , sql
@@ -31,7 +34,7 @@ class ColumnInfo:
3134
3235 constraints : InitVar [Union [str , Iterable [str ], None ]] = None
3336
34- def __post_init__ (self , constraints : Union [str , Iterable [str ], None ]):
37+ def __post_init__ (self , constraints : Union [str , Iterable [str ], None ]) -> None :
3538 if self .create_type == "" :
3639 self .create_type = self .type
3740 self .type = sql_create_type_map .get (self .type .upper (), self .type )
@@ -40,7 +43,7 @@ def __post_init__(self, constraints: Union[str, Iterable[str], None]):
4043 constraints = (constraints ,)
4144 self ._constraints = tuple (constraints )
4245
43- def create_table_string (self ):
46+ def create_table_string (self ) -> str :
4447 parts = (
4548 self .create_type ,
4649 * (() if self .nullable else ("NOT NULL" ,)),
@@ -49,30 +52,14 @@ def create_table_string(self):
4952 return " " .join (parts )
5053
5154
52- def model_field_metadata (
53- type : str , nullable : bool = False , constraints : Union [str , Iterable [str ]] = ()
54- ) -> dict [str , Any ]:
55- if isinstance (constraints , str ):
56- constraints = (constraints ,)
57- info = ColumnInfo (type = type , nullable = nullable , constraints = constraints )
58-
59- return {"sql_athame" : info }
60-
61-
62- def model_field (
63- * , type : str , constraints : Union [str , Iterable [str ]] = (), ** kwargs : Any
64- ) -> Any :
65- return field (** kwargs , metadata = model_field_metadata (type , constraints ))
66-
67-
6855sql_create_type_map = {
6956 "BIGSERIAL" : "BIGINT" ,
7057 "SERIAL" : "INTEGER" ,
7158 "SMALLSERIAL" : "SMALLINT" ,
7259}
7360
7461
75- sql_type_map : dict [type , tuple [str , bool ]] = {
62+ sql_type_map : dict [Any , tuple [str , bool ]] = {
7663 Optional [bool ]: ("BOOLEAN" , True ),
7764 Optional [bytes ]: ("BYTEA" , True ),
7865 Optional [datetime .date ]: ("DATE" , True ),
@@ -92,13 +79,6 @@ def model_field(
9279}
9380
9481
95- def column_info_for_field (field ):
96- if "sql_athame" in field .metadata :
97- return field .metadata ["sql_athame" ]
98- type , nullable = sql_type_map [field .type ]
99- return ColumnInfo (type = type , nullable = nullable )
100-
101-
10282T = TypeVar ("T" , bound = "ModelBase" )
10383U = TypeVar ("U" )
10484
@@ -109,6 +89,7 @@ class ModelBase:
10989 table_name : str
11090 primary_key_names : tuple [str , ...]
11191 array_safe_insert : bool
92+ _type_hints : dict [str , type ]
11293
11394 def __init_subclass__ (
11495 cls ,
@@ -146,12 +127,34 @@ def _cached(cls, key: tuple, thunk: Callable[[], U]) -> U:
146127 cls ._cache [key ] = thunk ()
147128 return cls ._cache [key ]
148129
130+ @classmethod
131+ def type_hints (cls ) -> dict [str , type ]:
132+ try :
133+ return cls ._type_hints
134+ except AttributeError :
135+ cls ._type_hints = get_type_hints (cls , include_extras = True )
136+ return cls ._type_hints
137+
138+ @classmethod
139+ def column_info_for_field (cls , field : Field ) -> ColumnInfo :
140+ type_info = cls .type_hints ()[field .name ]
141+ base_type = type_info
142+ if get_origin (type_info ) is Annotated :
143+ base_type = type_info .__origin__ # type: ignore
144+ for md in type_info .__metadata__ : # type: ignore
145+ if isinstance (md , ColumnInfo ):
146+ return md
147+ type , nullable = sql_type_map [base_type ]
148+ return ColumnInfo (type = type , nullable = nullable )
149+
149150 @classmethod
150151 def column_info (cls , column : str ) -> ColumnInfo :
151152 try :
152153 return cls ._column_info [column ] # type: ignore
153154 except AttributeError :
154- cls ._column_info = {f .name : column_info_for_field (f ) for f in cls ._fields ()}
155+ cls ._column_info = {
156+ f .name : cls .column_info_for_field (f ) for f in cls ._fields ()
157+ }
155158 return cls ._column_info [column ]
156159
157160 @classmethod
0 commit comments