Skip to content

Commit 05a9325

Browse files
authored
Merge pull request #12 from bdowning/safe-array-insert
Add insert_multiple_mode and array safety
2 parents d5f71b1 + 61a4de1 commit 05a9325

6 files changed

Lines changed: 126 additions & 6 deletions

File tree

.bumpversion.cfg

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
[bumpversion]
2-
current_version = 0.4.0-alpha-1
2+
current_version = 0.4.0-alpha-2
33
parse = (?P<major>\d+)\.(?P<minor>\d+)\.(?P<patch>\d+)(-(?P<release>.*)-(?P<build>\d+))?
44
serialize =
55
{major}.{minor}.{patch}-{release}-{build}

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[tool.poetry]
22
name = "sql-athame"
3-
version = "0.4.0-alpha-1"
3+
version = "0.4.0-alpha-2"
44
description = "Python tool for slicing and dicing SQL"
55
authors = ["Brian Downing <bdowning@lavos.net>"]
66
license = "MIT"

run

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ usage=()
44

55
usage+=(" $0 tests - run tests")
66
tests() {
7-
poetry run pytest
7+
poetry run pytest "$@"
88
lint
99
}
1010

setup.cfg

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,8 @@ ignore_missing_imports = True
3535
ignore =
3636
# max line length
3737
E501,
38+
# E203 whitespace before ':'; fights with black
39+
E203,
3840
# multiple statements on one line (def); fights with black
3941
E704,
4042
# line break before binary operator; fights with black

sql_athame/dataclasses.py

Lines changed: 70 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -100,12 +100,24 @@ class ModelBase(Mapping[str, Any]):
100100
_cache: Dict[tuple, Any]
101101
table_name: str
102102
primary_key_names: Tuple[str, ...]
103+
array_safe_insert: bool
103104

104105
def __init_subclass__(
105-
cls, *, table_name: str, primary_key: Union[FieldNames, str] = (), **kwargs: Any
106+
cls,
107+
*,
108+
table_name: str,
109+
primary_key: Union[FieldNames, str] = (),
110+
insert_multiple_mode: str = "unnest",
111+
**kwargs: Any,
106112
):
107113
cls._cache = {}
108114
cls.table_name = table_name
115+
if insert_multiple_mode == "array_safe":
116+
cls.array_safe_insert = True
117+
elif insert_multiple_mode == "unnest":
118+
cls.array_safe_insert = False
119+
else:
120+
raise ValueError("Unknown `insert_multiple_mode`")
109121
if isinstance(primary_key, str):
110122
cls.primary_key_names = (primary_key,)
111123
else:
@@ -407,19 +419,69 @@ def insert_multiple_sql(cls: Type[T], rows: Iterable[T]) -> Fragment:
407419
)
408420

