Skip to content

Commit 975fff6

Browse files
committed
Fix upserts with only primary key
1 parent 69a948d commit 975fff6

2 files changed

Lines changed: 60 additions & 9 deletions

File tree

sql_athame/dataclasses.py

Lines changed: 21 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -875,19 +875,31 @@ def upsert_sql(
875875
force_update
876876
) # Remove force_update from manual insert_only too
877877
all_insert_only = manual_insert_only | auto_insert_only
878-
cached = cls._cached(
879-
("upsert_sql", tuple(sorted(all_insert_only))),
880-
lambda: sql(
881-
" ON CONFLICT ({pks}) DO UPDATE SET {assignments}",
878+
879+
def generate_upsert_fragment():
880+
updatable_fields = cls.field_names(
881+
exclude=(*cls.primary_key_names, *all_insert_only)
882+
)
883+
return sql(
884+
" ON CONFLICT ({pks}) DO {action}",
882885
insert_sql=insert_sql,
883886
pks=sql.list(cls.primary_key_names_sql()),
884-
assignments=sql.list(
885-
sql("{field}=EXCLUDED.{field}", field=x)
886-
for x in cls.field_names_sql(
887-
exclude=(*cls.primary_key_names, *all_insert_only)
887+
action=(
888+
sql(
889+
"UPDATE SET {assignments}",
890+
assignments=sql.list(
891+
sql("{field}=EXCLUDED.{field}", field=sql.identifier(field))
892+
for field in updatable_fields
893+
),
888894
)
895+
if updatable_fields
896+
else sql.literal("NOTHING")
889897
),
890-
).flatten(),
898+
).flatten()
899+
900+
cached = cls._cached(
901+
("upsert_sql", tuple(sorted(all_insert_only))),
902+
generate_upsert_fragment,
891903
)
892904
return Fragment([insert_sql, cached])
893905

tests/test_dataclasses.py

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -440,3 +440,42 @@ class Test(ModelBase, table_name="table", primary_key="id"):
440440
assert '"name"=EXCLUDED."name"' not in partial_query # Excluded manually
441441
assert '"updated_at"=EXCLUDED."updated_at"' in partial_query
442442
assert '"created_at"=EXCLUDED."created_at"' in partial_query # Force updated
443+
444+
445+
def test_primary_key_only_table_upsert():
446+
"""Test that upsert works correctly for tables with only primary key columns."""
447+
448+
@dataclass
449+
class PrimaryKeyOnly(ModelBase, table_name="pk_only", primary_key="id"):
450+
id: uuid.UUID
451+
452+
test_instance = PrimaryKeyOnly(uuid.uuid4())
453+
insert_sql = test_instance.insert_sql()
454+
455+
# Test that upsert generates valid SQL with DO NOTHING
456+
upsert_sql = PrimaryKeyOnly.upsert_sql(insert_sql)
457+
query, params = upsert_sql.query()
458+
459+
# Should contain ON CONFLICT DO NOTHING since there are no updatable fields
460+
assert "ON CONFLICT" in query
461+
assert "DO NOTHING" in query
462+
assert "DO UPDATE SET" not in query
463+
assert len(params) == 1 # Only the ID parameter
464+
465+
# Test with compound primary key (still no other fields)
466+
@dataclass
467+
class CompoundPrimaryKeyOnly(
468+
ModelBase, table_name="compound_pk_only", primary_key=("id1", "id2")
469+
):
470+
id1: int
471+
id2: str
472+
473+
compound_instance = CompoundPrimaryKeyOnly(1, "test")
474+
compound_insert = compound_instance.insert_sql()
475+
compound_upsert = CompoundPrimaryKeyOnly.upsert_sql(compound_insert)
476+
compound_query, compound_params = compound_upsert.query()
477+
478+
assert "ON CONFLICT" in compound_query
479+
assert "DO NOTHING" in compound_query
480+
assert "DO UPDATE SET" not in compound_query
481+
assert len(compound_params) == 2 # Both ID parameters

0 commit comments

Comments
 (0)