Skip to content

Commit 66be1c4

Browse files
authored
Merge pull request #28 from bdowning/executemany
executemany, {fetch,cursor}_from, switch to uv
2 parents d627642 + a174ca9 commit 66be1c4

11 files changed

Lines changed: 1281 additions & 1329 deletions

File tree

.bumpversion.cfg

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
[bumpversion]
2-
current_version = 0.4.0-alpha-10
2+
current_version = 0.4.0-alpha-11
33
parse = (?P<major>\d+)\.(?P<minor>\d+)\.(?P<patch>\d+)(-(?P<release>.*)-(?P<build>\d+))?
44
serialize =
55
{major}.{minor}.{patch}-{release}-{build}

.github/workflows/publish.yml

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -9,13 +9,19 @@ jobs:
99
publish:
1010
runs-on: ubuntu-latest
1111
env:
12-
POETRY_PYPI_TOKEN_PYPI: ${{ secrets.POETRY_PYPI_TOKEN_PYPI }}
12+
UV_PUBLISH_TOKEN: ${{ secrets.POETRY_PYPI_TOKEN_PYPI }}
1313
steps:
1414
- uses: actions/checkout@v4
15-
- run: pipx install poetry
15+
16+
- uses: astral-sh/setup-uv@v5
17+
with:
18+
version: "0.6.6"
19+
enable-cache: true
20+
python-version: '3.12'
21+
1622
- uses: actions/setup-python@v5
1723
with:
1824
python-version: '3.12'
19-
cache: poetry
20-
- run: poetry build
21-
- run: poetry publish
25+
26+
- run: uv build
27+
- run: uv publish

.github/workflows/test.yml

