11import datetime
22import functools
3+ import sys
34import uuid
45from collections .abc import AsyncGenerator , Iterable , Mapping
56from dataclasses import Field , InitVar , dataclass , fields
@@ -34,10 +35,13 @@ class ColumnInfo:
3435 type : Optional [str ] = None
3536 create_type : Optional [str ] = None
3637 nullable : Optional [bool ] = None
37- _constraints : tuple [str , ...] = ()
3838
39+ _constraints : tuple [str , ...] = ()
3940 constraints : InitVar [Union [str , Iterable [str ], None ]] = None
4041
42+ serialize : Optional [Callable [[Any ], Any ]] = None
43+ deserialize : Optional [Callable [[Any ], Any ]] = None
44+
4145 def __post_init__ (self , constraints : Union [str , Iterable [str ], None ]) -> None :
4246 if constraints is not None :
4347 if type (constraints ) is str :
@@ -51,29 +55,41 @@ def merge(a: "ColumnInfo", b: "ColumnInfo") -> "ColumnInfo":
5155 create_type = b .create_type if b .create_type is not None else a .create_type ,
5256 nullable = b .nullable if b .nullable is not None else a .nullable ,
5357 _constraints = (* a ._constraints , * b ._constraints ),
58+ serialize = b .serialize if b .serialize is not None else a .serialize ,
59+ deserialize = b .deserialize if b .deserialize is not None else a .deserialize ,
5460 )
5561
5662
5763@dataclass
5864class ConcreteColumnInfo :
65+ field : Field
66+ type_hint : type
5967 type : str
6068 create_type : str
6169 nullable : bool
6270 constraints : tuple [str , ...]
71+ serialize : Optional [Callable [[Any ], Any ]] = None
72+ deserialize : Optional [Callable [[Any ], Any ]] = None
6373
6474 @staticmethod
65- def from_column_info (name : str , * args : ColumnInfo ) -> "ConcreteColumnInfo" :
75+ def from_column_info (
76+ field : Field , type_hint : Any , * args : ColumnInfo
77+ ) -> "ConcreteColumnInfo" :
6678 info = functools .reduce (ColumnInfo .merge , args , ColumnInfo ())
6779 if info .create_type is None and info .type is not None :
6880 info .create_type = info .type
6981 info .type = sql_create_type_map .get (info .type .upper (), info .type )
7082 if type (info .type ) is not str or type (info .create_type ) is not str :
71- raise ValueError (f"Missing SQL type for column { name !r} " )
83+ raise ValueError (f"Missing SQL type for column { field . name !r} " )
7284 return ConcreteColumnInfo (
85+ field = field ,
86+ type_hint = type_hint ,
7387 type = info .type ,
7488 create_type = info .create_type ,
7589 nullable = bool (info .nullable ),
7690 constraints = info ._constraints ,
91+ serialize = info .serialize ,
92+ deserialize = info .deserialize ,
7793 )
7894
7995 def create_table_string (self ) -> str :
@@ -84,13 +100,24 @@ def create_table_string(self) -> str:
84100 )
85101 return " " .join (parts )
86102
103+ def maybe_serialize (self , value : Any ) -> Any :
104+ if self .serialize :
105+ return self .serialize (value )
106+ return value
107+
108+
109+ UNION_TYPES : tuple = (Union ,)
110+ if sys .version_info >= (3 , 10 ):
111+ from types import UnionType
112+
113+ UNION_TYPES = (Union , UnionType )
87114
88115NULLABLE_TYPES = (type (None ), Any , object )
89116
90117
91118def split_nullable (typ : type ) -> tuple [bool , type ]:
92119 nullable = typ in NULLABLE_TYPES
93- if get_origin (typ ) is Union :
120+ if get_origin (typ ) in UNION_TYPES :
94121 args = []
95122 for arg in get_args (typ ):
96123 if arg in NULLABLE_TYPES :
@@ -108,7 +135,7 @@ def split_nullable(typ: type) -> tuple[bool, type]:
108135}
109136
110137
111- sql_type_map : dict [Any , str ] = {
138+ sql_type_map : dict [type , str ] = {
112139 bool : "BOOLEAN" ,
113140 bytes : "BYTEA" ,
114141 datetime .date : "DATE" ,
@@ -125,12 +152,11 @@ def split_nullable(typ: type) -> tuple[bool, type]:
125152
126153
127154class ModelBase :
128- _column_info : Optional [ dict [str , ConcreteColumnInfo ] ]
155+ _column_info : dict [str , ConcreteColumnInfo ]
129156 _cache : dict [tuple , Any ]
130157 table_name : str
131158 primary_key_names : tuple [str , ...]
132159 array_safe_insert : bool
133- _type_hints : dict [str , type ]
134160
135161 def __init_subclass__ (
136162 cls ,
@@ -153,13 +179,6 @@ def __init_subclass__(
153179 else :
154180 cls .primary_key_names = tuple (primary_key )
155181
156- @classmethod
157- def _fields (cls ):
158- # wrapper to ignore typing weirdness: 'Argument 1 to "fields"
159- # has incompatible type "..."; expected "DataclassInstance |
160- # type[DataclassInstance]"'
161- return fields (cls ) # type: ignore
162-
163182 @classmethod
164183 def _cached (cls , key : tuple , thunk : Callable [[], U ]) -> U :
165184 try :
@@ -169,38 +188,31 @@ def _cached(cls, key: tuple, thunk: Callable[[], U]) -> U:
169188 return cls ._cache [key ]
170189
171190 @classmethod
172- def type_hints (cls ) -> dict [str , type ]:
173- try :
174- return cls ._type_hints
175- except AttributeError :
176- cls ._type_hints = get_type_hints (cls , include_extras = True )
177- return cls ._type_hints
178-
179- @classmethod
180- def column_info_for_field (cls , field : Field ) -> ConcreteColumnInfo :
181- type_info = cls .type_hints ()[field .name ]
182- base_type = type_info
191+ def column_info_for_field (cls , field : Field , type_hint : type ) -> ConcreteColumnInfo :
192+ base_type = type_hint
183193 metadata = []
184- if get_origin (type_info ) is Annotated :
185- base_type , * metadata = get_args (type_info )
194+ if get_origin (type_hint ) is Annotated :
195+ base_type , * metadata = get_args (type_hint )
186196 nullable , base_type = split_nullable (base_type )
187197 info = [ColumnInfo (nullable = nullable )]
188198 if base_type in sql_type_map :
189199 info .append (ColumnInfo (type = sql_type_map [base_type ]))
190200 for md in metadata :
191201 if isinstance (md , ColumnInfo ):
192202 info .append (md )
193- return ConcreteColumnInfo .from_column_info (field . name , * info )
203+ return ConcreteColumnInfo .from_column_info (field , type_hint , * info )
194204
195205 @classmethod
196- def column_info (cls , column : str ) -> ConcreteColumnInfo :
206+ def column_info (cls ) -> dict [ str , ConcreteColumnInfo ] :
197207 try :
198- return cls ._column_info [ column ] # type: ignore
208+ return cls ._column_info
199209 except AttributeError :
210+ type_hints = get_type_hints (cls , include_extras = True )
200211 cls ._column_info = {
201- f .name : cls .column_info_for_field (f ) for f in cls ._fields ()
212+ f .name : cls .column_info_for_field (f , type_hints [f .name ])
213+ for f in fields (cls ) # type: ignore
202214 }
203- return cls ._column_info [ column ]
215+ return cls ._column_info
204216
205217 @classmethod
206218 def table_name_sql (cls , * , prefix : Optional [str ] = None ) -> Fragment :
@@ -212,7 +224,11 @@ def primary_key_names_sql(cls, *, prefix: Optional[str] = None) -> list[Fragment
212224
213225 @classmethod
214226 def field_names (cls , * , exclude : FieldNamesSet = ()) -> list [str ]:
215- return [f .name for f in cls ._fields () if f .name not in exclude ]
227+ return [
228+ ci .field .name
229+ for ci in cls .column_info ().values ()
230+ if ci .field .name not in exclude
231+ ]
216232
217233 @classmethod
218234 def field_names_sql (
@@ -231,9 +247,13 @@ def _get_field_values_fn(
231247 ) -> Callable [[T ], list [Any ]]:
232248 env : dict [str , Any ] = {}
233249 func = ["def get_field_values(self): return [" ]
234- for f in cls ._fields ():
235- if f .name not in exclude :
236- func .append (f"self.{ f .name } ," )
250+ for ci in cls .column_info ().values ():
251+ if ci .field .name not in exclude :
252+ if ci .serialize :
253+ env [f"_ser_{ ci .field .name } " ] = ci .serialize
254+ func .append (f"_ser_{ ci .field .name } (self.{ ci .field .name } ), " )
255+ else :
256+ func .append (f"self.{ ci .field .name } ," )
237257 func += ["]" ]
238258 exec (" " .join (func ), env )
239259 return env ["get_field_values" ]
@@ -257,36 +277,46 @@ def field_values_sql(
257277 return [sql .value (value ) for value in self .field_values ()]
258278
259279 @classmethod
260- def from_tuple (
261- cls : type [T ], tup : tuple , * , offset : int = 0 , exclude : FieldNamesSet = ()
262- ) -> T :
263- names = (f .name for f in cls ._fields () if f .name not in exclude )
264- kwargs = {name : tup [offset ] for offset , name in enumerate (names , start = offset )}
265- return cls (** kwargs )
280+ def _get_from_mapping_fn (cls : type [T ]) -> Callable [[Mapping [str , Any ]], T ]:
281+ env : dict [str , Any ] = {"cls" : cls }
282+ func = ["def from_mapping(mapping):" ]
283+ if not any (ci .deserialize for ci in cls .column_info ().values ()):
284+ func .append (" return cls(**mapping)" )
285+ else :
286+ func .append (" deser_dict = dict(mapping)" )
287+ for ci in cls .column_info ().values ():
288+ if ci .deserialize :
289+ env [f"_deser_{ ci .field .name } " ] = ci .deserialize
290+ func .append (f" if { ci .field .name !r} in deser_dict:" )
291+ func .append (
292+ f" deser_dict[{ ci .field .name !r} ] = _deser_{ ci .field .name } (deser_dict[{ ci .field .name !r} ])"
293+ )
294+ func .append (" return cls(**deser_dict)" )
295+ exec ("\n " .join (func ), env )
296+ return env ["from_mapping" ]
266297
267298 @classmethod
268- def from_dict (
269- cls : type [T ], dct : dict [str , Any ], * , exclude : FieldNamesSet = ()
270- ) -> T :
271- names = {f .name for f in cls ._fields () if f .name not in exclude }
272- kwargs = {k : v for k , v in dct .items () if k in names }
273- return cls (** kwargs )
299+ def from_mapping (cls : type [T ], mapping : Mapping [str , Any ], / ) -> T :
300+ # KLUDGE nasty but... efficient?
301+ from_mapping_fn = cls ._get_from_mapping_fn ()
302+ cls .from_mapping = from_mapping_fn # type: ignore
303+ return from_mapping_fn (mapping )
274304
275305 @classmethod
276306 def ensure_model (cls : type [T ], row : Union [T , Mapping [str , Any ]]) -> T :
277307 if isinstance (row , cls ):
278308 return row
279- return cls ( ** row )
309+ return cls . from_mapping ( row ) # type: ignore
280310
281311 @classmethod
282312 def create_table_sql (cls ) -> Fragment :
283313 entries = [
284314 sql (
285315 "{} {}" ,
286- sql .identifier (f .name ),
287- sql .literal (cls . column_info ( f . name ) .create_table_string ()),
316+ sql .identifier (ci . field .name ),
317+ sql .literal (ci .create_table_string ()),
288318 )
289- for f in cls ._fields ()
319+ for ci in cls .column_info (). values ()
290320 ]
291321 if cls .primary_key_names :
292322 entries += [sql ("PRIMARY KEY ({})" , sql .list (cls .primary_key_names_sql ()))]
@@ -338,7 +368,7 @@ async def select_cursor(
338368 * cls .select_sql (order_by = order_by , for_update = for_update , where = where ),
339369 prefetch = prefetch ,
340370 ):
341- yield cls ( ** row )
371+ yield cls . from_mapping ( row )
342372
343373 @classmethod
344374 async def select (
@@ -349,19 +379,22 @@ async def select(
349379 where : Where = (),
350380 ) -> list [T ]:
351381 return [
352- cls ( ** row )
382+ cls . from_mapping ( row )
353383 for row in await connection_or_pool .fetch (
354384 * cls .select_sql (order_by = order_by , for_update = for_update , where = where )
355385 )
356386 ]
357387
358388 @classmethod
359389 def create_sql (cls : type [T ], ** kwargs : Any ) -> Fragment :
390+ column_info = cls .column_info ()
360391 return sql (
361392 "INSERT INTO {table} ({fields}) VALUES ({values}) RETURNING {out_fields}" ,
362393 table = cls .table_name_sql (),
363- fields = sql .list (sql .identifier (x ) for x in kwargs .keys ()),
364- values = sql .list (sql .value (x ) for x in kwargs .values ()),
394+ fields = sql .list (sql .identifier (k ) for k in kwargs .keys ()),
395+ values = sql .list (
396+ sql .value (column_info [k ].maybe_serialize (v )) for k , v in kwargs .items ()
397+ ),
365398 out_fields = sql .list (cls .field_names_sql ()),
366399 )
367400
@@ -370,7 +403,7 @@ async def create(
370403 cls : type [T ], connection_or_pool : Union [Connection , Pool ], ** kwargs : Any
371404 ) -> T :
372405 row = await connection_or_pool .fetchrow (* cls .create_sql (** kwargs ))
373- return cls ( ** row )
406+ return cls . from_mapping ( row )
374407
375408 def insert_sql (self , exclude : FieldNamesSet = ()) -> Fragment :
376409 cached = self ._cached (
@@ -428,10 +461,11 @@ def delete_multiple_sql(cls: type[T], rows: Iterable[T]) -> Fragment:
428461 pks = sql .list (sql .identifier (pk ) for pk in cls .primary_key_names ),
429462 ).compile (),
430463 )
464+ column_info = cls .column_info ()
431465 return cached (
432466 unnest = sql .unnest (
433467 (row .primary_key () for row in rows ),
434- (cls . column_info ( pk ) .type for pk in cls .primary_key_names ),
468+ (column_info [ pk ] .type for pk in cls .primary_key_names ),
435469 ),
436470 )
437471
@@ -451,10 +485,11 @@ def insert_multiple_sql(cls: type[T], rows: Iterable[T]) -> Fragment:
451485 fields = sql .list (cls .field_names_sql ()),
452486 ).compile (),
453487 )
488+ column_info = cls .column_info ()
454489 return cached (
455490 unnest = sql .unnest (
456491 (row .field_values () for row in rows ),
457- (cls . column_info ( name ) .type for name in cls .field_names ()),
492+ (column_info [ name ] .type for name in cls .field_names ()),
458493 ),
459494 )
460495
@@ -545,9 +580,9 @@ def _get_equal_ignoring_fn(
545580 ) -> Callable [[T , T ], bool ]:
546581 env : dict [str , Any ] = {}
547582 func = ["def equal_ignoring(a, b):" ]
548- for f in cls ._fields ():
549- if f .name not in ignore :
550- func .append (f" if a.{ f . name } != b.{ f .name } : return False" )
583+ for ci in cls .column_info (). values ():
584+ if ci . field .name not in ignore :
585+ func .append (f" if a.{ ci . field . name } != b.{ ci . field .name } : return False" )
551586 func += [" return True" ]
552587 exec ("\n " .join (func ), env )
553588 return env ["equal_ignoring" ]
@@ -603,9 +638,11 @@ def _get_differences_ignoring_fn(
603638 "def differences_ignoring(a, b):" ,
604639 " diffs = []" ,
605640 ]
606- for f in cls ._fields ():
607- if f .name not in ignore :
608- func .append (f" if a.{ f .name } != b.{ f .name } : diffs.append({ f .name !r} )" )
641+ for ci in cls .column_info ().values ():
642+ if ci .field .name not in ignore :
643+ func .append (
644+ f" if a.{ ci .field .name } != b.{ ci .field .name } : diffs.append({ ci .field .name !r} )"
645+ )
609646 func += [" return diffs" ]
610647 exec ("\n " .join (func ), env )
611648 return env ["differences_ignoring" ]
0 commit comments