Skip to content

Commit b274d59

Browse files
committed
Column info via annotation
1 parent 9b2e99b commit b274d59

4 files changed

Lines changed: 49 additions & 44 deletions

File tree

sql_athame/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,2 @@
11
from .base import Fragment, sql
2-
from .dataclasses import ModelBase, model_field, model_field_metadata
2+
from .dataclasses import ColumnInfo, ModelBase

sql_athame/dataclasses.py

Lines changed: 31 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,16 @@
11
import datetime
22
import uuid
33
from collections.abc import AsyncGenerator, Iterable, Mapping
4-
from dataclasses import InitVar, dataclass, field, fields
4+
from dataclasses import Field, InitVar, dataclass, fields
55
from typing import (
6+
Annotated,
67
Any,
78
Callable,
89
Optional,
910
TypeVar,
1011
Union,
12+
get_origin,
13+
get_type_hints,
1114
)
1215

1316
from .base import Fragment, sql
@@ -31,7 +34,7 @@ class ColumnInfo:
3134

3235
constraints: InitVar[Union[str, Iterable[str], None]] = None
3336

34-
def __post_init__(self, constraints: Union[str, Iterable[str], None]):
37+
def __post_init__(self, constraints: Union[str, Iterable[str], None]) -> None:
3538
if self.create_type == "":
3639
self.create_type = self.type
3740
self.type = sql_create_type_map.get(self.type.upper(), self.type)
@@ -40,7 +43,7 @@ def __post_init__(self, constraints: Union[str, Iterable[str], None]):
4043
constraints = (constraints,)
4144
self._constraints = tuple(constraints)
4245

