Skip to content

Commit a2918b9

Browse files
committed
Add insert_only fields to upsert/replace_multiple
1 parent 7115e29 commit a2918b9

2 files changed

Lines changed: 67 additions & 11 deletions

File tree

sql_athame/dataclasses.py

Lines changed: 34 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -458,31 +458,46 @@ async def insert_multiple(
458458

459459
@classmethod
460460
async def upsert_multiple_unnest(
461-
cls: Type[T], connection_or_pool: Union[Connection, Pool], rows: Iterable[T]
461+
cls: Type[T],
462+
connection_or_pool: Union[Connection, Pool],
463+
rows: Iterable[T],
464+
insert_only: FieldNamesSet = (),
462465
) -> str:
463466
return await connection_or_pool.execute(
464-
*cls.upsert_sql(cls.insert_multiple_sql(rows))
467+
*cls.upsert_sql(cls.insert_multiple_sql(rows), exclude=insert_only)
465468
)
466469

467470
@classmethod
468471
async def upsert_multiple_array_safe(
469-
cls: Type[T], connection_or_pool: Union[Connection, Pool], rows: Iterable[T]
472+
cls: Type[T],
473+
connection_or_pool: Union[Connection, Pool],
474+
rows: Iterable[T],
475+
insert_only: FieldNamesSet = (),
470476
) -> str:
471477
last = ""
472478
for chunk in chunked(rows, 100):
473479
last = await connection_or_pool.execute(
474-
*cls.upsert_sql(cls.insert_multiple_array_safe_sql(chunk))
480+
*cls.upsert_sql(
481+
cls.insert_multiple_array_safe_sql(chunk), exclude=insert_only
482+
)
475483
)
476484
return last
477485

478486
@classmethod
479487
async def upsert_multiple(
480-
cls: Type[T], connection_or_pool: Union[Connection, Pool], rows: Iterable[T]
488+
cls: Type[T],
489+
connection_or_pool: Union[Connection, Pool],
490+
rows: Iterable[T],
491+
insert_only: FieldNamesSet = (),
481492
) -> str:
482493
if cls.array_safe_insert:
483-
return await cls.upsert_multiple_array_safe(connection_or_pool, rows)
494+
return await cls.upsert_multiple_array_safe(
495+
connection_or_pool, rows, insert_only=insert_only
496+
)
484497
else:
485-
return await cls.upsert_multiple_unnest(connection_or_pool, rows)
498+
return await cls.upsert_multiple_unnest(
499+
connection_or_pool, rows, insert_only=insert_only
500+
)
486501

487502
@classmethod
488503
def _get_equal_ignoring_fn(
@@ -505,9 +520,11 @@ async def replace_multiple(
505520
*,
506521
where: Where,
507522
ignore: FieldNamesSet = (),
523+
insert_only: FieldNamesSet = (),
508524
) -> Tuple[List[T], List[T], List[T]]:
525+
ignore = sorted(set(ignore) | set(insert_only))
509526
equal_ignoring = cls._cached(
510-
("equal_ignoring", tuple(sorted(ignore))),
527+
("equal_ignoring", tuple(ignore)),
511528
lambda: cls._get_equal_ignoring_fn(ignore),
512529
)
513530
pending = {row.primary_key(): row for row in map(cls.ensure_model, rows)}
@@ -529,7 +546,9 @@ async def replace_multiple(
529546
created = list(pending.values())
530547

531548
if created or updated:
532-
await cls.upsert_multiple(connection, (*created, *updated))
549+
await cls.upsert_multiple(
550+
connection, (*created, *updated), insert_only=insert_only
551+
)
533552
if deleted:
534553
await cls.delete_multiple(connection, deleted)
535554

@@ -561,9 +580,11 @@ async def replace_multiple_reporting_differences(
561580
*,
562581
where: Where,
563582
ignore: FieldNamesSet = (),
583+
insert_only: FieldNamesSet = (),
564584
) -> Tuple[List[T], List[Tuple[T, T, List[str]]], List[T]]:
585+
ignore = sorted(set(ignore) | set(insert_only))
565586
differences_ignoring = cls._cached(
566-
("differences_ignoring", tuple(sorted(ignore))),
587+
("differences_ignoring", tuple(ignore)),
567588
lambda: cls._get_differences_ignoring_fn(ignore),
568589
)
569590

@@ -588,7 +609,9 @@ async def replace_multiple_reporting_differences(
588609

589610
if created or updated_triples:
590611
await cls.upsert_multiple(
591-
connection, (*created, *(t[1] for t in updated_triples))
612+
connection,
613+
(*created, *(t[1] for t in updated_triples)),
614+
insert_only=insert_only,
592615
)
593616
if deleted:
594617
await cls.delete_multiple(connection, deleted)

tests/test_asyncpg.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
1+
import asyncio
12
import json
23
import os
34
from dataclasses import dataclass, field
5+
from datetime import datetime
46
from typing import Optional
57

68
import asyncpg
@@ -93,6 +95,37 @@ class Test(ModelBase, table_name="test", primary_key="id"):
9395
]
9496

9597

98+
async def test_replace_multiple_ignore_insert_only(conn):
99+
@dataclass(order=True)
100+
class Test(ModelBase, table_name="test", primary_key="id"):
101+
id: int
102+
a: int
103+
created: datetime = field(default_factory=datetime.utcnow)
104+
updated: datetime = field(default_factory=datetime.utcnow)
105+
106+
await conn.execute(*Test.create_table_sql())
107+
108+
data = [Test(1, 1), Test(2, 1), Test(3, 2)]
109+
await Test.insert_multiple(conn, data)
110+
111+
await asyncio.sleep(0.1)
112+
new_data = [Test(1, 1), Test(2, 4), Test(3, 2)]
113+
c, u, d = await Test.replace_multiple(
114+
conn, new_data, where=[], ignore=["updated"], insert_only=["created"]
115+
)
116+
assert not c and not d
117+
assert len(u) == 1
118+
119+
db_data = await Test.select(conn, order_by="id")
120+
assert data[0] == db_data[0]
121+
assert data[2] == db_data[2]
122+
orig = data[1]
123+
new = new_data[1]
124+
db = db_data[1]
125+
assert db.created == orig.created
126+
assert db.updated == new.updated
127+
128+
96129
async def test_replace_multiple_arrays(conn):
97130
@dataclass(order=True)
98131
class Test(

0 commit comments

Comments
 (0)