Skip to content

Commit 1136b74

Browse files
committed
Serialize/deserialize for sql-athame dataclasses
1 parent a467b63 commit 1136b74

2 files changed

Lines changed: 72 additions & 17 deletions

File tree

sql_athame/dataclasses.py

Lines changed: 52 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,9 @@ class ColumnInfo:
3838
_constraints: tuple[str, ...] = ()
3939
constraints: InitVar[Union[str, Iterable[str], None]] = None
4040

41+
serialize: Optional[Callable[[Any], Any]] = None
42+
deserialize: Optional[Callable[[Any], Any]] = None
43+
4144
def __post_init__(self, constraints: Union[str, Iterable[str], None]) -> None:
4245
if constraints is not None:
4346
if type(constraints) is str:
@@ -51,6 +54,8 @@ def merge(a: "ColumnInfo", b: "ColumnInfo") -> "ColumnInfo":
5154
create_type=b.create_type if b.create_type is not None else a.create_type,
5255
nullable=b.nullable if b.nullable is not None else a.nullable,
5356
_constraints=(*a._constraints, *b._constraints),
57+
serialize=b.serialize if b.serialize is not None else a.serialize,
58+
deserialize=b.deserialize if b.deserialize is not None else a.deserialize,
5459
)
5560

5661

@@ -62,6 +67,8 @@ class ConcreteColumnInfo:
6267
create_type: str
6368
nullable: bool
6469
constraints: tuple[str, ...]
70+
serialize: Optional[Callable[[Any], Any]] = None
71+
deserialize: Optional[Callable[[Any], Any]] = None
6572

6673
@staticmethod
6774
def from_column_info(
@@ -80,6 +87,8 @@ def from_column_info(
8087
create_type=info.create_type,
8188
nullable=bool(info.nullable),
8289
constraints=info._constraints,
90+
serialize=info.serialize,
91+
deserialize=info.deserialize,
8392
)
8493

8594
def create_table_string(self) -> str:
@@ -90,6 +99,11 @@ def create_table_string(self) -> str:
9099
)
91100
return " ".join(parts)
92101

102+
def maybe_serialize(self, value: Any) -> Any:
103+
if self.serialize:
104+
return self.serialize(value)
105+
return value
106+
93107

94108
NULLABLE_TYPES = (type(None), Any, object)
95109

