Skip to content

Commit 7ff2075

Browse files
authored
Merge pull request #20 from bdowning/updates
More updates
2 parents b9d492e + 2be7341 commit 7ff2075

7 files changed

Lines changed: 108 additions & 113 deletions

File tree

.bumpversion.cfg

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
[bumpversion]
2-
current_version = 0.4.0-alpha-5
2+
current_version = 0.4.0-alpha-6
33
parse = (?P<major>\d+)\.(?P<minor>\d+)\.(?P<patch>\d+)(-(?P<release>.*)-(?P<build>\d+))?
44
serialize =
55
{major}.{minor}.{patch}-{release}-{build}

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[tool.poetry]
22
name = "sql-athame"
3-
version = "0.4.0-alpha-5"
3+
version = "0.4.0-alpha-6"
44
description = "Python tool for slicing and dicing SQL"
55
authors = ["Brian Downing <bdowning@lavos.net>"]
66
license = "MIT"

sql_athame/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,2 @@
11
from .base import Fragment, sql
2-
from .dataclasses import ModelBase, model_field, model_field_metadata
2+
from .dataclasses import ColumnInfo, ModelBase

sql_athame/dataclasses.py

Lines changed: 75 additions & 69 deletions
Original file line numberDiff line numberDiff line change
@@ -1,52 +1,57 @@
11
import datetime
22
import uuid
3-
from collections.abc import AsyncGenerator, Iterable, Iterator, Mapping
4-
from dataclasses import dataclass, field, fields
3+
from collections.abc import AsyncGenerator, Iterable, Mapping
4+
from dataclasses import Field, InitVar, dataclass, fields
55
from typing import (
6+
Annotated,
67
Any,
78
Callable,
89
Optional,
910
TypeVar,
1011
Union,
12+
get_origin,
13+
get_type_hints,
1114
)
1215

16+
from typing_extensions import TypeAlias
17+
1318
from .base import Fragment, sql
1419

15-
Where = Union[Fragment, Iterable[Fragment]]
20+
Where: TypeAlias = Union[Fragment, Iterable[Fragment]]
1621
# KLUDGE to avoid a string argument being valid
17-
SequenceOfStrings = Union[list[str], tuple[str, ...]]
18-
FieldNames = SequenceOfStrings
19-
FieldNamesSet = Union[SequenceOfStrings, set[str]]
22+
SequenceOfStrings: TypeAlias = Union[list[str], tuple[str, ...]]
23+
FieldNames: TypeAlias = SequenceOfStrings
24+
FieldNamesSet: TypeAlias = Union[SequenceOfStrings, set[str]]
2025

21-
Connection = Any
22-
Pool = Any
26+
Connection: TypeAlias = Any
27+
Pool: TypeAlias = Any
2328

2429

2530
@dataclass
2631
class ColumnInfo:
2732
type: str
28-
create_type: str
29-
constraints: tuple[str, ...]
30-
31-
def create_table_string(self):
32-
return " ".join((self.create_type, *self.constraints))
33-
34-
35-
def model_field_metadata(
36-
type: str, constraints: Union[str, Iterable[str]] = ()
37-
) -> dict[str, Any]:
38-
if isinstance(constraints, str):
39-
constraints = (constraints,)
40-
info = ColumnInfo(
41-
sql_create_type_map.get(type.upper(), type), type, tuple(constraints)
42-
)
43-
return {"sql_athame": info}
44-
45-
46-
def model_field(
47-
*, type: str, constraints: Union[str, Iterable[str]] = (), **kwargs: Any
48-
) -> Any:
49-
return field(**kwargs, metadata=model_field_metadata(type, constraints))
33+
create_type: str = ""
34+
nullable: bool = False
35+
_constraints: tuple[str, ...] = ()
36+
37+
constraints: InitVar[Union[str, Iterable[str], None]] = None
38+
39+
def __post_init__(self, constraints: Union[str, Iterable[str], None]) -> None:
40+
if self.create_type == "":
41+
self.create_type = self.type
42+
self.type = sql_create_type_map.get(self.type.upper(), self.type)
43+
if constraints is not None:
44+
if type(constraints) is str:
45+
constraints = (constraints,)
46+
self._constraints = tuple(constraints)
47+
48+
def create_table_string(self) -> str:
49+
parts = (
50+
self.create_type,
51+
*(() if self.nullable else ("NOT NULL",)),
52+
*self._constraints,
53+
)
54+
return " ".join(parts)
5055

5156

