Skip to content

Commit ef5ef1c

Browse files
committed
Add insert_only to upsert
1 parent e2cf4a1 commit ef5ef1c

3 files changed

Lines changed: 85 additions & 2 deletions

File tree

sql_athame/dataclasses.py

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -855,20 +855,31 @@ def upsert_sql(cls, insert_sql: Fragment, exclude: FieldNamesSet = ()) -> Fragme
855855
return Fragment([insert_sql, cached])
856856

857857
async def upsert(
858-
self, connection_or_pool: Union[Connection, Pool], exclude: FieldNamesSet = ()
858+
self,
859+
connection_or_pool: Union[Connection, Pool],
860+
exclude: FieldNamesSet = (),
861+
insert_only: FieldNamesSet = (),
859862
) -> bool:
860863
"""Insert or update this instance in the database.
861864
862865
Args:
863866
connection_or_pool: Database connection or pool
864867
exclude: Field names to exclude from the UPDATE clause
868+
insert_only: Field names that should only be set on INSERT, not UPDATE
865869
866870
Returns:
867871
True if the record was updated, False if it was inserted
872+
873+
Example:
874+
>>> user = User(id=1, name="Alice", created_at=datetime.now())
875+
>>> # Only set created_at on INSERT, not UPDATE
876+
>>> was_updated = await user.upsert(pool, insert_only={'created_at'})
868877
"""
878+
# Combine exclude and insert_only for the UPDATE clause
879+
update_exclude = set(exclude) | set(insert_only)
869880
query = sql(
870881
"{} RETURNING xmax",
871-
self.upsert_sql(self.insert_sql(exclude=exclude), exclude=exclude),
882+
self.upsert_sql(self.insert_sql(exclude=exclude), exclude=update_exclude),
872883
)
873884
result = await connection_or_pool.fetchrow(*query)
874885
return result["xmax"] != 0

tests/test_asyncpg.py

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -321,3 +321,44 @@ class Test(ModelBase, table_name="table", primary_key="id"):
321321
await Test.insert_multiple(conn, [])
322322

323323
assert list(await Test.select(conn)) == []
324+
325+
326+
async def test_upsert_insert_only(conn):
327+
@dataclass
328+
class Test(ModelBase, table_name="test_upsert", primary_key="id"):
329+
id: int
330+
name: str
331+
count: int
332+
created_at: str
333+
334+
await conn.execute(*Test.create_table_sql())
335+
336+
# Initial insert
337+
record = Test(1, "Alice", 5, "2023-01-01")
338+
was_updated = await record.upsert(conn)
339+
assert not was_updated # Should be False for initial insert
340+
341+
# Verify record was inserted
342+
result = await Test.select(conn, where=sql("id = {}", 1))
343+
assert len(result) == 1
344+
assert result[0] == record
345+
346+
# Update without insert_only - should update all fields including created_at
347+
updated_record = Test(1, "Alice Updated", 10, "2023-01-02")
348+
was_updated = await updated_record.upsert(conn)
349+
assert was_updated # Should be True for update
350+
351+
result = await Test.select(conn, where=sql("id = {}", 1))
352+
assert len(result) == 1
353+
assert result[0] == updated_record
354+
355+
# Update with insert_only - should not update created_at
356+
final_record = Test(1, "Alice Final", 15, "2023-01-03")
357+
was_updated = await final_record.upsert(conn, insert_only={"created_at"})
358+
assert was_updated # Should be True for update
359+
360+
result = await Test.select(conn, where=sql("id = {}", 1))
361+
assert len(result) == 1
362+
# created_at should still be the old value, other fields should be updated
363+
expected = Test(1, "Alice Final", 15, "2023-01-02") # created_at unchanged
364+
assert result[0] == expected

tests/test_dataclasses.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -166,6 +166,37 @@ class Test(ModelBase, table_name="table", primary_key="id"):
166166
]
167167

168168

169+
def test_upsert_insert_only():
170+
@dataclass
171+
class Test(ModelBase, table_name="table", primary_key="id"):
172+
id: int
173+
foo: int
174+
bar: str
175+
created_at: str
176+
177+
t = Test(1, 42, "str", "2023-01-01")
178+
179+
# Test with insert_only parameter - created_at should be excluded from UPDATE
180+
assert list(t.upsert_sql(t.insert_sql(), exclude={"created_at"})) == [
181+
'INSERT INTO "table" ("id", "foo", "bar", "created_at") VALUES ($1, $2, $3, $4) '
182+
'ON CONFLICT ("id") DO UPDATE SET "foo"=EXCLUDED."foo", "bar"=EXCLUDED."bar"',
183+
1,
184+
42,
185+
"str",
186+
"2023-01-01",
187+
]
188+
189+
# Test with both exclude and insert_only-style exclude
190+
assert list(t.upsert_sql(t.insert_sql(), exclude={"bar", "created_at"})) == [
191+
'INSERT INTO "table" ("id", "foo", "bar", "created_at") VALUES ($1, $2, $3, $4) '
192+
'ON CONFLICT ("id") DO UPDATE SET "foo"=EXCLUDED."foo"',
193+
1,
194+
42,
195+
"str",
196+
"2023-01-01",
197+
]
198+
199+
169200
def test_serial():
170201
@dataclass
171202
class Test(ModelBase, table_name="table", primary_key="id"):

0 commit comments

Comments
 (0)