Skip to content

Commit 71e1272

Browse files
authored
Merge pull request #21 from bdowning/updates
Accept and merge multiple ColumnInfo, make type optional
2 parents 7ff2075 + c942b2f commit 71e1272

4 files changed

Lines changed: 68 additions & 15 deletions

File tree

.bumpversion.cfg

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
[bumpversion]
2-
current_version = 0.4.0-alpha-6
2+
current_version = 0.4.0-alpha-7
33
parse = (?P<major>\d+)\.(?P<minor>\d+)\.(?P<patch>\d+)(-(?P<release>.*)-(?P<build>\d+))?
44
serialize =
55
{major}.{minor}.{patch}-{release}-{build}

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[tool.poetry]
22
name = "sql-athame"
3-
version = "0.4.0-alpha-6"
3+
version = "0.4.0-alpha-7"
44
description = "Python tool for slicing and dicing SQL"
55
authors = ["Brian Downing <bdowning@lavos.net>"]
66
license = "MIT"

sql_athame/dataclasses.py

Lines changed: 47 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import datetime
2+
import functools
23
import uuid
34
from collections.abc import AsyncGenerator, Iterable, Mapping
45
from dataclasses import Field, InitVar, dataclass, fields
@@ -29,27 +30,56 @@
2930

3031
@dataclass
3132
class ColumnInfo:
32-
type: str
33-
create_type: str = ""
34-
nullable: bool = False
33+
type: Optional[str] = None
34+
create_type: Optional[str] = None
35+
nullable: Optional[bool] = None
3536
_constraints: tuple[str, ...] = ()
3637

3738
constraints: InitVar[Union[str, Iterable[str], None]] = None
3839

3940
def __post_init__(self, constraints: Union[str, Iterable[str], None]) -> None:
40-
if self.create_type == "":
41-
self.create_type = self.type
42-
self.type = sql_create_type_map.get(self.type.upper(), self.type)
4341
if constraints is not None:
4442
if type(constraints) is str:
4543
constraints = (constraints,)
4644
self._constraints = tuple(constraints)
4745

46+
@staticmethod
47+
def merge(a: "ColumnInfo", b: "ColumnInfo") -> "ColumnInfo":
48+
return ColumnInfo(
49+
type=b.type if b.type is not None else a.type,
50+
create_type=b.create_type if b.create_type is not None else a.create_type,
51+
nullable=b.nullable if b.nullable is not None else a.nullable,
52+
_constraints=(*a._constraints, *b._constraints),
53+
)
54+
55+
56+
@dataclass
57+
class ConcreteColumnInfo:
58+
type: str
59+
create_type: str
60+
nullable: bool
61+
constraints: tuple[str, ...]
62+
63+
@staticmethod
64+
def from_column_info(name: str, *args: ColumnInfo) -> "ConcreteColumnInfo":
65+
info = functools.reduce(ColumnInfo.merge, args, ColumnInfo())
66+
if info.create_type is None and info.type is not None:
67+
info.create_type = info.type
68+
info.type = sql_create_type_map.get(info.type.upper(), info.type)
69+
if type(info.type) is not str or type(info.create_type) is not str:
70+
raise ValueError(f"Missing SQL type for column {name!r}")
71+
return ConcreteColumnInfo(
72+
type=info.type,
73+
create_type=info.create_type,
74+
nullable=bool(info.nullable),
75+
constraints=info._constraints,
76+
)
77+
4878
def create_table_string(self) -> str:
4979
parts = (
5080
self.create_type,
5181
*(() if self.nullable else ("NOT NULL",)),
52-
*self._constraints,
82+
*self.constraints,
5383
)
5484
return " ".join(parts)
5585

@@ -86,7 +116,7 @@ def create_table_string(self) -> str:
86116

87117

88118
class ModelBase:
89-
_column_info: Optional[dict[str, ColumnInfo]]
119+
_column_info: Optional[dict[str, ConcreteColumnInfo]]
90120
_cache: dict[tuple, Any]
91121
table_name: str
92122
primary_key_names: tuple[str, ...]
@@ -138,19 +168,23 @@ def type_hints(cls) -> dict[str, type]:
138168
return cls._type_hints
139169

140170
@classmethod
141-
def column_info_for_field(cls, field: Field) -> ColumnInfo:
171+
def column_info_for_field(cls, field: Field) -> ConcreteColumnInfo:
142172
type_info = cls.type_hints()[field.name]
143173
base_type = type_info
144174
if get_origin(type_info) is Annotated:
145175
base_type = type_info.__origin__ # type: ignore
176+
info = []
177+
if base_type in sql_type_map:
178+
_type, nullable = sql_type_map[base_type]
179+
info.append(ColumnInfo(type=_type, nullable=nullable))
180+
if get_origin(type_info) is Annotated:
146181
for md in type_info.__metadata__: # type: ignore
147182
if isinstance(md, ColumnInfo):
148-
return md
149-
type, nullable = sql_type_map[base_type]
150-
return ColumnInfo(type=type, nullable=nullable)
183+
info.append(md)
184+
return ConcreteColumnInfo.from_column_info(field.name, *info)
151185

152186
@classmethod
153-
def column_info(cls, column: str) -> ColumnInfo:
187+
def column_info(cls, column: str) -> ConcreteColumnInfo:
154188
try:
155189
return cls._column_info[column] # type: ignore
156190
except AttributeError:

tests/test_dataclasses.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,8 @@
66
from dataclasses import dataclass
77
from typing import Annotated, Optional
88

9+
import pytest
10+
911
from sql_athame import sql
1012
from sql_athame.dataclasses import ColumnInfo, ModelBase
1113

@@ -67,16 +69,33 @@ class Test(ModelBase, table_name="table", primary_key="foo"):
6769
foo: int
6870
bar: str
6971
baz: Optional[uuid.UUID]
72+
quux: Annotated[int, ColumnInfo(constraints="REFERENCES foobar")]
73+
quuux: Annotated[
74+
int,
75+
ColumnInfo(constraints="REFERENCES foobar"),
76+
ColumnInfo(constraints="BLAH", nullable=True),
77+
]
7078

7179
assert list(Test.create_table_sql()) == [
7280
'CREATE TABLE IF NOT EXISTS "table" ('
7381
'"foo" INTEGER NOT NULL, '
7482
'"bar" TEXT NOT NULL, '
7583
'"baz" UUID, '
84+
'"quux" INTEGER NOT NULL REFERENCES foobar, '
85+
'"quuux" INTEGER REFERENCES foobar BLAH, '
7686
'PRIMARY KEY ("foo"))'
7787
]
7888

7989

90+
def test_modelclass_missing_type():
91+
@dataclass
92+
class Test(ModelBase, table_name="table", primary_key="foo"):
93+
foo: dict
94+
95+
with pytest.raises(ValueError, match="Missing SQL type for column 'foo'"):
96+
Test.create_table_sql()
97+
98+
8099
def test_upsert():
81100
@dataclass
82101
class Test(ModelBase, table_name="table", primary_key="id"):

0 commit comments

Comments
 (0)