Skip to content

Commit a467b63

Browse files
committed
Refactor column_info
1 parent 69ce167 commit a467b63

2 files changed

Lines changed: 52 additions & 57 deletions

File tree

sql_athame/dataclasses.py

Lines changed: 50 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -34,8 +34,8 @@ class ColumnInfo:
3434
type: Optional[str] = None
3535
create_type: Optional[str] = None
3636
nullable: Optional[bool] = None
37-
_constraints: tuple[str, ...] = ()
3837

38+
_constraints: tuple[str, ...] = ()
3939
constraints: InitVar[Union[str, Iterable[str], None]] = None
4040

4141
def __post_init__(self, constraints: Union[str, Iterable[str], None]) -> None:
@@ -56,20 +56,26 @@ def merge(a: "ColumnInfo", b: "ColumnInfo") -> "ColumnInfo":
5656

5757
@dataclass
5858
class ConcreteColumnInfo:
59+
field: Field
60+
type_hint: type
5961
type: str
6062
create_type: str
6163
nullable: bool
6264
constraints: tuple[str, ...]
6365

6466
@staticmethod
65-
def from_column_info(name: str, *args: ColumnInfo) -> "ConcreteColumnInfo":
67+
def from_column_info(
68+
field: Field, type_hint: Any, *args: ColumnInfo
69+
) -> "ConcreteColumnInfo":
6670
info = functools.reduce(ColumnInfo.merge, args, ColumnInfo())
6771
if info.create_type is None and info.type is not None:
6872
info.create_type = info.type
6973
info.type = sql_create_type_map.get(info.type.upper(), info.type)
7074
if type(info.type) is not str or type(info.create_type) is not str:
71-
raise ValueError(f"Missing SQL type for column {name!r}")
75+
raise ValueError(f"Missing SQL type for column {field.name!r}")
7276
return ConcreteColumnInfo(
77+
field=field,
78+
type_hint=type_hint,
7379
type=info.type,
7480
create_type=info.create_type,
7581
nullable=bool(info.nullable),
@@ -108,7 +114,7 @@ def split_nullable(typ: type) -> tuple[bool, type]:
108114
}
109115

110116

