@@ -34,8 +34,8 @@ class ColumnInfo:
3434 type : Optional [str ] = None
3535 create_type : Optional [str ] = None
3636 nullable : Optional [bool ] = None
37- _constraints : tuple [str , ...] = ()
3837
38+ _constraints : tuple [str , ...] = ()
3939 constraints : InitVar [Union [str , Iterable [str ], None ]] = None
4040
4141 def __post_init__ (self , constraints : Union [str , Iterable [str ], None ]) -> None :
@@ -56,20 +56,26 @@ def merge(a: "ColumnInfo", b: "ColumnInfo") -> "ColumnInfo":
5656
5757@dataclass
5858class ConcreteColumnInfo :
59+ field : Field
60+ type_hint : type
5961 type : str
6062 create_type : str
6163 nullable : bool
6264 constraints : tuple [str , ...]
6365
6466 @staticmethod
65- def from_column_info (name : str , * args : ColumnInfo ) -> "ConcreteColumnInfo" :
67+ def from_column_info (
68+ field : Field , type_hint : Any , * args : ColumnInfo
69+ ) -> "ConcreteColumnInfo" :
6670 info = functools .reduce (ColumnInfo .merge , args , ColumnInfo ())
6771 if info .create_type is None and info .type is not None :
6872 info .create_type = info .type
6973 info .type = sql_create_type_map .get (info .type .upper (), info .type )
7074 if type (info .type ) is not str or type (info .create_type ) is not str :
71- raise ValueError (f"Missing SQL type for column { name !r} " )
75+ raise ValueError (f"Missing SQL type for column { field . name !r} " )
7276 return ConcreteColumnInfo (
77+ field = field ,
78+ type_hint = type_hint ,
7379 type = info .type ,
7480 create_type = info .create_type ,
7581 nullable = bool (info .nullable ),
@@ -108,7 +114,7 @@ def split_nullable(typ: type) -> tuple[bool, type]:
108114}
109115
110116
111- sql_type_map : dict [Any , str ] = {
117+ sql_type_map : dict [type , str ] = {
112118 bool : "BOOLEAN" ,
113119 bytes : "BYTEA" ,
114120 datetime .date : "DATE" ,
@@ -125,12 +131,11 @@ def split_nullable(typ: type) -> tuple[bool, type]:
125131
126132
127133class ModelBase :
128- _column_info : Optional [ dict [str , ConcreteColumnInfo ] ]
134+ _column_info : dict [str , ConcreteColumnInfo ]
129135 _cache : dict [tuple , Any ]
130136 table_name : str
131137 primary_key_names : tuple [str , ...]
132138 array_safe_insert : bool
133- _type_hints : dict [str , type ]
134139
135140 def __init_subclass__ (
136141 cls ,
@@ -153,13 +158,6 @@ def __init_subclass__(
153158 else :
154159 cls .primary_key_names = tuple (primary_key )
155160
156- @classmethod
157- def _fields (cls ):
158- # wrapper to ignore typing weirdness: 'Argument 1 to "fields"
159- # has incompatible type "..."; expected "DataclassInstance |
160- # type[DataclassInstance]"'
161- return fields (cls ) # type: ignore
162-
163161 @classmethod
164162 def _cached (cls , key : tuple , thunk : Callable [[], U ]) -> U :
165163 try :
@@ -169,38 +167,31 @@ def _cached(cls, key: tuple, thunk: Callable[[], U]) -> U:
169167 return cls ._cache [key ]
170168
171169 @classmethod
172- def type_hints (cls ) -> dict [str , type ]:
173- try :
174- return cls ._type_hints
175- except AttributeError :
176- cls ._type_hints = get_type_hints (cls , include_extras = True )
177- return cls ._type_hints
178-
179- @classmethod
180- def column_info_for_field (cls , field : Field ) -> ConcreteColumnInfo :
181- type_info = cls .type_hints ()[field .name ]
182- base_type = type_info
170+ def column_info_for_field (cls , field : Field , type_hint : type ) -> ConcreteColumnInfo :
171+ base_type = type_hint
183172 metadata = []
184- if get_origin (type_info ) is Annotated :
185- base_type , * metadata = get_args (type_info )
173+ if get_origin (type_hint ) is Annotated :
174+ base_type , * metadata = get_args (type_hint )
186175 nullable , base_type = split_nullable (base_type )
187176 info = [ColumnInfo (nullable = nullable )]
188177 if base_type in sql_type_map :
189178 info .append (ColumnInfo (type = sql_type_map [base_type ]))
190179 for md in metadata :
191180 if isinstance (md , ColumnInfo ):
192181 info .append (md )
193- return ConcreteColumnInfo .from_column_info (field . name , * info )
182+ return ConcreteColumnInfo .from_column_info (field , type_hint , * info )
194183
195184 @classmethod
196- def column_info (cls , column : str ) -> ConcreteColumnInfo :
185+ def column_info (cls ) -> dict [ str , ConcreteColumnInfo ] :
197186 try :
198- return cls ._column_info [ column ] # type: ignore
187+ return cls ._column_info
199188 except AttributeError :
189+ type_hints = get_type_hints (cls , include_extras = True )
200190 cls ._column_info = {
201- f .name : cls .column_info_for_field (f ) for f in cls ._fields ()
191+ f .name : cls .column_info_for_field (f , type_hints [f .name ])
192+ for f in fields (cls ) # type: ignore
202193 }
203- return cls ._column_info [ column ]
194+ return cls ._column_info
204195
205196 @classmethod
206197 def table_name_sql (cls , * , prefix : Optional [str ] = None ) -> Fragment :
@@ -212,7 +203,11 @@ def primary_key_names_sql(cls, *, prefix: Optional[str] = None) -> list[Fragment
212203
213204 @classmethod
214205 def field_names (cls , * , exclude : FieldNamesSet = ()) -> list [str ]:
215- return [f .name for f in cls ._fields () if f .name not in exclude ]
206+ return [
207+ ci .field .name
208+ for ci in cls .column_info ().values ()
209+ if ci .field .name not in exclude
210+ ]
216211
217212 @classmethod
218213 def field_names_sql (
@@ -231,9 +226,9 @@ def _get_field_values_fn(
231226 ) -> Callable [[T ], list [Any ]]:
232227 env : dict [str , Any ] = {}
233228 func = ["def get_field_values(self): return [" ]
234- for f in cls ._fields ():
235- if f .name not in exclude :
236- func .append (f"self.{ f .name } ," )
229+ for ci in cls .column_info (). values ():
230+ if ci . field .name not in exclude :
231+ func .append (f"self.{ ci . field .name } ," )
237232 func += ["]" ]
238233 exec (" " .join (func ), env )
239234 return env ["get_field_values" ]
@@ -256,19 +251,15 @@ def field_values_sql(
256251 else :
257252 return [sql .value (value ) for value in self .field_values ()]
258253
259- @classmethod
260- def from_tuple (
261- cls : type [T ], tup : tuple , * , offset : int = 0 , exclude : FieldNamesSet = ()
262- ) -> T :
263- names = (f .name for f in cls ._fields () if f .name not in exclude )
264- kwargs = {name : tup [offset ] for offset , name in enumerate (names , start = offset )}
265- return cls (** kwargs )
266-
267254 @classmethod
268255 def from_dict (
269256 cls : type [T ], dct : dict [str , Any ], * , exclude : FieldNamesSet = ()
270257 ) -> T :
271- names = {f .name for f in cls ._fields () if f .name not in exclude }
258+ names = {
259+ ci .field .name
260+ for ci in cls .column_info ().values ()
261+ if ci .field .name not in exclude
262+ }
272263 kwargs = {k : v for k , v in dct .items () if k in names }
273264 return cls (** kwargs )
274265
@@ -283,10 +274,10 @@ def create_table_sql(cls) -> Fragment:
283274 entries = [
284275 sql (
285276 "{} {}" ,
286- sql .identifier (f .name ),
287- sql .literal (cls . column_info ( f . name ) .create_table_string ()),
277+ sql .identifier (ci . field .name ),
278+ sql .literal (ci .create_table_string ()),
288279 )
289- for f in cls ._fields ()
280+ for ci in cls .column_info (). values ()
290281 ]
291282 if cls .primary_key_names :
292283 entries += [sql ("PRIMARY KEY ({})" , sql .list (cls .primary_key_names_sql ()))]
@@ -428,10 +419,11 @@ def delete_multiple_sql(cls: type[T], rows: Iterable[T]) -> Fragment:
428419 pks = sql .list (sql .identifier (pk ) for pk in cls .primary_key_names ),
429420 ).compile (),
430421 )
422+ column_info = cls .column_info ()
431423 return cached (
432424 unnest = sql .unnest (
433425 (row .primary_key () for row in rows ),
434- (cls . column_info ( pk ) .type for pk in cls .primary_key_names ),
426+ (column_info [ pk ] .type for pk in cls .primary_key_names ),
435427 ),
436428 )
437429
@@ -451,10 +443,11 @@ def insert_multiple_sql(cls: type[T], rows: Iterable[T]) -> Fragment:
451443 fields = sql .list (cls .field_names_sql ()),
452444 ).compile (),
453445 )
446+ column_info = cls .column_info ()
454447 return cached (
455448 unnest = sql .unnest (
456449 (row .field_values () for row in rows ),
457- (cls . column_info ( name ) .type for name in cls .field_names ()),
450+ (column_info [ name ] .type for name in cls .field_names ()),
458451 ),
459452 )
460453
@@ -545,9 +538,9 @@ def _get_equal_ignoring_fn(
545538 ) -> Callable [[T , T ], bool ]:
546539 env : dict [str , Any ] = {}
547540 func = ["def equal_ignoring(a, b):" ]
548- for f in cls ._fields ():
549- if f .name not in ignore :
550- func .append (f" if a.{ f . name } != b.{ f .name } : return False" )
541+ for ci in cls .column_info (). values ():
542+ if ci . field .name not in ignore :
543+ func .append (f" if a.{ ci . field . name } != b.{ ci . field .name } : return False" )
551544 func += [" return True" ]
552545 exec ("\n " .join (func ), env )
553546 return env ["equal_ignoring" ]
@@ -603,9 +596,11 @@ def _get_differences_ignoring_fn(
603596 "def differences_ignoring(a, b):" ,
604597 " diffs = []" ,
605598 ]
606- for f in cls ._fields ():
607- if f .name not in ignore :
608- func .append (f" if a.{ f .name } != b.{ f .name } : diffs.append({ f .name !r} )" )
599+ for ci in cls .column_info ().values ():
600+ if ci .field .name not in ignore :
601+ func .append (
602+ f" if a.{ ci .field .name } != b.{ ci .field .name } : diffs.append({ ci .field .name !r} )"
603+ )
609604 func += [" return diffs" ]
610605 exec ("\n " .join (func ), env )
611606 return env ["differences_ignoring" ]
0 commit comments