Skip to content

Commit 503785e

Browse files
authored
Merge pull request #23 from bdowning/updates
Serialize/deserialize for sql-athame dataclasses
2 parents 69ce167 + 74312f7 commit 503785e

4 files changed

Lines changed: 148 additions & 69 deletions

File tree

.bumpversion.cfg

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
[bumpversion]
2-
current_version = 0.4.0-alpha-8
2+
current_version = 0.4.0-alpha-9
33
parse = (?P<major>\d+)\.(?P<minor>\d+)\.(?P<patch>\d+)(-(?P<release>.*)-(?P<build>\d+))?
44
serialize =
55
{major}.{minor}.{patch}-{release}-{build}

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[tool.poetry]
22
name = "sql-athame"
3-
version = "0.4.0-alpha-8"
3+
version = "0.4.0-alpha-9"
44
description = "Python tool for slicing and dicing SQL"
55
authors = ["Brian Downing <bdowning@lavos.net>"]
66
license = "MIT"

sql_athame/dataclasses.py

Lines changed: 102 additions & 65 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import datetime
22
import functools
3+
import sys
34
import uuid
45
from collections.abc import AsyncGenerator, Iterable, Mapping
56
from dataclasses import Field, InitVar, dataclass, fields
@@ -34,10 +35,13 @@ class ColumnInfo:
3435
type: Optional[str] = None
3536
create_type: Optional[str] = None
3637
nullable: Optional[bool] = None
37-
_constraints: tuple[str, ...] = ()
3838

39+
_constraints: tuple[str, ...] = ()
3940
constraints: InitVar[Union[str, Iterable[str], None]] = None
4041

42+
serialize: Optional[Callable[[Any], Any]] = None
43+
deserialize: Optional[Callable[[Any], Any]] = None
44+
4145
def __post_init__(self, constraints: Union[str, Iterable[str], None]) -> None:
4246
if constraints is not None:
4347
if type(constraints) is str:
@@ -51,29 +55,41 @@ def merge(a: "ColumnInfo", b: "ColumnInfo") -> "ColumnInfo":
5155
create_type=b.create_type if b.create_type is not None else a.create_type,
5256
nullable=b.nullable if b.nullable is not None else a.nullable,
5357
_constraints=(*a._constraints, *b._constraints),
58+
serialize=b.serialize if b.serialize is not None else a.serialize,
59+
deserialize=b.deserialize if b.deserialize is not None else a.deserialize,
5460
)
5561

5662

5763
@dataclass
5864
class ConcreteColumnInfo:
65+
field: Field
66+
type_hint: type
5967
type: str
6068
create_type: str
6169
nullable: bool
6270
constraints: tuple[str, ...]
71+
serialize: Optional[Callable[[Any], Any]] = None
72+
deserialize: Optional[Callable[[Any], Any]] = None
6373

6474
@staticmethod
65-
def from_column_info(name: str, *args: ColumnInfo) -> "ConcreteColumnInfo":
75+
def from_column_info(
76+
field: Field, type_hint: Any, *args: ColumnInfo
77+
) -> "ConcreteColumnInfo":
6678
info = functools.reduce(ColumnInfo.merge, args, ColumnInfo())
6779
if info.create_type is None and info.type is not None:
6880
info.create_type = info.type
6981
info.type = sql_create_type_map.get(info.type.upper(), info.type)
7082
if type(info.type) is not str or type(info.create_type) is not str:
71-
raise ValueError(f"Missing SQL type for column {name!r}")
83+
raise ValueError(f"Missing SQL type for column {field.name!r}")
7284
return ConcreteColumnInfo(
85+
field=field,
86+
type_hint=type_hint,
7387
type=info.type,
7488
create_type=info.create_type,
7589
nullable=bool(info.nullable),
7690
constraints=info._constraints,
91+
serialize=info.serialize,
92+
deserialize=info.deserialize,
7793
)
7894

7995
def create_table_string(self) -> str:
@@ -84,13 +100,24 @@ def create_table_string(self) -> str:
84100
)
85101
return " ".join(parts)
86102

103+
def maybe_serialize(self, value: Any) -> Any:
104+
if self.serialize:
105+
return self.serialize(value)
106+
return value
107+
108+
109+
UNION_TYPES: tuple = (Union,)
110+
if sys.version_info >= (3, 10):
111+
from types import UnionType
112+
113+
UNION_TYPES = (Union, UnionType)
87114