43-
def create_table_string(self):
46+
def create_table_string(self) -> str:
4447
parts = (
4548
self.create_type,
4649
*(() if self.nullable else ("NOT NULL",)),
@@ -49,30 +52,14 @@ def create_table_string(self):
4952
return " ".join(parts)
5053

5154

52-
def model_field_metadata(
53-
type: str, nullable: bool = False, constraints: Union[str, Iterable[str]] = ()
54-
) -> dict[str, Any]:
55-
if isinstance(constraints, str):
56-
constraints = (constraints,)
57-
info = ColumnInfo(type=type, nullable=nullable, constraints=constraints)
58-
59-
return {"sql_athame": info}
60-
61-
62-
def model_field(
63-
*, type: str, constraints: Union[str, Iterable[str]] = (), **kwargs: Any
64-
) -> Any:
65-
return field(**kwargs, metadata=model_field_metadata(type, constraints))
66-
67-
6855
sql_create_type_map = {
6956
"BIGSERIAL": "BIGINT",
7057
"SERIAL": "INTEGER",
7158
"SMALLSERIAL": "SMALLINT",
7259
}
7360

7461

75-
sql_type_map: dict[type, tuple[str, bool]] = {
62+
sql_type_map: dict[Any, tuple[str, bool]] = {
7663
Optional[bool]: ("BOOLEAN", True),
7764
Optional[bytes]: ("BYTEA", True),
7865
Optional[datetime.date]: ("DATE", True),
@@ -92,13 +79,6 @@ def model_field(
9279
}
9380

9481

95-
def column_info_for_field(field):
96-
if "sql_athame" in field.metadata:
97-
return field.metadata["sql_athame"]
98-
type, nullable = sql_type_map[field.type]
99-
return ColumnInfo(type=type, nullable=nullable)
100-
101-
10282
T = TypeVar("T", bound="ModelBase")
10383
U = TypeVar("U")
10484

@@ -109,6 +89,7 @@ class ModelBase:
10989
table_name: str
11090
primary_key_names: tuple[str, ...]
11191
array_safe_insert: bool
92+
_type_hints: dict[str, type]
11293

11394
def __init_subclass__(
11495
cls,
@@ -146,12 +127,34 @@ def _cached(cls, key: tuple, thunk: Callable[[], U]) -> U:
146127
cls._cache[key] = thunk()
147128
return cls._cache[key]
148129

130+
@classmethod
131+
def type_hints(cls) -> dict[str, type]:
132+
try:
133+
return cls._type_hints
134+
except AttributeError:
135+
cls._type_hints = get_type_hints(cls, include_extras=True)
136+
return cls._type_hints
137+
138+
@classmethod
139+
def column_info_for_field(cls, field: Field) -> ColumnInfo:
140+
type_info = cls.type_hints()[field.name]
141+
base_type = type_info
142+
if get_origin(type_info) is Annotated:
143+
base_type = type_info.__origin__ # type: ignore
144+
for md in type_info.__metadata__: # type: ignore
145+
if isinstance(md, ColumnInfo):
146+
return md
147+
type, nullable = sql_type_map[base_type]
148+
return ColumnInfo(type=type, nullable=nullable)
149+
149150
@classmethod
150151
def column_info(cls, column: str) -> ColumnInfo:
151152
try:
152153
return cls._column_info[column] # type: ignore
153154
except AttributeError:
154-
cls._column_info = {f.name: column_info_for_field(f) for f in cls._fields()}
155+
cls._column_info = {
156+
f.name: cls.column_info_for_field(f) for f in cls._fields()
157+
}
155158
return cls._column_info[column]
156159

157160
@classmethod

tests/test_asyncpg.py

Lines changed: 11 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -3,12 +3,13 @@
33
import os
44
from dataclasses import dataclass, field
55
from datetime import datetime
6-
from typing import Optional
6+
from typing import Annotated, Optional
77

88
import asyncpg
99
import pytest
10+
from typing_extensions import TypeAlias
1011

11-
from sql_athame import ModelBase, model_field_metadata, sql
12+
from sql_athame import ColumnInfo, ModelBase, sql
1213

1314

1415
@pytest.fixture(autouse=True)
@@ -137,7 +138,7 @@ class Test(
137138
insert_multiple_mode="array_safe",
138139
):
139140
id: int
140-
a: list[int] = field(metadata=model_field_metadata(type="INT[]"))
141+
a: Annotated[list[int], ColumnInfo(type="INT[]")]
141142
b: str
142143

143144
await conn.execute(*Test.create_table_sql())
@@ -260,10 +261,13 @@ class Test(ModelBase, table_name="test", primary_key=("id1", "id2")):
260261
]
261262

262263

264+
Serial: TypeAlias = Annotated[int, ColumnInfo(type="SERIAL")]
265+
266+
263267
async def test_serial(conn):
264268
@dataclass
265269
class Test(ModelBase, table_name="table", primary_key="id"):
266-
id: int = field(metadata=model_field_metadata(type="SERIAL"))
270+
id: Serial
267271
foo: int
268272
bar: str
269273

@@ -279,10 +283,8 @@ class Test(ModelBase, table_name="table", primary_key="id"):
279283
async def test_unnest_json(conn):
280284
@dataclass
281285
class Test(ModelBase, table_name="table", primary_key="id"):
282-
id: int = field(metadata=model_field_metadata(type="SERIAL"))
283-
json: Optional[list] = field(
284-
metadata=model_field_metadata(type="JSONB", nullable=True)
285-
)
286+
id: Serial
287+
json: Annotated[Optional[list], ColumnInfo(type="JSONB", nullable=True)]
286288

287289
await conn.set_type_codec(
288290
"jsonb", encoder=json.dumps, decoder=json.loads, schema="pg_catalog"
@@ -307,7 +309,7 @@ class Test(ModelBase, table_name="table", primary_key="id"):
307309
async def test_unnest_empty(conn):
308310
@dataclass
309311
class Test(ModelBase, table_name="table", primary_key="id"):
310-
id: int = field(metadata=model_field_metadata(type="SERIAL"))
312+
id: Serial
311313

312314
await conn.execute(*Test.create_table_sql())
313315

tests/test_dataclasses.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,16 @@
11
import uuid
2-
from dataclasses import dataclass, field
3-
from typing import Optional
2+
from dataclasses import dataclass
3+
from typing import Annotated, Optional
44

55
from sql_athame import sql
6-
from sql_athame.dataclasses import ModelBase, model_field_metadata
6+
from sql_athame.dataclasses import ColumnInfo, ModelBase
77

88

99
def test_modelclass():
1010
@dataclass
1111
class Test(ModelBase, table_name="table"):
12-
foo: int = field(metadata=model_field_metadata(type="INTEGER"))
13-
bar: str = field(default="hi", metadata=model_field_metadata(type="TEXT"))
12+
foo: int
13+
bar: str = "hi"
1414

1515
t = Test(42)
1616

@@ -94,7 +94,7 @@ class Test(ModelBase, table_name="table", primary_key="id"):
9494
def test_serial():
9595
@dataclass
9696
class Test(ModelBase, table_name="table", primary_key="id"):
97-
id: int = field(metadata=model_field_metadata(type="SERIAL"))
97+
id: Annotated[int, ColumnInfo(type="SERIAL")]
9898
foo: int
9999
bar: str
100100

0 commit comments

Comments
 (0)