Skip to content

Commit 36503ca

Browse files
committed
Make upserts honor insert_only ColumnInfo, add force_update to override
1 parent 52b3b47 commit 36503ca

3 files changed

Lines changed: 226 additions & 40 deletions

File tree

sql_athame/dataclasses.py

Lines changed: 101 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -844,12 +844,18 @@ async def insert(
844844
return await connection_or_pool.execute(*self.insert_sql(exclude))
845845

846846
@classmethod
847-
def upsert_sql(cls, insert_sql: Fragment, exclude: FieldNamesSet = ()) -> Fragment:
847+
def upsert_sql(
848+
cls,
849+
insert_sql: Fragment,
850+
exclude: FieldNamesSet = (),
851+
force_update: FieldNamesSet = (),
852+
) -> Fragment:
848853
"""Generate UPSERT (INSERT ... ON CONFLICT DO UPDATE) SQL.
849854
850855
Args:
851856
insert_sql: Base INSERT statement Fragment
852857
exclude: Field names to exclude from the UPDATE clause
858+
force_update: Field names to force include in UPDATE clause, overriding insert_only settings
853859
854860
Returns:
855861
Fragment containing INSERT ... ON CONFLICT DO UPDATE statement
@@ -858,17 +864,27 @@ def upsert_sql(cls, insert_sql: Fragment, exclude: FieldNamesSet = ()) -> Fragme
858864
>>> insert = user.insert_sql()
859865
>>> list(User.upsert_sql(insert))
860866
['INSERT INTO "users" ("name", "email") VALUES ($1, $2) ON CONFLICT ("id") DO UPDATE SET "name"=EXCLUDED."name", "email"=EXCLUDED."email"', 'Alice', 'alice@example.com']
867+
868+
Note:
869+
Fields marked with ColumnInfo(insert_only=True) are automatically
870+
excluded from the UPDATE clause, unless overridden by force_update.
861871
"""
872+
# Combine exclude parameter with auto-detected insert_only fields, but remove force_update fields
873+
auto_insert_only = cls.insert_only_field_names() - set(force_update)
874+
manual_exclude = set(exclude) - set(
875+
force_update
876+
) # Remove force_update from manual excludes too
877+
all_exclude = manual_exclude | auto_insert_only
862878
cached = cls._cached(
863-
("upsert_sql", tuple(sorted(exclude))),
879+
("upsert_sql", tuple(sorted(all_exclude))),
864880
lambda: sql(
865881
" ON CONFLICT ({pks}) DO UPDATE SET {assignments}",
866882
insert_sql=insert_sql,
867883
pks=sql.list(cls.primary_key_names_sql()),
868884
assignments=sql.list(
869885
sql("{field}=EXCLUDED.{field}", field=x)
870886
for x in cls.field_names_sql(
871-
exclude=(*cls.primary_key_names, *exclude)
887+
exclude=(*cls.primary_key_names, *all_exclude)
872888
)
873889
),
874890
).flatten(),
@@ -880,13 +896,15 @@ async def upsert(
880896
connection_or_pool: Union[Connection, Pool],
881897
exclude: FieldNamesSet = (),
882898
insert_only: FieldNamesSet = (),
899+
force_update: FieldNamesSet = (),
883900
) -> bool:
884901
"""Insert or update this instance in the database.
885902
886903
Args:
887904
connection_or_pool: Database connection or pool
888905
exclude: Field names to exclude from the UPDATE clause
889906
insert_only: Field names that should only be set on INSERT, not UPDATE
907+
force_update: Field names to force include in UPDATE clause, overriding insert_only settings
890908
891909
Returns:
892910
True if the record was updated, False if it was inserted
@@ -895,18 +913,24 @@ async def upsert(
895913
>>> user = User(id=1, name="Alice", created_at=datetime.now())
896914
>>> # Only set created_at on INSERT, not UPDATE
897915
>>> was_updated = await user.upsert(pool, insert_only={'created_at'})
916+
>>> # Force update created_at even if it's marked insert_only in ColumnInfo
917+
>>> was_updated = await user.upsert(pool, force_update={'created_at'})
898918
899919
Note:
900920
Fields marked with ColumnInfo(insert_only=True) are automatically
901-
treated as insert-only and combined with the insert_only parameter.
921+
treated as insert-only and combined with the insert_only parameter,
922+
unless overridden by force_update.
902923
"""
903-
# Combine auto-detected insert_only fields with manual ones
904-
all_insert_only = self.insert_only_field_names() | set(insert_only)
905-
# Combine exclude and insert_only for the UPDATE clause
906-
update_exclude = set(exclude) | all_insert_only
924+
# upsert_sql automatically handles insert_only fields from ColumnInfo
925+
# We only need to combine manual insert_only with exclude for the UPDATE clause
926+
update_exclude = set(exclude) | set(insert_only)
907927
query = sql(
908928
"{} RETURNING xmax",
909-
self.upsert_sql(self.insert_sql(exclude=exclude), exclude=update_exclude),
929+
self.upsert_sql(
930+
self.insert_sql(exclude=exclude),
931+
exclude=update_exclude,
932+
force_update=force_update,
933+
),
910934
)
911935
result = await connection_or_pool.fetchrow(*query)
912936
return result["xmax"] != 0
@@ -1136,17 +1160,21 @@ async def upsert_multiple_executemany(
11361160
connection_or_pool: Union[Connection, Pool],
11371161
rows: Iterable[T],
11381162
insert_only: FieldNamesSet = (),
1163+
force_update: FieldNamesSet = (),
11391164
) -> None:
11401165
"""Bulk upsert using asyncpg's executemany.
11411166
11421167
Args:
11431168
connection_or_pool: Database connection or pool
11441169
rows: Model instances to upsert
11451170
insert_only: Field names that should only be set on INSERT, not UPDATE
1171+
force_update: Field names to force include in UPDATE clause, overriding insert_only settings
11461172
"""
11471173
args = [r.field_values() for r in rows]
11481174
query = cls.upsert_sql(
1149-
cls.insert_multiple_executemany_chunk_sql(1), exclude=insert_only
1175+
cls.insert_multiple_executemany_chunk_sql(1),
1176+
exclude=insert_only,
1177+
force_update=force_update,
11501178
).query()[0]
11511179
if args:
11521180
await connection_or_pool.executemany(query, args)
@@ -1157,19 +1185,25 @@ async def upsert_multiple_unnest(
11571185
connection_or_pool: Union[Connection, Pool],
11581186
rows: Iterable[T],
11591187
insert_only: FieldNamesSet = (),
1188+
force_update: FieldNamesSet = (),
11601189
) -> str:
11611190
"""Bulk upsert using PostgreSQL UNNEST.
11621191
11631192
Args:
11641193
connection_or_pool: Database connection or pool
11651194
rows: Model instances to upsert
11661195
insert_only: Field names that should only be set on INSERT, not UPDATE
1196+
force_update: Field names to force include in UPDATE clause, overriding insert_only settings
11671197
11681198
Returns:
11691199
Result string from the database operation
11701200
"""
11711201
return await connection_or_pool.execute(
1172-
*cls.upsert_sql(cls.insert_multiple_sql(rows), exclude=insert_only)
1202+
*cls.upsert_sql(
1203+
cls.insert_multiple_sql(rows),
1204+
exclude=insert_only,
1205+
force_update=force_update,
1206+
)
11731207
)
11741208

11751209
@classmethod
@@ -1178,6 +1212,7 @@ async def upsert_multiple_array_safe(
11781212
connection_or_pool: Union[Connection, Pool],
11791213
rows: Iterable[T],
11801214
insert_only: FieldNamesSet = (),
1215+
force_update: FieldNamesSet = (),
11811216
) -> str:
11821217
"""Bulk upsert using VALUES syntax with chunking.
11831218
@@ -1188,6 +1223,7 @@ async def upsert_multiple_array_safe(
11881223
connection_or_pool: Database connection or pool
11891224
rows: Model instances to upsert
11901225
insert_only: Field names that should only be set on INSERT, not UPDATE
1226+
force_update: Field names to force include in UPDATE clause, overriding insert_only settings
11911227
11921228
Returns:
11931229
Result string from the last chunk operation
@@ -1196,7 +1232,9 @@ async def upsert_multiple_array_safe(
11961232
for chunk in chunked(rows, 100):
11971233
last = await connection_or_pool.execute(
11981234
*cls.upsert_sql(
1199-
cls.insert_multiple_array_safe_sql(chunk), exclude=insert_only
1235+
cls.insert_multiple_array_safe_sql(chunk),
1236+
exclude=insert_only,
1237+
force_update=force_update,
12001238
)
12011239
)
12021240
return last
@@ -1207,39 +1245,52 @@ async def upsert_multiple(
12071245
connection_or_pool: Union[Connection, Pool],
12081246
rows: Iterable[T],
12091247
insert_only: FieldNamesSet = (),
1248+
force_update: FieldNamesSet = (),
12101249
) -> str:
12111250
"""Bulk upsert (INSERT ... ON CONFLICT DO UPDATE) multiple records.
12121251
12131252
Args:
12141253
connection_or_pool: Database connection or pool
12151254
rows: Model instances to upsert
12161255
insert_only: Field names that should only be set on INSERT, not UPDATE
1256+
force_update: Field names to force include in UPDATE clause, overriding insert_only settings
12171257
12181258
Returns:
12191259
Result string from the database operation
12201260
12211261
Example:
12221262
>>> await User.upsert_multiple(pool, users, insert_only={'created_at'})
1263+
>>> await User.upsert_multiple(pool, users, force_update={'created_at'})
12231264
12241265
Note:
12251266
Fields marked with ColumnInfo(insert_only=True) are automatically
1226-
treated as insert-only and combined with the insert_only parameter.
1267+
treated as insert-only and combined with the insert_only parameter,
1268+
unless overridden by force_update.
12271269
"""
1228-
# Combine auto-detected insert_only fields with manual ones
1229-
all_insert_only = cls.insert_only_field_names() | set(insert_only)
1270+
# upsert_sql automatically handles insert_only fields from ColumnInfo
1271+
# Pass manual insert_only parameter through to the specific implementations
12301272

12311273
if cls.insert_multiple_mode == "executemany":
12321274
await cls.upsert_multiple_executemany(
1233-
connection_or_pool, rows, insert_only=all_insert_only
1275+
connection_or_pool,
1276+
rows,
1277+
insert_only=insert_only,
1278+
force_update=force_update,
12341279
)
12351280
return "INSERT"
12361281
elif cls.insert_multiple_mode == "array_safe":
12371282
return await cls.upsert_multiple_array_safe(
1238-
connection_or_pool, rows, insert_only=all_insert_only
1283+
connection_or_pool,
1284+
rows,
1285+
insert_only=insert_only,
1286+
force_update=force_update,
12391287
)
12401288
else:
12411289
return await cls.upsert_multiple_unnest(
1242-
connection_or_pool, rows, insert_only=all_insert_only
1290+
connection_or_pool,
1291+
rows,
1292+
insert_only=insert_only,
1293+
force_update=force_update,
12431294
)
12441295

12451296
@classmethod
@@ -1272,6 +1323,7 @@ async def plan_replace_multiple(
12721323
where: Where,
12731324
ignore: FieldNamesSet = (),
12741325
insert_only: FieldNamesSet = (),
1326+
force_update: FieldNamesSet = (),
12751327
) -> "ReplaceMultiplePlan[T]":
12761328
"""Plan a replace operation by comparing new data with existing records.
12771329
@@ -1284,6 +1336,7 @@ async def plan_replace_multiple(
12841336
where: WHERE clause to limit which existing records to consider
12851337
ignore: Field names to ignore when comparing records
12861338
insert_only: Field names that should only be set on INSERT, not UPDATE
1339+
force_update: Field names to force include in UPDATE clause, overriding insert_only settings
12871340
12881341
Returns:
12891342
ReplaceMultiplePlan containing the planned operations
@@ -1296,9 +1349,10 @@ async def plan_replace_multiple(
12961349
12971350
Note:
12981351
Fields marked with ColumnInfo(insert_only=True) are automatically
1299-
treated as insert-only and combined with the insert_only parameter.
1352+
treated as insert-only and combined with the insert_only parameter,
1353+
unless overridden by force_update.
13001354
"""
1301-
# Combine auto-detected insert_only fields with manual ones
1355+
# For comparison purposes, combine auto-detected insert_only fields with manual ones
13021356
all_insert_only = cls.insert_only_field_names() | set(insert_only)
13031357
ignore = sorted(set(ignore) | all_insert_only)
13041358
equal_ignoring = cls._cached(
@@ -1323,7 +1377,11 @@ async def plan_replace_multiple(
13231377

13241378
created = list(pending.values())
13251379

1326-
return ReplaceMultiplePlan(cls, all_insert_only, created, updated, deleted)
1380+
# Pass only manual insert_only and force_update to the plan
1381+
# since upsert_multiple handles auto-detected ones
1382+
return ReplaceMultiplePlan(
1383+
cls, insert_only, force_update, created, updated, deleted
1384+
)
13271385

13281386
@classmethod
13291387
async def replace_multiple(
@@ -1334,6 +1392,7 @@ async def replace_multiple(
13341392
where: Where,
13351393
ignore: FieldNamesSet = (),
13361394
insert_only: FieldNamesSet = (),
1395+
force_update: FieldNamesSet = (),
13371396
) -> tuple[list[T], list[T], list[T]]:
13381397
"""Replace records in the database with the provided data.
13391398
@@ -1347,6 +1406,7 @@ async def replace_multiple(
13471406
where: WHERE clause to limit which existing records to consider for replacement
13481407
ignore: Field names to ignore when comparing records
13491408
insert_only: Field names that should only be set on INSERT, not UPDATE
1409+
force_update: Field names to force include in UPDATE clause, overriding insert_only settings
13501410
13511411
Returns:
13521412
Tuple of (created_records, updated_records, deleted_records)
@@ -1358,10 +1418,16 @@ async def replace_multiple(
13581418
13591419
Note:
13601420
Fields marked with ColumnInfo(insert_only=True) are automatically
1361-
treated as insert-only and combined with the insert_only parameter.
1421+
treated as insert-only and combined with the insert_only parameter,
1422+
unless overridden by force_update.
13621423
"""
13631424
plan = await cls.plan_replace_multiple(
1364-
connection, rows, where=where, ignore=ignore, insert_only=insert_only
1425+
connection,
1426+
rows,
1427+
where=where,
1428+
ignore=ignore,
1429+
insert_only=insert_only,
1430+
force_update=force_update,
13651431
)
13661432
await plan.execute(connection)
13671433
return plan.cud
@@ -1401,6 +1467,7 @@ async def replace_multiple_reporting_differences(
14011467
where: Where,
14021468
ignore: FieldNamesSet = (),
14031469
insert_only: FieldNamesSet = (),
1470+
force_update: FieldNamesSet = (),
14041471
) -> tuple[list[T], list[tuple[T, T, list[str]]], list[T]]:
14051472
"""Replace records and report the specific field differences for updates.
14061473
@@ -1413,6 +1480,7 @@ async def replace_multiple_reporting_differences(
14131480
where: WHERE clause to limit which existing records to consider
14141481
ignore: Field names to ignore when comparing records
14151482
insert_only: Field names that should only be set on INSERT, not UPDATE
1483+
force_update: Field names to force include in UPDATE clause, overriding insert_only settings
14161484
14171485
Returns:
14181486
Tuple of (created_records, update_triples, deleted_records)
@@ -1427,9 +1495,10 @@ async def replace_multiple_reporting_differences(
14271495
14281496
Note:
14291497
Fields marked with ColumnInfo(insert_only=True) are automatically
1430-
treated as insert-only and combined with the insert_only parameter.
1498+
treated as insert-only and combined with the insert_only parameter,
1499+
unless overridden by force_update.
14311500
"""
1432-
# Combine auto-detected insert_only fields with manual ones
1501+
# For comparison purposes, combine auto-detected insert_only fields with manual ones
14331502
all_insert_only = cls.insert_only_field_names() | set(insert_only)
14341503
ignore = sorted(set(ignore) | all_insert_only)
14351504
differences_ignoring = cls._cached(
@@ -1460,7 +1529,8 @@ async def replace_multiple_reporting_differences(
14601529
await cls.upsert_multiple(
14611530
connection,
14621531
(*created, *(t[1] for t in updated_triples)),
1463-
insert_only=all_insert_only,
1532+
insert_only=insert_only,
1533+
force_update=force_update,
14641534
)
14651535
if deleted:
14661536
await cls.delete_multiple(connection, deleted)
@@ -1472,6 +1542,7 @@ async def replace_multiple_reporting_differences(
14721542
class ReplaceMultiplePlan(Generic[T]):
14731543
model_class: type[T]
14741544
insert_only: FieldNamesSet
1545+
force_update: FieldNamesSet
14751546
created: list[T]
14761547
updated: list[T]
14771548
deleted: list[T]
@@ -1493,7 +1564,10 @@ async def execute_upserts(self, connection: Connection) -> None:
14931564
"""
14941565
if self.created or self.updated:
14951566
await self.model_class.upsert_multiple(
1496-
connection, (*self.created, *self.updated), insert_only=self.insert_only
1567+
connection,
1568+
(*self.created, *self.updated),
1569+
insert_only=self.insert_only,
1570+
force_update=self.force_update,
14971571
)
14981572

14991573
async def execute_deletes(self, connection: Connection) -> None:

0 commit comments

Comments
 (0)