@@ -38,6 +38,9 @@ class ColumnInfo:
3838 _constraints : tuple [str , ...] = ()
3939 constraints : InitVar [Union [str , Iterable [str ], None ]] = None
4040
41+ serialize : Optional [Callable [[Any ], Any ]] = None
42+ deserialize : Optional [Callable [[Any ], Any ]] = None
43+
4144 def __post_init__ (self , constraints : Union [str , Iterable [str ], None ]) -> None :
4245 if constraints is not None :
4346 if type (constraints ) is str :
@@ -51,6 +54,8 @@ def merge(a: "ColumnInfo", b: "ColumnInfo") -> "ColumnInfo":
5154 create_type = b .create_type if b .create_type is not None else a .create_type ,
5255 nullable = b .nullable if b .nullable is not None else a .nullable ,
5356 _constraints = (* a ._constraints , * b ._constraints ),
57+ serialize = b .serialize if b .serialize is not None else a .serialize ,
58+ deserialize = b .deserialize if b .deserialize is not None else a .deserialize ,
5459 )
5560
5661
@@ -62,6 +67,8 @@ class ConcreteColumnInfo:
6267 create_type : str
6368 nullable : bool
6469 constraints : tuple [str , ...]
70+ serialize : Optional [Callable [[Any ], Any ]] = None
71+ deserialize : Optional [Callable [[Any ], Any ]] = None
6572
6673 @staticmethod
6774 def from_column_info (
@@ -80,6 +87,8 @@ def from_column_info(
8087 create_type = info .create_type ,
8188 nullable = bool (info .nullable ),
8289 constraints = info ._constraints ,
90+ serialize = info .serialize ,
91+ deserialize = info .deserialize ,
8392 )
8493
8594 def create_table_string (self ) -> str :
@@ -90,6 +99,11 @@ def create_table_string(self) -> str:
9099 )
91100 return " " .join (parts )
92101
102+ def maybe_serialize (self , value : Any ) -> Any :
103+ if self .serialize :
104+ return self .serialize (value )
105+ return value
106+
93107
94108NULLABLE_TYPES = (type (None ), Any , object )
95109
@@ -228,7 +242,11 @@ def _get_field_values_fn(
228242 func = ["def get_field_values(self): return [" ]
229243 for ci in cls .column_info ().values ():
230244 if ci .field .name not in exclude :
231- func .append (f"self.{ ci .field .name } ," )
245+ if ci .serialize :
246+ env [f"_ser_{ ci .field .name } " ] = ci .serialize
247+ func .append (f"_ser_{ ci .field .name } (self.{ ci .field .name } ), " )
248+ else :
249+ func .append (f"self.{ ci .field .name } ," )
232250 func += ["]" ]
233251 exec (" " .join (func ), env )
234252 return env ["get_field_values" ]
@@ -252,22 +270,36 @@ def field_values_sql(
252270 return [sql .value (value ) for value in self .field_values ()]
253271
254272 @classmethod
255- def from_dict (
256- cls : type [T ], dct : dict [str , Any ], * , exclude : FieldNamesSet = ()
257- ) -> T :
258- names = {
259- ci .field .name
260- for ci in cls .column_info ().values ()
261- if ci .field .name not in exclude
262- }
263- kwargs = {k : v for k , v in dct .items () if k in names }
264- return cls (** kwargs )
273+ def _get_from_mapping_fn (cls : type [T ]) -> Callable [[Mapping [str , Any ]], T ]:
274+ env : dict [str , Any ] = {"cls" : cls }
275+ func = ["def from_mapping(mapping):" ]
276+ if not any (ci .deserialize for ci in cls .column_info ().values ()):
277+ func .append (" return cls(**mapping)" )
278+ else :
279+ func .append (" deser_dict = dict(mapping)" )
280+ for ci in cls .column_info ().values ():
281+ if ci .deserialize :
282+ env [f"_deser_{ ci .field .name } " ] = ci .deserialize
283+ func .append (f" if { ci .field .name !r} in deser_dict:" )
284+ func .append (
285+ f" deser_dict[{ ci .field .name !r} ] = _deser_{ ci .field .name } (deser_dict[{ ci .field .name !r} ])"
286+ )
287+ func .append (" return cls(**deser_dict)" )
288+ exec ("\n " .join (func ), env )
289+ return env ["from_mapping" ]
290+
291+ @classmethod
292+ def from_mapping (cls : type [T ], mapping : Mapping [str , Any ], / ) -> T :
293+ # KLUDGE nasty but... efficient?
294+ from_mapping_fn = cls ._get_from_mapping_fn ()
295+ cls .from_mapping = from_mapping_fn # type: ignore
296+ return from_mapping_fn (mapping )
265297
266298 @classmethod
267299 def ensure_model (cls : type [T ], row : Union [T , Mapping [str , Any ]]) -> T :
268300 if isinstance (row , cls ):
269301 return row
270- return cls ( ** row )
302+ return cls . from_mapping ( row ) # type: ignore
271303
272304 @classmethod
273305 def create_table_sql (cls ) -> Fragment :
@@ -329,7 +361,7 @@ async def select_cursor(
329361 * cls .select_sql (order_by = order_by , for_update = for_update , where = where ),
330362 prefetch = prefetch ,
331363 ):
332- yield cls ( ** row )
364+ yield cls . from_mapping ( row )
333365
334366 @classmethod
335367 async def select (
@@ -340,19 +372,22 @@ async def select(
340372 where : Where = (),
341373 ) -> list [T ]:
342374 return [
343- cls ( ** row )
375+ cls . from_mapping ( row )
344376 for row in await connection_or_pool .fetch (
345377 * cls .select_sql (order_by = order_by , for_update = for_update , where = where )
346378 )
347379 ]
348380
349381 @classmethod
350382 def create_sql (cls : type [T ], ** kwargs : Any ) -> Fragment :
383+ column_info = cls .column_info ()
351384 return sql (
352385 "INSERT INTO {table} ({fields}) VALUES ({values}) RETURNING {out_fields}" ,
353386 table = cls .table_name_sql (),
354- fields = sql .list (sql .identifier (x ) for x in kwargs .keys ()),
355- values = sql .list (sql .value (x ) for x in kwargs .values ()),
387+ fields = sql .list (sql .identifier (k ) for k in kwargs .keys ()),
388+ values = sql .list (
389+ sql .value (column_info [k ].maybe_serialize (v )) for k , v in kwargs .items ()
390+ ),
356391 out_fields = sql .list (cls .field_names_sql ()),
357392 )
358393
@@ -361,7 +396,7 @@ async def create(
361396 cls : type [T ], connection_or_pool : Union [Connection , Pool ], ** kwargs : Any
362397 ) -> T :
363398 row = await connection_or_pool .fetchrow (* cls .create_sql (** kwargs ))
364- return cls ( ** row )
399+ return cls . from_mapping ( row )
365400
366401 def insert_sql (self , exclude : FieldNamesSet = ()) -> Fragment :
367402 cached = self ._cached (
0 commit comments