Lines changed: 14 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ jobs:
1414
- '3.10'
1515
- '3.11'
1616
- '3.12'
17+
- '3.13'
1718
runs-on: ubuntu-latest
1819
services:
1920
postgres:
@@ -26,16 +27,22 @@ jobs:
2627
PGPORT: 5432
2728
steps:
2829
- uses: actions/checkout@v4
29-
- run: pipx install poetry
30+
31+
- uses: astral-sh/setup-uv@v5
32+
with:
33+
version: "0.6.6"
34+
enable-cache: true
35+
python-version: ${{ matrix.python-version }}
36+
3037
- uses: actions/setup-python@v5
3138
with:
3239
python-version: ${{ matrix.python-version }}
33-
cache: poetry
34-
- run: poetry install
35-
- run: poetry run ruff check
36-
- run: poetry run ruff format --diff
40+
41+
- run: uv sync
42+
- run: uv run ruff check
43+
- run: uv run ruff format --diff
3744
if: success() || failure()
38-
- run: poetry run mypy sql_athame/**.py tests/**.py
45+
- run: uv run mypy sql_athame/**.py tests/**.py
3946
if: success() || failure()
40-
- run: poetry run pytest
47+
- run: uv run pytest
4148
if: success() || failure()

poetry.lock

Lines changed: 0 additions & 1255 deletions
This file was deleted.

poetry.toml

Lines changed: 0 additions & 6 deletions
This file was deleted.

pyproject.toml

Lines changed: 32 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -1,37 +1,43 @@
1-
[tool.poetry]
1+
[project]
2+
authors = [
3+
{name = "Brian Downing", email = "bdowning@lavos.net"},
4+
]
5+
license = {text = "MIT"}
6+
requires-python = "<4.0,>=3.9"
7+
dependencies = [
8+
"typing-extensions",
9+
]
210
name = "sql-athame"
3-
version = "0.4.0-alpha-10"
11+
version = "0.4.0-alpha-11"
412
description = "Python tool for slicing and dicing SQL"
5-
authors = ["Brian Downing <bdowning@lavos.net>"]
6-
license = "MIT"
713
readme = "README.md"
14+
15+
[project.urls]
816
homepage = "https://github.com/bdowning/sql-athame"
917
repository = "https://github.com/bdowning/sql-athame"
1018

11-
[tool.poetry.extras]
12-
asyncpg = ["asyncpg"]
13-
14-
[tool.poetry.dependencies]
15-
python = "^3.9"
16-
asyncpg = { version = "*", optional = true }
17-
typing-extensions = "*"
19+
[project.optional-dependencies]
20+
asyncpg = [
21+
"asyncpg",
22+
]
1823

19-
[tool.poetry.group.dev.dependencies]
20-
pytest = "*"
21-
mypy = "*"
22-
flake8 = "*"
23-
ipython = "*"
24-
pytest-cov = "*"
25-
bump2version = "*"
26-
asyncpg = "*"
27-
pytest-asyncio = "*"
28-
grip = "*"
29-
SQLAlchemy = "*"
30-
ruff = "*"
24+
[dependency-groups]
25+
dev = [
26+
"SQLAlchemy",
27+
"asyncpg",
28+
"bump2version",
29+
"flake8",
30+
"grip",
31+
"ipython",
32+
"mypy",
33+
"pytest",
34+
"pytest-asyncio",
35+
"pytest-cov",
36+
"ruff",
37+
]
3138

32-
[build-system]
33-
requires = ["poetry>=0.12"]
34-
build-backend = "poetry.masonry.api"
39+
[tool.setuptools]
40+
packages = ["sql_athame"]
3541

3642
[tool.ruff]
3743
target-version = "py39"
@@ -62,7 +68,6 @@ ignore = [
6268
"E501", # line too long
6369
"E721", # type checks, currently broken
6470
"ISC001", # conflicts with ruff format
65-
"PT004", # Fixture `...` does not return anything, add leading underscore
6671
"RET505", # Unnecessary `else` after `return` statement
6772
"RET506", # Unnecessary `else` after `raise` statement
6873
]

run

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -4,26 +4,26 @@ usage=()
44

55
usage+=(" $0 tests - run tests")
66
tests() {
7-
poetry run pytest "$@"
7+
uv run pytest "$@"
88
lint
99
}
1010

1111
usage+=(" $0 refmt - reformat code")
1212
refmt() {
13-
poetry run ruff check --select I --fix
14-
poetry run ruff format
13+
uv run ruff check --select I --fix
14+
uv run ruff format
1515
}
1616

1717
usage+=(" $0 lint - run linting")
1818
lint() {
19-
poetry run ruff check
20-
poetry run ruff format --diff
21-
poetry run mypy sql_athame/**.py tests/**.py
19+
uv run ruff check
20+
uv run ruff format --diff
21+
uv run mypy sql_athame/**.py tests/**.py
2222
}
2323

2424
usage+=(" $0 bump2version {major|minor|patch} - bump version number")
2525
bump2version() {
26-
poetry run bump2version "$@"
26+
uv run bump2version "$@"
2727
}
2828

2929
cmd=$1

sql_athame/dataclasses.py

Lines changed: 85 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -157,7 +157,7 @@ class ModelBase:
157157
_cache: dict[tuple, Any]
158158
table_name: str
159159
primary_key_names: tuple[str, ...]
160-
array_safe_insert: bool
160+
insert_multiple_mode: str
161161

162162
def __init_subclass__(
163163
cls,
@@ -169,12 +169,9 @@ def __init_subclass__(
169169
):
170170
cls._cache = {}
171171
cls.table_name = table_name
172-
if insert_multiple_mode == "array_safe":
173-
cls.array_safe_insert = True
174-
elif insert_multiple_mode == "unnest":
175-
cls.array_safe_insert = False
176-
else:
172+
if insert_multiple_mode not in ("array_safe", "unnest", "executemany"):
177173
raise ValueError("Unknown `insert_multiple_mode`")
174+
cls.insert_multiple_mode = insert_multiple_mode
178175
if isinstance(primary_key, str):
179176
cls.primary_key_names = (primary_key,)
180177
else:
@@ -357,19 +354,37 @@ def select_sql(
357354
return query
358355

359356
@classmethod
360-
async def select_cursor(
357+
async def cursor_from(
358+
cls: type[T],
359+
connection: Connection,
360+
query: Fragment,
361+
prefetch: int = 1000,
362+
) -> AsyncGenerator[T, None]:
363+
async for row in connection.cursor(*query, prefetch=prefetch):
364+
yield cls.from_mapping(row)
365+
366+
@classmethod
367+
def select_cursor(
361368
cls: type[T],
362369
connection: Connection,
363370
order_by: Union[FieldNames, str] = (),
364371
for_update: bool = False,
365372
where: Where = (),
366373
prefetch: int = 1000,
367374
) -> AsyncGenerator[T, None]:
368-
async for row in connection.cursor(
369-
*cls.select_sql(order_by=order_by, for_update=for_update, where=where),
375+
return cls.cursor_from(
376+
connection,
377+
cls.select_sql(order_by=order_by, for_update=for_update, where=where),
370378
prefetch=prefetch,
371-
):
372-
yield cls.from_mapping(row)
379+
)
380+
381+
@classmethod
382+
async def fetch_from(
383+
cls: type[T],
384+
connection_or_pool: Union[Connection, Pool],
385+
query: Fragment,
386+
) -> list[T]:
387+
return [cls.from_mapping(row) for row in await connection_or_pool.fetch(*query)]
373388

374389
@classmethod
375390
async def select(
@@ -379,12 +394,10 @@ async def select(
379394
for_update: bool = False,
380395
where: Where = (),
381396
) -> list[T]:
382-
return [
383-
cls.from_mapping(row)
384-
for row in await connection_or_pool.fetch(
385-
*cls.select_sql(order_by=order_by, for_update=for_update, where=where)
386-
)
387-
]
397+
return await cls.fetch_from(
398+
connection_or_pool,
399+
cls.select_sql(order_by=order_by, for_update=for_update, where=where),
400+
)
388401

389402
@classmethod
390403
def create_sql(cls: type[T], **kwargs: Any) -> Fragment:
@@ -506,6 +519,37 @@ def insert_multiple_array_safe_sql(cls: type[T], rows: Iterable[T]) -> Fragment:
506519
),
507520
)
508521

522+
@classmethod
523+
def insert_multiple_executemany_chunk_sql(
524+
cls: type[T], chunk_size: int
525+
) -> Fragment:
526+
def generate() -> Fragment:
527+
columns = len(cls.column_info())
528+
values = ", ".join(
529+
f"({', '.join(f'${i}' for i in chunk)})"
530+
for chunk in chunked(range(1, columns * chunk_size + 1), columns)
531+
)
532+
return sql(
533+
"INSERT INTO {table} ({fields}) VALUES {values}",
534+
table=cls.table_name_sql(),
535+
fields=sql.list(cls.field_names_sql()),
536+
values=sql.literal(values),
537+
).flatten()
538+
539+
return cls._cached(
540+
("insert_multiple_executemany_chunk", chunk_size),
541+
generate,
542+
)
543+
544+
@classmethod
545+
async def insert_multiple_executemany(
546+
cls: type[T], connection_or_pool: Union[Connection, Pool], rows: Iterable[T]
547+
) -> None:
548+
args = [r.field_values() for r in rows]
549+
query = cls.insert_multiple_executemany_chunk_sql(1).query()[0]
550+
if args:
551+
await connection_or_pool.executemany(query, args)
552+
509553
@classmethod
510554
async def insert_multiple_unnest(
511555
cls: type[T], connection_or_pool: Union[Connection, Pool], rows: Iterable[T]
@@ -527,11 +571,28 @@ async def insert_multiple_array_safe(
527571
async def insert_multiple(
528572
cls: type[T], connection_or_pool: Union[Connection, Pool], rows: Iterable[T]
529573
) -> str:
530-
if cls.array_safe_insert:
574+
if cls.insert_multiple_mode == "executemany":
575+
await cls.insert_multiple_executemany(connection_or_pool, rows)
576+
return "INSERT"
577+
elif cls.insert_multiple_mode == "array_safe":
531578
return await cls.insert_multiple_array_safe(connection_or_pool, rows)
532579
else:
533580
return await cls.insert_multiple_unnest(connection_or_pool, rows)
534581

582+
@classmethod
583+
async def upsert_multiple_executemany(
584+
cls: type[T],
585+
connection_or_pool: Union[Connection, Pool],
586+
rows: Iterable[T],
587+
insert_only: FieldNamesSet = (),
588+
) -> None:
589+
args = [r.field_values() for r in rows]
590+
query = cls.upsert_sql(
591+
cls.insert_multiple_executemany_chunk_sql(1), exclude=insert_only
592+
).query()[0]
593+
if args:
594+
await connection_or_pool.executemany(query, args)
595+
535596
@classmethod
536597
async def upsert_multiple_unnest(
537598
cls: type[T],
@@ -566,7 +627,12 @@ async def upsert_multiple(
566627
rows: Iterable[T],
567628
insert_only: FieldNamesSet = (),
568629
) -> str:
569-
if cls.array_safe_insert:
630+
if cls.insert_multiple_mode == "executemany":
631+
await cls.upsert_multiple_executemany(
632+
connection_or_pool, rows, insert_only=insert_only
633+
)
634+
return "INSERT"
635+
elif cls.insert_multiple_mode == "array_safe":
570636
return await cls.upsert_multiple_array_safe(
571637
connection_or_pool, rows, insert_only=insert_only
572638
)

tests/test_asyncpg.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -133,13 +133,14 @@ class Test(ModelBase, table_name="test", primary_key="id"):
133133
assert db.updated == new.updated
134134

135135

136-
async def test_replace_multiple_arrays(conn):
136+
@pytest.mark.parametrize("insert_multiple_mode", ["array_safe", "executemany"])
137+
async def test_replace_multiple_arrays(conn, insert_multiple_mode):
137138
@dataclass(order=True)
138139
class Test(
139140
ModelBase,
140141
table_name="test",
141142
primary_key="id",
142-
insert_multiple_mode="array_safe",
143+
insert_multiple_mode=insert_multiple_mode,
143144
):
144145
id: int
145146
a: Annotated[list[int], ColumnInfo(type="INT[]")]

tests/test_dataclasses.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,14 @@ class Test(ModelBase, table_name="table"):
5757
"hi",
5858
]
5959

60+
assert list(Test.insert_multiple_executemany_chunk_sql(1)) == [
61+
'INSERT INTO "table" ("foo", "bar") VALUES ($1, $2)'
62+
]
63+
64+
assert list(Test.insert_multiple_executemany_chunk_sql(3)) == [
65+
'INSERT INTO "table" ("foo", "bar") VALUES ($1, $2), ($3, $4), ($5, $6)'
66+
]
67+
6068
assert sql(
6169
"INSERT INTO table ({}) VALUES ({})",
6270
sql(",").join(t.field_names_sql()),

0 commit comments

Comments
 (0)