88115
NULLABLE_TYPES = (type(None), Any, object)
89116

90117

91118
def split_nullable(typ: type) -> tuple[bool, type]:
92119
nullable = typ in NULLABLE_TYPES
93-
if get_origin(typ) is Union:
120+
if get_origin(typ) in UNION_TYPES:
94121
args = []
95122
for arg in get_args(typ):
96123
if arg in NULLABLE_TYPES:
@@ -108,7 +135,7 @@ def split_nullable(typ: type) -> tuple[bool, type]:
108135
}
109136

110137

111-
sql_type_map: dict[Any, str] = {
138+
sql_type_map: dict[type, str] = {
112139
bool: "BOOLEAN",
113140
bytes: "BYTEA",
114141
datetime.date: "DATE",
@@ -125,12 +152,11 @@ def split_nullable(typ: type) -> tuple[bool, type]:
125152

126153

127154
class ModelBase:
128-
_column_info: Optional[dict[str, ConcreteColumnInfo]]
155+
_column_info: dict[str, ConcreteColumnInfo]
129156
_cache: dict[tuple, Any]
130157
table_name: str
131158
primary_key_names: tuple[str, ...]
132159
array_safe_insert: bool
133-
_type_hints: dict[str, type]
134160

135161
def __init_subclass__(
136162
cls,
@@ -153,13 +179,6 @@ def __init_subclass__(
153179
else:
154180
cls.primary_key_names = tuple(primary_key)
155181

156-
@classmethod
157-
def _fields(cls):
158-
# wrapper to ignore typing weirdness: 'Argument 1 to "fields"
159-
# has incompatible type "..."; expected "DataclassInstance |
160-
# type[DataclassInstance]"'
161-
return fields(cls) # type: ignore
162-
163182
@classmethod
164183
def _cached(cls, key: tuple, thunk: Callable[[], U]) -> U:
165184
try:
@@ -169,38 +188,31 @@ def _cached(cls, key: tuple, thunk: Callable[[], U]) -> U:
169188
return cls._cache[key]
170189

171190
@classmethod
172-
def type_hints(cls) -> dict[str, type]:
173-
try:
174-
return cls._type_hints
175-
except AttributeError:
176-
cls._type_hints = get_type_hints(cls, include_extras=True)
177-
return cls._type_hints
178-
179-
@classmethod
180-
def column_info_for_field(cls, field: Field) -> ConcreteColumnInfo:
181-
type_info = cls.type_hints()[field.name]
182-
base_type = type_info
191+
def column_info_for_field(cls, field: Field, type_hint: type) -> ConcreteColumnInfo:
192+
base_type = type_hint
183193
metadata = []
184-
if get_origin(type_info) is Annotated:
185-
base_type, *metadata = get_args(type_info)
194+
if get_origin(type_hint) is Annotated:
195+
base_type, *metadata = get_args(type_hint)
186196
nullable, base_type = split_nullable(base_type)
187197
info = [ColumnInfo(nullable=nullable)]
188198
if base_type in sql_type_map:
189199
info.append(ColumnInfo(type=sql_type_map[base_type]))
190200
for md in metadata:
191201
if isinstance(md, ColumnInfo):
192202
info.append(md)
193-
return ConcreteColumnInfo.from_column_info(field.name, *info)
203+
return ConcreteColumnInfo.from_column_info(field, type_hint, *info)
194204

195205
@classmethod
196-
def column_info(cls, column: str) -> ConcreteColumnInfo:
206+
def column_info(cls) -> dict[str, ConcreteColumnInfo]:
197207
try:
198-
return cls._column_info[column] # type: ignore
208+
return cls._column_info
199209
except AttributeError:
210+
type_hints = get_type_hints(cls, include_extras=True)
200211
cls._column_info = {
201-
f.name: cls.column_info_for_field(f) for f in cls._fields()
212+
f.name: cls.column_info_for_field(f, type_hints[f.name])
213+
for f in fields(cls) # type: ignore
202214
}
203-
return cls._column_info[column]
215+
return cls._column_info
204216

205217
@classmethod
206218
def table_name_sql(cls, *, prefix: Optional[str] = None) -> Fragment:
@@ -212,7 +224,11 @@ def primary_key_names_sql(cls, *, prefix: Optional[str] = None) -> list[Fragment
212224

213225
@classmethod
214226
def field_names(cls, *, exclude: FieldNamesSet = ()) -> list[str]:
215-
return [f.name for f in cls._fields() if f.name not in exclude]
227+
return [
228+
ci.field.name
229+
for ci in cls.column_info().values()
230+
if ci.field.name not in exclude
231+
]
216232

217233
@classmethod
218234
def field_names_sql(
@@ -231,9 +247,13 @@ def _get_field_values_fn(
231247
) -> Callable[[T], list[Any]]:
232248
env: dict[str, Any] = {}
233249
func = ["def get_field_values(self): return ["]
234-
for f in cls._fields():
235-
if f.name not in exclude:
236-
func.append(f"self.{f.name},")
250+
for ci in cls.column_info().values():
251+
if ci.field.name not in exclude:
252+
if ci.serialize:
253+
env[f"_ser_{ci.field.name}"] = ci.serialize
254+
func.append(f"_ser_{ci.field.name}(self.{ci.field.name}), ")
255+
else:
256+
func.append(f"self.{ci.field.name},")
237257
func += ["]"]
238258
exec(" ".join(func), env)
239259
return env["get_field_values"]
@@ -257,36 +277,46 @@ def field_values_sql(
257277
return [sql.value(value) for value in self.field_values()]
258278

259279
@classmethod
260-
def from_tuple(
261-
cls: type[T], tup: tuple, *, offset: int = 0, exclude: FieldNamesSet = ()
262-
) -> T:
263-
names = (f.name for f in cls._fields() if f.name not in exclude)
264-
kwargs = {name: tup[offset] for offset, name in enumerate(names, start=offset)}
265-
return cls(**kwargs)
280+
def _get_from_mapping_fn(cls: type[T]) -> Callable[[Mapping[str, Any]], T]:
281+
env: dict[str, Any] = {"cls": cls}
282+
func = ["def from_mapping(mapping):"]
283+
if not any(ci.deserialize for ci in cls.column_info().values()):
284+
func.append(" return cls(**mapping)")
285+
else:
286+
func.append(" deser_dict = dict(mapping)")
287+
for ci in cls.column_info().values():
288+
if ci.deserialize:
289+
env[f"_deser_{ci.field.name}"] = ci.deserialize
290+
func.append(f" if {ci.field.name!r} in deser_dict:")
291+
func.append(
292+
f" deser_dict[{ci.field.name!r}] = _deser_{ci.field.name}(deser_dict[{ci.field.name!r}])"
293+
)
294+
func.append(" return cls(**deser_dict)")
295+
exec("\n".join(func), env)
296+
return env["from_mapping"]
266297

267298
@classmethod
268-
def from_dict(
269-
cls: type[T], dct: dict[str, Any], *, exclude: FieldNamesSet = ()
270-
) -> T:
271-
names = {f.name for f in cls._fields() if f.name not in exclude}
272-
kwargs = {k: v for k, v in dct.items() if k in names}
273-
return cls(**kwargs)
299+
def from_mapping(cls: type[T], mapping: Mapping[str, Any], /) -> T:
300+
# KLUDGE nasty but... efficient?
301+
from_mapping_fn = cls._get_from_mapping_fn()
302+
cls.from_mapping = from_mapping_fn # type: ignore
303+
return from_mapping_fn(mapping)
274304

275305
@classmethod
276306
def ensure_model(cls: type[T], row: Union[T, Mapping[str, Any]]) -> T:
277307
if isinstance(row, cls):
278308
return row
279-
return cls(**row)
309+
return cls.from_mapping(row) # type: ignore
280310

281311
@classmethod
282312
def create_table_sql(cls) -> Fragment:
283313
entries = [
284314
sql(
285315
"{} {}",
286-
sql.identifier(f.name),
287-
sql.literal(cls.column_info(f.name).create_table_string()),
316+
sql.identifier(ci.field.name),
317+
sql.literal(ci.create_table_string()),
288318
)
289-
for f in cls._fields()
319+
for ci in cls.column_info().values()
290320
]
291321
if cls.primary_key_names:
292322
entries += [sql("PRIMARY KEY ({})", sql.list(cls.primary_key_names_sql()))]
@@ -338,7 +368,7 @@ async def select_cursor(
338368
*cls.select_sql(order_by=order_by, for_update=for_update, where=where),
339369
prefetch=prefetch,
340370
):
341-
yield cls(**row)
371+
yield cls.from_mapping(row)
342372

343373
@classmethod
344374
async def select(
@@ -349,19 +379,22 @@ async def select(
349379
where: Where = (),
350380
) -> list[T]:
351381
return [
352-
cls(**row)
382+
cls.from_mapping(row)
353383
for row in await connection_or_pool.fetch(
354384
*cls.select_sql(order_by=order_by, for_update=for_update, where=where)
355385
)
356386
]
357387

358388
@classmethod
359389
def create_sql(cls: type[T], **kwargs: Any) -> Fragment:
390+
column_info = cls.column_info()
360391
return sql(
361392
"INSERT INTO {table} ({fields}) VALUES ({values}) RETURNING {out_fields}",
362393
table=cls.table_name_sql(),
363-
fields=sql.list(sql.identifier(x) for x in kwargs.keys()),
364-
values=sql.list(sql.value(x) for x in kwargs.values()),
394+
fields=sql.list(sql.identifier(k) for k in kwargs.keys()),
395+
values=sql.list(
396+
sql.value(column_info[k].maybe_serialize(v)) for k, v in kwargs.items()
397+
),
365398
out_fields=sql.list(cls.field_names_sql()),
366399
)
367400

@@ -370,7 +403,7 @@ async def create(
370403
cls: type[T], connection_or_pool: Union[Connection, Pool], **kwargs: Any
371404
) -> T:
372405
row = await connection_or_pool.fetchrow(*cls.create_sql(**kwargs))
373-
return cls(**row)
406+
return cls.from_mapping(row)
374407

375408
def insert_sql(self, exclude: FieldNamesSet = ()) -> Fragment:
376409
cached = self._cached(
@@ -428,10 +461,11 @@ def delete_multiple_sql(cls: type[T], rows: Iterable[T]) -> Fragment:
428461
pks=sql.list(sql.identifier(pk) for pk in cls.primary_key_names),
429462
).compile(),
430463
)
464+
column_info = cls.column_info()
431465
return cached(
432466
unnest=sql.unnest(
433467
(row.primary_key() for row in rows),
434-
(cls.column_info(pk).type for pk in cls.primary_key_names),
468+
(column_info[pk].type for pk in cls.primary_key_names),
435469
),
436470
)
437471

@@ -451,10 +485,11 @@ def insert_multiple_sql(cls: type[T], rows: Iterable[T]) -> Fragment:
451485
fields=sql.list(cls.field_names_sql()),
452486
).compile(),
453487
)
488+
column_info = cls.column_info()
454489
return cached(
455490
unnest=sql.unnest(
456491
(row.field_values() for row in rows),
457-
(cls.column_info(name).type for name in cls.field_names()),
492+
(column_info[name].type for name in cls.field_names()),
458493
),
459494
)
460495

@@ -545,9 +580,9 @@ def _get_equal_ignoring_fn(
545580
) -> Callable[[T, T], bool]:
546581
env: dict[str, Any] = {}
547582
func = ["def equal_ignoring(a, b):"]
548-
for f in cls._fields():
549-
if f.name not in ignore:
550-
func.append(f" if a.{f.name} != b.{f.name}: return False")
583+
for ci in cls.column_info().values():
584+
if ci.field.name not in ignore:
585+
func.append(f" if a.{ci.field.name} != b.{ci.field.name}: return False")
551586
func += [" return True"]
552587
exec("\n".join(func), env)
553588
return env["equal_ignoring"]
@@ -603,9 +638,11 @@ def _get_differences_ignoring_fn(
603638
"def differences_ignoring(a, b):",
604639
" diffs = []",
605640
]
606-
for f in cls._fields():
607-
if f.name not in ignore:
608-
func.append(f" if a.{f.name} != b.{f.name}: diffs.append({f.name!r})")
641+
for ci in cls.column_info().values():
642+
if ci.field.name not in ignore:
643+
func.append(
644+
f" if a.{ci.field.name} != b.{ci.field.name}: diffs.append({ci.field.name!r})"
645+
)
609646
func += [" return diffs"]
610647
exec("\n".join(func), env)
611648
return env["differences_ignoring"]

0 commit comments

Comments
 (0)