Skip to content

Commit 4697aa5

Browse files
committed
Add insert_multiple_mode="executemany"
1 parent d627642 commit 4697aa5

3 files changed

Lines changed: 69 additions & 10 deletions

File tree

sql_athame/dataclasses.py

Lines changed: 58 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -157,7 +157,7 @@ class ModelBase:
157157
_cache: dict[tuple, Any]
158158
table_name: str
159159
primary_key_names: tuple[str, ...]
160-
array_safe_insert: bool
160+
insert_multiple_mode: str
161161

162162
def __init_subclass__(
163163
cls,
@@ -169,12 +169,9 @@ def __init_subclass__(
169169
):
170170
cls._cache = {}
171171
cls.table_name = table_name
172-
if insert_multiple_mode == "array_safe":
173-
cls.array_safe_insert = True
174-
elif insert_multiple_mode == "unnest":
175-
cls.array_safe_insert = False
176-
else:
172+
if insert_multiple_mode not in ("array_safe", "unnest", "executemany"):
177173
raise ValueError("Unknown `insert_multiple_mode`")
174+
cls.insert_multiple_mode = insert_multiple_mode
178175
if isinstance(primary_key, str):
179176
cls.primary_key_names = (primary_key,)
180177
else:
@@ -506,6 +503,37 @@ def insert_multiple_array_safe_sql(cls: type[T], rows: Iterable[T]) -> Fragment:
506503
),
507504
)
508505

506+
@classmethod
507+
def insert_multiple_executemany_chunk_sql(
508+
cls: type[T], chunk_size: int
509+
) -> Fragment:
510+
def generate() -> Fragment:
511+
columns = len(cls.column_info())
512+
values = ", ".join(
513+
f"({', '.join(f'${i}' for i in chunk)})"
514+
for chunk in chunked(range(1, columns * chunk_size + 1), columns)
515+
)
516+
return sql(
517+
"INSERT INTO {table} ({fields}) VALUES {values}",
518+
table=cls.table_name_sql(),
519+
fields=sql.list(cls.field_names_sql()),
520+
values=sql.literal(values),
521+
).flatten()
522+
523+
return cls._cached(
524+
("insert_multiple_executemany_chunk", chunk_size),
525+
generate,
526+
)
527+
528+
@classmethod
529+
async def insert_multiple_executemany(
530+
cls: type[T], connection_or_pool: Union[Connection, Pool], rows: Iterable[T]
531+
) -> None:
532+
args = [r.field_values() for r in rows]
533+
query = cls.insert_multiple_executemany_chunk_sql(1).query()[0]
534+
if args:
535+
await connection_or_pool.executemany(query, args)
536+
509537
@classmethod
510538
async def insert_multiple_unnest(
511539
cls: type[T], connection_or_pool: Union[Connection, Pool], rows: Iterable[T]
@@ -527,11 +555,28 @@ async def insert_multiple_array_safe(
527555
async def insert_multiple(
528556
cls: type[T], connection_or_pool: Union[Connection, Pool], rows: Iterable[T]
529557
) -> str:
530-
if cls.array_safe_insert:
558+
if cls.insert_multiple_mode == "executemany":
559+
await cls.insert_multiple_executemany(connection_or_pool, rows)
560+
return "INSERT"
561+
elif cls.insert_multiple_mode == "array_safe":
531562
return await cls.insert_multiple_array_safe(connection_or_pool, rows)
532563
else:
533564
return await cls.insert_multiple_unnest(connection_or_pool, rows)
534565

566+
@classmethod
567+
async def upsert_multiple_executemany(
568+
cls: type[T],
569+
connection_or_pool: Union[Connection, Pool],
570+
rows: Iterable[T],
571+
insert_only: FieldNamesSet = (),
572+
) -> None:
573+
args = [r.field_values() for r in rows]
574+
query = cls.upsert_sql(
575+
cls.insert_multiple_executemany_chunk_sql(1), exclude=insert_only
576+
).query()[0]
577+
if args:
578+
await connection_or_pool.executemany(query, args)
579+
535580
@classmethod
536581
async def upsert_multiple_unnest(
537582
cls: type[T],
@@ -566,7 +611,12 @@ async def upsert_multiple(
566611
rows: Iterable[T],
567612
insert_only: FieldNamesSet = (),
568613
) -> str:
569-
if cls.array_safe_insert:
614+
if cls.insert_multiple_mode == "executemany":
615+
await cls.upsert_multiple_executemany(
616+
connection_or_pool, rows, insert_only=insert_only
617+
)
618+
return "INSERT"
619+
elif cls.insert_multiple_mode == "array_safe":
570620
return await cls.upsert_multiple_array_safe(
571621
connection_or_pool, rows, insert_only=insert_only
572622
)

tests/test_asyncpg.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -133,13 +133,14 @@ class Test(ModelBase, table_name="test", primary_key="id"):
133133
assert db.updated == new.updated
134134

135135

136-
async def test_replace_multiple_arrays(conn):
136+
@pytest.mark.parametrize("insert_multiple_mode", ["array_safe", "executemany"])
137+
async def test_replace_multiple_arrays(conn, insert_multiple_mode):
137138
@dataclass(order=True)
138139
class Test(
139140
ModelBase,
140141
table_name="test",
141142
primary_key="id",
142-
insert_multiple_mode="array_safe",
143+
insert_multiple_mode=insert_multiple_mode,
143144
):
144145
id: int
145146
a: Annotated[list[int], ColumnInfo(type="INT[]")]

tests/test_dataclasses.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,14 @@ class Test(ModelBase, table_name="table"):
5757
"hi",
5858
]
5959

60+
assert list(Test.insert_multiple_executemany_chunk_sql(1)) == [
61+
'INSERT INTO "table" ("foo", "bar") VALUES ($1, $2)'
62+
]
63+
64+
assert list(Test.insert_multiple_executemany_chunk_sql(3)) == [
65+
'INSERT INTO "table" ("foo", "bar") VALUES ($1, $2), ($3, $4), ($5, $6)'
66+
]
67+
6068
assert sql(
6169
"INSERT INTO table ({}) VALUES ({})",
6270
sql(",").join(t.field_names_sql()),

0 commit comments

Comments
 (0)