11import datetime
22import uuid
3- from collections .abc import AsyncGenerator , Iterable , Iterator , Mapping
4- from dataclasses import dataclass , field , fields
3+ from collections .abc import AsyncGenerator , Iterable , Mapping
4+ from dataclasses import Field , InitVar , dataclass , fields
55from typing import (
6+ Annotated ,
67 Any ,
78 Callable ,
89 Optional ,
910 TypeVar ,
1011 Union ,
12+ get_origin ,
13+ get_type_hints ,
1114)
1215
16+ from typing_extensions import TypeAlias
17+
1318from .base import Fragment , sql
1419
15- Where = Union [Fragment , Iterable [Fragment ]]
20+ Where : TypeAlias = Union [Fragment , Iterable [Fragment ]]
1621# KLUDGE to avoid a string argument being valid
17- SequenceOfStrings = Union [list [str ], tuple [str , ...]]
18- FieldNames = SequenceOfStrings
19- FieldNamesSet = Union [SequenceOfStrings , set [str ]]
22+ SequenceOfStrings : TypeAlias = Union [list [str ], tuple [str , ...]]
23+ FieldNames : TypeAlias = SequenceOfStrings
24+ FieldNamesSet : TypeAlias = Union [SequenceOfStrings , set [str ]]
2025
21- Connection = Any
22- Pool = Any
26+ Connection : TypeAlias = Any
27+ Pool : TypeAlias = Any
2328
2429
2530@dataclass
2631class ColumnInfo :
2732 type : str
28- create_type : str
29- constraints : tuple [ str , ...]
30-
31- def create_table_string ( self ):
32- return " " . join (( self . create_type , * self . constraints ))
33-
34-
35- def model_field_metadata (
36- type : str , constraints : Union [ str , Iterable [ str ]] = ()
37- ) -> dict [ str , Any ]:
38- if isinstance ( constraints , str ) :
39- constraints = (constraints ,)
40- info = ColumnInfo (
41- sql_create_type_map . get ( type . upper (), type ), type , tuple (constraints )
42- )
43- return { "sql_athame" : info }
44-
45-
46- def model_field (
47- * , type : str , constraints : Union [ str , Iterable [ str ]] = (), ** kwargs : Any
48- ) -> Any :
49- return field ( ** kwargs , metadata = model_field_metadata ( type , constraints ) )
33+ create_type : str = ""
34+ nullable : bool = False
35+ _constraints : tuple [ str , ...] = ()
36+
37+ constraints : InitVar [ Union [ str , Iterable [ str ], None ]] = None
38+
39+ def __post_init__ ( self , constraints : Union [ str , Iterable [ str ], None ]) -> None :
40+ if self . create_type == "" :
41+ self . create_type = self . type
42+ self . type = sql_create_type_map . get ( self . type . upper (), self . type )
43+ if constraints is not None :
44+ if type (constraints ) is str :
45+ constraints = ( constraints ,)
46+ self . _constraints = tuple (constraints )
47+
48+ def create_table_string ( self ) -> str :
49+ parts = (
50+ self . create_type ,
51+ * (() if self . nullable else ( "NOT NULL" ,)),
52+ * self . _constraints ,
53+ )
54+ return " " . join ( parts )
5055
5156
5257sql_create_type_map = {
@@ -56,43 +61,37 @@ def model_field(
5661}
5762
5863
59- sql_type_map = {
60- Optional [bool ]: ("BOOLEAN" ,),
61- Optional [bytes ]: ("BYTEA" ,),
62- Optional [datetime .date ]: ("DATE" ,),
63- Optional [datetime .datetime ]: ("TIMESTAMP" ,),
64- Optional [float ]: ("DOUBLE PRECISION" ,),
65- Optional [int ]: ("INTEGER" ,),
66- Optional [str ]: ("TEXT" ,),
67- Optional [uuid .UUID ]: ("UUID" ,),
68- bool : ("BOOLEAN" , "NOT NULL" ),
69- bytes : ("BYTEA" , "NOT NULL" ),
70- datetime .date : ("DATE" , "NOT NULL" ),
71- datetime .datetime : ("TIMESTAMP" , "NOT NULL" ),
72- float : ("DOUBLE PRECISION" , "NOT NULL" ),
73- int : ("INTEGER" , "NOT NULL" ),
74- str : ("TEXT" , "NOT NULL" ),
75- uuid .UUID : ("UUID" , "NOT NULL" ),
64+ sql_type_map : dict [ Any , tuple [ str , bool ]] = {
65+ Optional [bool ]: ("BOOLEAN" , True ),
66+ Optional [bytes ]: ("BYTEA" , True ),
67+ Optional [datetime .date ]: ("DATE" , True ),
68+ Optional [datetime .datetime ]: ("TIMESTAMP" , True ),
69+ Optional [float ]: ("DOUBLE PRECISION" , True ),
70+ Optional [int ]: ("INTEGER" , True ),
71+ Optional [str ]: ("TEXT" , True ),
72+ Optional [uuid .UUID ]: ("UUID" , True ),
73+ bool : ("BOOLEAN" , False ),
74+ bytes : ("BYTEA" , False ),
75+ datetime .date : ("DATE" , False ),
76+ datetime .datetime : ("TIMESTAMP" , False ),
77+ float : ("DOUBLE PRECISION" , False ),
78+ int : ("INTEGER" , False ),
79+ str : ("TEXT" , False ),
80+ uuid .UUID : ("UUID" , False ),
7681}
7782
7883
79- def column_info_for_field (field ):
80- if "sql_athame" in field .metadata :
81- return field .metadata ["sql_athame" ]
82- type , * constraints = sql_type_map [field .type ]
83- return ColumnInfo (type , type , tuple (constraints ))
84-
85-
8684T = TypeVar ("T" , bound = "ModelBase" )
8785U = TypeVar ("U" )
8886
8987
90- class ModelBase ( Mapping [ str , Any ]) :
88+ class ModelBase :
9189 _column_info : Optional [dict [str , ColumnInfo ]]
9290 _cache : dict [tuple , Any ]
9391 table_name : str
9492 primary_key_names : tuple [str , ...]
9593 array_safe_insert : bool
94+ _type_hints : dict [str , type ]
9695
9796 def __init_subclass__ (
9897 cls ,
@@ -130,27 +129,34 @@ def _cached(cls, key: tuple, thunk: Callable[[], U]) -> U:
130129 cls ._cache [key ] = thunk ()
131130 return cls ._cache [key ]
132131
133- def keys (self ):
134- return self .field_names ()
135-
136- def __getitem__ (self , key : str ) -> Any :
137- return getattr (self , key )
138-
139- def __iter__ (self ) -> Iterator [Any ]:
140- return iter (self .keys ())
141-
142- def __len__ (self ) -> int :
143- return len (self .keys ())
132+ @classmethod
133+ def type_hints (cls ) -> dict [str , type ]:
134+ try :
135+ return cls ._type_hints
136+ except AttributeError :
137+ cls ._type_hints = get_type_hints (cls , include_extras = True )
138+ return cls ._type_hints
144139
145- def get (self , key : str , default : Any = None ) -> Any :
146- return getattr (self , key , default )
140+ @classmethod
141+ def column_info_for_field (cls , field : Field ) -> ColumnInfo :
142+ type_info = cls .type_hints ()[field .name ]
143+ base_type = type_info
144+ if get_origin (type_info ) is Annotated :
145+ base_type = type_info .__origin__ # type: ignore
146+ for md in type_info .__metadata__ : # type: ignore
147+ if isinstance (md , ColumnInfo ):
148+ return md
149+ type , nullable = sql_type_map [base_type ]
150+ return ColumnInfo (type = type , nullable = nullable )
147151
148152 @classmethod
149153 def column_info (cls , column : str ) -> ColumnInfo :
150154 try :
151155 return cls ._column_info [column ] # type: ignore
152156 except AttributeError :
153- cls ._column_info = {f .name : column_info_for_field (f ) for f in cls ._fields ()}
157+ cls ._column_info = {
158+ f .name : cls .column_info_for_field (f ) for f in cls ._fields ()
159+ }
154160 return cls ._column_info [column ]
155161
156162 @classmethod
0 commit comments