@@ -228,7 +242,11 @@ def _get_field_values_fn(
228242
func = ["def get_field_values(self): return ["]
229243
for ci in cls.column_info().values():
230244
if ci.field.name not in exclude:
231-
func.append(f"self.{ci.field.name},")
245+
if ci.serialize:
246+
env[f"_ser_{ci.field.name}"] = ci.serialize
247+
func.append(f"_ser_{ci.field.name}(self.{ci.field.name}), ")
248+
else:
249+
func.append(f"self.{ci.field.name},")
232250
func += ["]"]
233251
exec(" ".join(func), env)
234252
return env["get_field_values"]
@@ -252,22 +270,36 @@ def field_values_sql(
252270
return [sql.value(value) for value in self.field_values()]
253271

254272
@classmethod
255-
def from_dict(
256-
cls: type[T], dct: dict[str, Any], *, exclude: FieldNamesSet = ()
257-
) -> T:
258-
names = {
259-
ci.field.name
260-
for ci in cls.column_info().values()
261-
if ci.field.name not in exclude
262-
}
263-
kwargs = {k: v for k, v in dct.items() if k in names}
264-
return cls(**kwargs)
273+
def _get_from_mapping_fn(cls: type[T]) -> Callable[[Mapping[str, Any]], T]:
274+
env: dict[str, Any] = {"cls": cls}
275+
func = ["def from_mapping(mapping):"]
276+
if not any(ci.deserialize for ci in cls.column_info().values()):
277+
func.append(" return cls(**mapping)")
278+
else:
279+
func.append(" deser_dict = dict(mapping)")
280+
for ci in cls.column_info().values():
281+
if ci.deserialize:
282+
env[f"_deser_{ci.field.name}"] = ci.deserialize
283+
func.append(f" if {ci.field.name!r} in deser_dict:")
284+
func.append(
285+
f" deser_dict[{ci.field.name!r}] = _deser_{ci.field.name}(deser_dict[{ci.field.name!r}])"
286+
)
287+
func.append(" return cls(**deser_dict)")
288+
exec("\n".join(func), env)
289+
return env["from_mapping"]
290+
291+
@classmethod
292+
def from_mapping(cls: type[T], mapping: Mapping[str, Any], /) -> T:
293+
# KLUDGE nasty but... efficient?
294+
from_mapping_fn = cls._get_from_mapping_fn()
295+
cls.from_mapping = from_mapping_fn # type: ignore
296+
return from_mapping_fn(mapping)
265297

266298
@classmethod
267299
def ensure_model(cls: type[T], row: Union[T, Mapping[str, Any]]) -> T:
268300
if isinstance(row, cls):
269301
return row
270-
return cls(**row)
302+
return cls.from_mapping(row) # type: ignore
271303

272304
@classmethod
273305
def create_table_sql(cls) -> Fragment:
@@ -329,7 +361,7 @@ async def select_cursor(
329361
*cls.select_sql(order_by=order_by, for_update=for_update, where=where),
330362
prefetch=prefetch,
331363
):
332-
yield cls(**row)
364+
yield cls.from_mapping(row)
333365

334366
@classmethod
335367
async def select(
@@ -340,19 +372,22 @@ async def select(
340372
where: Where = (),
341373
) -> list[T]:
342374
return [
343-
cls(**row)
375+
cls.from_mapping(row)
344376
for row in await connection_or_pool.fetch(
345377
*cls.select_sql(order_by=order_by, for_update=for_update, where=where)
346378
)
347379
]
348380

349381
@classmethod
350382
def create_sql(cls: type[T], **kwargs: Any) -> Fragment:
383+
column_info = cls.column_info()
351384
return sql(
352385
"INSERT INTO {table} ({fields}) VALUES ({values}) RETURNING {out_fields}",
353386
table=cls.table_name_sql(),
354-
fields=sql.list(sql.identifier(x) for x in kwargs.keys()),
355-
values=sql.list(sql.value(x) for x in kwargs.values()),
387+
fields=sql.list(sql.identifier(k) for k in kwargs.keys()),
388+
values=sql.list(
389+
sql.value(column_info[k].maybe_serialize(v)) for k, v in kwargs.items()
390+
),
356391
out_fields=sql.list(cls.field_names_sql()),
357392
)
358393

@@ -361,7 +396,7 @@ async def create(
361396
cls: type[T], connection_or_pool: Union[Connection, Pool], **kwargs: Any
362397
) -> T:
363398
row = await connection_or_pool.fetchrow(*cls.create_sql(**kwargs))
364-
return cls(**row)
399+
return cls.from_mapping(row)
365400

366401
def insert_sql(self, exclude: FieldNamesSet = ()) -> Fragment:
367402
cached = self._cached(

tests/test_dataclasses.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -152,3 +152,23 @@ class Test(ModelBase, table_name="table", primary_key="id"):
152152
42,
153153
"foo",
154154
]
155+
156+
157+
def test_serde():
158+
@dataclass
159+
class Test(ModelBase, table_name="table"):
160+
foo: Annotated[
161+
str,
162+
ColumnInfo(serialize=lambda x: x.upper(), deserialize=lambda x: x.lower()),
163+
]
164+
bar: str
165+
166+
assert Test("foo", "bar").field_values() == ["FOO", "bar"]
167+
assert Test.create_sql(foo="foo", bar="bar").query() == (
168+
'INSERT INTO "table" ("foo", "bar") VALUES ($1, $2) RETURNING "foo", "bar"',
169+
["FOO", "bar"],
170+
)
171+
172+
assert Test.from_mapping({"foo": "FOO", "bar": "BAR"}) == Test("foo", "BAR")
173+
# make sure the monkey patching didn't screw things up
174+
assert Test.from_mapping({"foo": "FOO", "bar": "BAR"}) == Test("foo", "BAR")

0 commit comments

Comments
 (0)