111-
sql_type_map: dict[Any, str] = {
117+
sql_type_map: dict[type, str] = {
112118
bool: "BOOLEAN",
113119
bytes: "BYTEA",
114120
datetime.date: "DATE",
@@ -125,12 +131,11 @@ def split_nullable(typ: type) -> tuple[bool, type]:
125131

126132

127133
class ModelBase:
128-
_column_info: Optional[dict[str, ConcreteColumnInfo]]
134+
_column_info: dict[str, ConcreteColumnInfo]
129135
_cache: dict[tuple, Any]
130136
table_name: str
131137
primary_key_names: tuple[str, ...]
132138
array_safe_insert: bool
133-
_type_hints: dict[str, type]
134139

135140
def __init_subclass__(
136141
cls,
@@ -153,13 +158,6 @@ def __init_subclass__(
153158
else:
154159
cls.primary_key_names = tuple(primary_key)
155160

156-
@classmethod
157-
def _fields(cls):
158-
# wrapper to ignore typing weirdness: 'Argument 1 to "fields"
159-
# has incompatible type "..."; expected "DataclassInstance |
160-
# type[DataclassInstance]"'
161-
return fields(cls) # type: ignore
162-
163161
@classmethod
164162
def _cached(cls, key: tuple, thunk: Callable[[], U]) -> U:
165163
try:
@@ -169,38 +167,31 @@ def _cached(cls, key: tuple, thunk: Callable[[], U]) -> U:
169167
return cls._cache[key]
170168

171169
@classmethod
172-
def type_hints(cls) -> dict[str, type]:
173-
try:
174-
return cls._type_hints
175-
except AttributeError:
176-
cls._type_hints = get_type_hints(cls, include_extras=True)
177-
return cls._type_hints
178-
179-
@classmethod
180-
def column_info_for_field(cls, field: Field) -> ConcreteColumnInfo:
181-
type_info = cls.type_hints()[field.name]
182-
base_type = type_info
170+
def column_info_for_field(cls, field: Field, type_hint: type) -> ConcreteColumnInfo:
171+
base_type = type_hint
183172
metadata = []
184-
if get_origin(type_info) is Annotated:
185-
base_type, *metadata = get_args(type_info)
173+
if get_origin(type_hint) is Annotated:
174+
base_type, *metadata = get_args(type_hint)
186175
nullable, base_type = split_nullable(base_type)
187176
info = [ColumnInfo(nullable=nullable)]
188177
if base_type in sql_type_map:
189178
info.append(ColumnInfo(type=sql_type_map[base_type]))
190179
for md in metadata:
191180
if isinstance(md, ColumnInfo):
192181
info.append(md)
193-
return ConcreteColumnInfo.from_column_info(field.name, *info)
182+
return ConcreteColumnInfo.from_column_info(field, type_hint, *info)
194183

195184
@classmethod
196-
def column_info(cls, column: str) -> ConcreteColumnInfo:
185+
def column_info(cls) -> dict[str, ConcreteColumnInfo]:
197186
try:
198-
return cls._column_info[column] # type: ignore
187+
return cls._column_info
199188
except AttributeError:
189+
type_hints = get_type_hints(cls, include_extras=True)
200190
cls._column_info = {
201-
f.name: cls.column_info_for_field(f) for f in cls._fields()
191+
f.name: cls.column_info_for_field(f, type_hints[f.name])
192+
for f in fields(cls) # type: ignore
202193
}
203-
return cls._column_info[column]
194+
return cls._column_info
204195

205196
@classmethod
206197
def table_name_sql(cls, *, prefix: Optional[str] = None) -> Fragment:
@@ -212,7 +203,11 @@ def primary_key_names_sql(cls, *, prefix: Optional[str] = None) -> list[Fragment
212203

213204
@classmethod
214205
def field_names(cls, *, exclude: FieldNamesSet = ()) -> list[str]:
215-
return [f.name for f in cls._fields() if f.name not in exclude]
206+
return [
207+
ci.field.name
208+
for ci in cls.column_info().values()
209+
if ci.field.name not in exclude
210+
]
216211

217212
@classmethod
218213
def field_names_sql(
@@ -231,9 +226,9 @@ def _get_field_values_fn(
231226
) -> Callable[[T], list[Any]]:
232227
env: dict[str, Any] = {}
233228
func = ["def get_field_values(self): return ["]
234-
for f in cls._fields():
235-
if f.name not in exclude:
236-
func.append(f"self.{f.name},")
229+
for ci in cls.column_info().values():
230+
if ci.field.name not in exclude:
231+
func.append(f"self.{ci.field.name},")
237232
func += ["]"]
238233
exec(" ".join(func), env)
239234
return env["get_field_values"]
@@ -256,19 +251,15 @@ def field_values_sql(
256251
else:
257252
return [sql.value(value) for value in self.field_values()]
258253

259-
@classmethod
260-
def from_tuple(
261-
cls: type[T], tup: tuple, *, offset: int = 0, exclude: FieldNamesSet = ()
262-
) -> T:
263-
names = (f.name for f in cls._fields() if f.name not in exclude)
264-
kwargs = {name: tup[offset] for offset, name in enumerate(names, start=offset)}
265-
return cls(**kwargs)
266-
267254
@classmethod
268255
def from_dict(
269256
cls: type[T], dct: dict[str, Any], *, exclude: FieldNamesSet = ()
270257
) -> T:
271-
names = {f.name for f in cls._fields() if f.name not in exclude}
258+
names = {
259+
ci.field.name
260+
for ci in cls.column_info().values()
261+
if ci.field.name not in exclude
262+
}
272263
kwargs = {k: v for k, v in dct.items() if k in names}
273264
return cls(**kwargs)
274265

@@ -283,10 +274,10 @@ def create_table_sql(cls) -> Fragment:
283274
entries = [
284275
sql(
285276
"{} {}",
286-
sql.identifier(f.name),
287-
sql.literal(cls.column_info(f.name).create_table_string()),
277+
sql.identifier(ci.field.name),
278+
sql.literal(ci.create_table_string()),
288279
)
289-
for f in cls._fields()
280+
for ci in cls.column_info().values()
290281
]
291282
if cls.primary_key_names:
292283
entries += [sql("PRIMARY KEY ({})", sql.list(cls.primary_key_names_sql()))]
@@ -428,10 +419,11 @@ def delete_multiple_sql(cls: type[T], rows: Iterable[T]) -> Fragment:
428419
pks=sql.list(sql.identifier(pk) for pk in cls.primary_key_names),
429420
).compile(),
430421
)
422+
column_info = cls.column_info()
431423
return cached(
432424
unnest=sql.unnest(
433425
(row.primary_key() for row in rows),
434-
(cls.column_info(pk).type for pk in cls.primary_key_names),
426+
(column_info[pk].type for pk in cls.primary_key_names),
435427
),
436428
)
437429

@@ -451,10 +443,11 @@ def insert_multiple_sql(cls: type[T], rows: Iterable[T]) -> Fragment:
451443
fields=sql.list(cls.field_names_sql()),
452444
).compile(),
453445
)
446+
column_info = cls.column_info()
454447
return cached(
455448
unnest=sql.unnest(
456449
(row.field_values() for row in rows),
457-
(cls.column_info(name).type for name in cls.field_names()),
450+
(column_info[name].type for name in cls.field_names()),
458451
),
459452
)
460453

@@ -545,9 +538,9 @@ def _get_equal_ignoring_fn(
545538
) -> Callable[[T, T], bool]:
546539
env: dict[str, Any] = {}
547540
func = ["def equal_ignoring(a, b):"]
548-
for f in cls._fields():
549-
if f.name not in ignore:
550-
func.append(f" if a.{f.name} != b.{f.name}: return False")
541+
for ci in cls.column_info().values():
542+
if ci.field.name not in ignore:
543+
func.append(f" if a.{ci.field.name} != b.{ci.field.name}: return False")
551544
func += [" return True"]
552545
exec("\n".join(func), env)
553546
return env["equal_ignoring"]
@@ -603,9 +596,11 @@ def _get_differences_ignoring_fn(
603596
"def differences_ignoring(a, b):",
604597
" diffs = []",
605598
]
606-
for f in cls._fields():
607-
if f.name not in ignore:
608-
func.append(f" if a.{f.name} != b.{f.name}: diffs.append({f.name!r})")
599+
for ci in cls.column_info().values():
600+
if ci.field.name not in ignore:
601+
func.append(
602+
f" if a.{ci.field.name} != b.{ci.field.name}: diffs.append({ci.field.name!r})"
603+
)
609604
func += [" return diffs"]
610605
exec("\n".join(func), env)
611606
return env["differences_ignoring"]

tests/test_dataclasses.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -135,8 +135,8 @@ class Test(ModelBase, table_name="table", primary_key="id"):
135135
foo: int
136136
bar: str
137137

138-
assert Test.column_info("id").type == "INTEGER"
139-
assert Test.column_info("id").create_type == "SERIAL"
138+
assert Test.column_info()["id"].type == "INTEGER"
139+
assert Test.column_info()["id"].create_type == "SERIAL"
140140
assert list(Test.create_table_sql()) == [
141141
'CREATE TABLE IF NOT EXISTS "table" ('
142142
'"id" SERIAL NOT NULL, '

0 commit comments

Comments
 (0)