5257
sql_create_type_map = {
@@ -56,43 +61,37 @@ def model_field(
5661
}
5762

5863

59-
sql_type_map = {
60-
Optional[bool]: ("BOOLEAN",),
61-
Optional[bytes]: ("BYTEA",),
62-
Optional[datetime.date]: ("DATE",),
63-
Optional[datetime.datetime]: ("TIMESTAMP",),
64-
Optional[float]: ("DOUBLE PRECISION",),
65-
Optional[int]: ("INTEGER",),
66-
Optional[str]: ("TEXT",),
67-
Optional[uuid.UUID]: ("UUID",),
68-
bool: ("BOOLEAN", "NOT NULL"),
69-
bytes: ("BYTEA", "NOT NULL"),
70-
datetime.date: ("DATE", "NOT NULL"),
71-
datetime.datetime: ("TIMESTAMP", "NOT NULL"),
72-
float: ("DOUBLE PRECISION", "NOT NULL"),
73-
int: ("INTEGER", "NOT NULL"),
74-
str: ("TEXT", "NOT NULL"),
75-
uuid.UUID: ("UUID", "NOT NULL"),
64+
sql_type_map: dict[Any, tuple[str, bool]] = {
65+
Optional[bool]: ("BOOLEAN", True),
66+
Optional[bytes]: ("BYTEA", True),
67+
Optional[datetime.date]: ("DATE", True),
68+
Optional[datetime.datetime]: ("TIMESTAMP", True),
69+
Optional[float]: ("DOUBLE PRECISION", True),
70+
Optional[int]: ("INTEGER", True),
71+
Optional[str]: ("TEXT", True),
72+
Optional[uuid.UUID]: ("UUID", True),
73+
bool: ("BOOLEAN", False),
74+
bytes: ("BYTEA", False),
75+
datetime.date: ("DATE", False),
76+
datetime.datetime: ("TIMESTAMP", False),
77+
float: ("DOUBLE PRECISION", False),
78+
int: ("INTEGER", False),
79+
str: ("TEXT", False),
80+
uuid.UUID: ("UUID", False),
7681
}
7782

7883

79-
def column_info_for_field(field):
80-
if "sql_athame" in field.metadata:
81-
return field.metadata["sql_athame"]
82-
type, *constraints = sql_type_map[field.type]
83-
return ColumnInfo(type, type, tuple(constraints))
84-
85-
8684
T = TypeVar("T", bound="ModelBase")
8785
U = TypeVar("U")
8886

8987

90-
class ModelBase(Mapping[str, Any]):
88+
class ModelBase:
9189
_column_info: Optional[dict[str, ColumnInfo]]
9290
_cache: dict[tuple, Any]
9391
table_name: str
9492
primary_key_names: tuple[str, ...]
9593
array_safe_insert: bool
94+
_type_hints: dict[str, type]
9695

9796
def __init_subclass__(
9897
cls,
@@ -130,27 +129,34 @@ def _cached(cls, key: tuple, thunk: Callable[[], U]) -> U:
130129
cls._cache[key] = thunk()
131130
return cls._cache[key]
132131

133-
def keys(self):
134-
return self.field_names()
135-
136-
def __getitem__(self, key: str) -> Any:
137-
return getattr(self, key)
138-
139-
def __iter__(self) -> Iterator[Any]:
140-
return iter(self.keys())
141-
142-
def __len__(self) -> int:
143-
return len(self.keys())
132+
@classmethod
133+
def type_hints(cls) -> dict[str, type]:
134+
try:
135+
return cls._type_hints
136+
except AttributeError:
137+
cls._type_hints = get_type_hints(cls, include_extras=True)
138+
return cls._type_hints
144139

145-
def get(self, key: str, default: Any = None) -> Any:
146-
return getattr(self, key, default)
140+
@classmethod
141+
def column_info_for_field(cls, field: Field) -> ColumnInfo:
142+
type_info = cls.type_hints()[field.name]
143+
base_type = type_info
144+
if get_origin(type_info) is Annotated:
145+
base_type = type_info.__origin__ # type: ignore
146+
for md in type_info.__metadata__: # type: ignore
147+
if isinstance(md, ColumnInfo):
148+
return md
149+
type, nullable = sql_type_map[base_type]
150+
return ColumnInfo(type=type, nullable=nullable)
147151

148152
@classmethod
149153
def column_info(cls, column: str) -> ColumnInfo:
150154
try:
151155
return cls._column_info[column] # type: ignore
152156
except AttributeError:
153-
cls._column_info = {f.name: column_info_for_field(f) for f in cls._fields()}
157+
cls._column_info = {
158+
f.name: cls.column_info_for_field(f) for f in cls._fields()
159+
}
154160
return cls._column_info[column]
155161

156162
@classmethod

sql_athame/types.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
import dataclasses
22
from typing import TYPE_CHECKING, Any, Union
33

4+
from typing_extensions import TypeAlias
5+
46

57
@dataclasses.dataclass(eq=False)
68
class Placeholder:
@@ -15,8 +17,8 @@ class Slot:
1517
name: str
1618

1719

18-
Part = Union[str, Placeholder, Slot, "Fragment"]
19-
FlatPart = Union[str, Placeholder, Slot]
20+
Part: TypeAlias = Union[str, Placeholder, Slot, "Fragment"]
21+
FlatPart: TypeAlias = Union[str, Placeholder, Slot]
2022

2123
if TYPE_CHECKING:
2224
from .base import Fragment

tests/test_asyncpg.py

Lines changed: 15 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,19 @@
1+
# ruff: noqa: UP007
2+
3+
from __future__ import annotations
4+
15
import asyncio
26
import json
37
import os
48
from dataclasses import dataclass, field
59
from datetime import datetime
6-
from typing import Optional
10+
from typing import Annotated, Optional
711

812
import asyncpg
913
import pytest
14+
from typing_extensions import TypeAlias
1015

11-
from sql_athame import ModelBase, model_field_metadata, sql
16+
from sql_athame import ColumnInfo, ModelBase, sql
1217

1318

1419
@pytest.fixture(autouse=True)
@@ -137,9 +142,7 @@ class Test(
137142
insert_multiple_mode="array_safe",
138143
):
139144
id: int
140-
a: list[int] = field(
141-
metadata=model_field_metadata(type="INT[]", constraints="NOT NULL")
142-
)
145+
a: Annotated[list[int], ColumnInfo(type="INT[]")]
143146
b: str
144147

145148
await conn.execute(*Test.create_table_sql())
@@ -262,10 +265,13 @@ class Test(ModelBase, table_name="test", primary_key=("id1", "id2")):
262265
]
263266

264267

268+
Serial: TypeAlias = Annotated[int, ColumnInfo(type="SERIAL")]
269+
270+
265271
async def test_serial(conn):
266272
@dataclass
267273
class Test(ModelBase, table_name="table", primary_key="id"):
268-
id: int = field(metadata=model_field_metadata(type="SERIAL"))
274+
id: Serial
269275
foo: int
270276
bar: str
271277

@@ -281,8 +287,8 @@ class Test(ModelBase, table_name="table", primary_key="id"):
281287
async def test_unnest_json(conn):
282288
@dataclass
283289
class Test(ModelBase, table_name="table", primary_key="id"):
284-
id: int = field(metadata=model_field_metadata(type="SERIAL"))
285-
json: Optional[list] = field(metadata=model_field_metadata(type="JSONB"))
290+
id: Serial
291+
json: Annotated[Optional[list], ColumnInfo(type="JSONB", nullable=True)]
286292

287293
await conn.set_type_codec(
288294
"jsonb", encoder=json.dumps, decoder=json.loads, schema="pg_catalog"
@@ -307,7 +313,7 @@ class Test(ModelBase, table_name="table", primary_key="id"):
307313
async def test_unnest_empty(conn):
308314
@dataclass
309315
class Test(ModelBase, table_name="table", primary_key="id"):
310-
id: int = field(metadata=model_field_metadata(type="SERIAL"))
316+
id: Serial
311317

312318
await conn.execute(*Test.create_table_sql())
313319

tests/test_dataclasses.py

Lines changed: 11 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -1,21 +1,20 @@
1+
# ruff: noqa: UP007
2+
3+
from __future__ import annotations
4+
15
import uuid
2-
from dataclasses import dataclass, field
3-
from typing import Optional
6+
from dataclasses import dataclass
7+
from typing import Annotated, Optional
48

59
from sql_athame import sql
6-
from sql_athame.dataclasses import ModelBase, model_field_metadata
10+
from sql_athame.dataclasses import ColumnInfo, ModelBase
711

812

913
def test_modelclass():
1014
@dataclass
1115
class Test(ModelBase, table_name="table"):
12-
foo: int = field(
13-
metadata=model_field_metadata(type="INTEGER", constraints="NOT NULL")
14-
)
15-
bar: str = field(
16-
default="hi",
17-
metadata=model_field_metadata(type="TEXT", constraints="NOT NULL"),
18-
)
16+
foo: int
17+
bar: str = "hi"
1918

2019
t = Test(42)
2120

@@ -96,36 +95,18 @@ class Test(ModelBase, table_name="table", primary_key="id"):
9695
]
9796

9897

99-
def test_mapping():
100-
@dataclass
101-
class Test(ModelBase, table_name="table", primary_key="id"):
102-
id: int
103-
foo: int
104-
bar: str
105-
106-
t = Test(1, 2, "foo")
107-
assert t["id"] == 1
108-
assert t["foo"] == 2
109-
assert t["bar"] == "foo"
110-
111-
assert list(t.keys()) == ["id", "foo", "bar"]
112-
113-
assert dict(t) == {"id": 1, "foo": 2, "bar": "foo"}
114-
assert dict(**t) == {"id": 1, "foo": 2, "bar": "foo"}
115-
116-
11798
def test_serial():
11899
@dataclass
119100
class Test(ModelBase, table_name="table", primary_key="id"):
120-
id: int = field(metadata=model_field_metadata(type="SERIAL"))
101+
id: Annotated[int, ColumnInfo(type="SERIAL")]
121102
foo: int
122103
bar: str
123104

124105
assert Test.column_info("id").type == "INTEGER"
125106
assert Test.column_info("id").create_type == "SERIAL"
126107
assert list(Test.create_table_sql()) == [
127108
'CREATE TABLE IF NOT EXISTS "table" ('
128-
'"id" SERIAL, '
109+
'"id" SERIAL NOT NULL, '
129110
'"foo" INTEGER NOT NULL, '
130111
'"bar" TEXT NOT NULL, '
131112
'PRIMARY KEY ("id"))'

0 commit comments

Comments
 (0)