409421
@classmethod
410-
async def insert_multiple(
422+
def insert_multiple_array_safe_sql(cls: Type[T], rows: Iterable[T]) -> Fragment:
423+
return sql(
424+
"INSERT INTO {table} ({fields}) VALUES {values}",
425+
table=cls.table_name_sql(),
426+
fields=sql.list(cls.field_names_sql()),
427+
values=sql.list(
428+
sql("({})", sql.list(row.field_values_sql(default_none=True)))
429+
for row in rows
430+
),
431+
)
432+
433+
@classmethod
434+
async def insert_multiple_unnest(
411435
cls: Type[T], connection_or_pool: Union[Connection, Pool], rows: Iterable[T]
412436
) -> str:
413437
return await connection_or_pool.execute(*cls.insert_multiple_sql(rows))
414438

415439
@classmethod
416-
async def upsert_multiple(
440+
async def insert_multiple_array_safe(
441+
cls: Type[T], connection_or_pool: Union[Connection, Pool], rows: Iterable[T]
442+
) -> str:
443+
for chunk in chunked(rows, 100):
444+
last = await connection_or_pool.execute(
445+
*cls.insert_multiple_array_safe_sql(chunk)
446+
)
447+
return last
448+
449+
@classmethod
450+
async def insert_multiple(
451+
cls: Type[T], connection_or_pool: Union[Connection, Pool], rows: Iterable[T]
452+
) -> str:
453+
if cls.array_safe_insert:
454+
return await cls.insert_multiple_array_safe(connection_or_pool, rows)
455+
else:
456+
return await cls.insert_multiple_unnest(connection_or_pool, rows)
457+
458+
@classmethod
459+
async def upsert_multiple_unnest(
417460
cls: Type[T], connection_or_pool: Union[Connection, Pool], rows: Iterable[T]
418461
) -> str:
419462
return await connection_or_pool.execute(
420463
*cls.upsert_sql(cls.insert_multiple_sql(rows))
421464
)
422465

466+
@classmethod
467+
async def upsert_multiple_array_safe(
468+
cls: Type[T], connection_or_pool: Union[Connection, Pool], rows: Iterable[T]
469+
) -> str:
470+
for chunk in chunked(rows, 100):
471+
last = await connection_or_pool.execute(
472+
*cls.upsert_sql(cls.insert_multiple_array_safe_sql(chunk))
473+
)
474+
return last
475+
476+
@classmethod
477+
async def upsert_multiple(
478+
cls: Type[T], connection_or_pool: Union[Connection, Pool], rows: Iterable[T]
479+
) -> str:
480+
if cls.array_safe_insert:
481+
return await cls.upsert_multiple_array_safe(connection_or_pool, rows)
482+
else:
483+
return await cls.upsert_multiple_unnest(connection_or_pool, rows)
484+
423485
@classmethod
424486
def _get_equal_ignoring_fn(
425487
cls: Type[T], ignore: FieldNamesSet = ()
@@ -530,3 +592,8 @@ async def replace_multiple_reporting_differences(
530592
await cls.delete_multiple(connection, deleted)
531593

532594
return created, updated_triples, deleted
595+
596+
597+
def chunked(lst, n):
598+
for i in range(0, len(lst), n):
599+
yield lst[i : i + n]

tests/test_asyncpg.py

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -91,6 +91,57 @@ class Test(ModelBase, table_name="test", primary_key="id"):
9191
]
9292

9393

94+
async def test_replace_multiple_arrays(conn):
95+
@dataclass(order=True)
96+
class Test(
97+
ModelBase,
98+
table_name="test",
99+
primary_key="id",
100+
insert_multiple_mode="array_safe",
101+
):
102+
id: int
103+
a: list[int] = field(
104+
metadata=model_field_metadata(type="INT[]", constraints="NOT NULL")
105+
)
106+
b: str
107+
108+
await conn.execute(*Test.create_table_sql())
109+
110+
data = [
111+
Test(1, [1], "foo"),
112+
Test(2, [1, 3, 5], "bar"),
113+
Test(3, [], "quux"),
114+
]
115+
await Test.insert_multiple(conn, data)
116+
117+
c, u, d = await Test.replace_multiple(conn, [], where=[])
118+
assert not c and not u
119+
assert len(d) == 3
120+
assert await Test.select(conn) == []
121+
122+
await Test.insert_multiple(conn, data)
123+
124+
c, u, d = await Test.replace_multiple(conn, [], where=sql("a @> ARRAY[1]"))
125+
assert not c and not u
126+
assert len(d) == 2
127+
assert [x.id for x in await Test.select(conn)] == [3]
128+
129+
await conn.execute("DELETE FROM test")
130+
await Test.insert_multiple(conn, data)
131+
132+
c, u, d = await Test.replace_multiple(
133+
conn, [Test(1, [5], "apples"), Test(4, [6], "fred")], where=sql("a @> ARRAY[1]")
134+
)
135+
assert len(c) == 1
136+
assert len(u) == 1
137+
assert len(d) == 1
138+
assert list(sorted(await Test.select(conn))) == [
139+
Test(1, [5], "apples"),
140+
Test(3, [], "quux"),
141+
Test(4, [6], "fred"),
142+
]
143+
144+
94145
async def test_replace_multiple_reporting_differences(conn):
95146
@dataclass(order=True)
96147
class Test(ModelBase, table_name="test", primary_key="id"):

0 commit comments

Comments
 (0)