Skip to content

Commit 52b3b47

Browse files
committed
Add insert_only to ColumnInfo
1 parent ef5ef1c commit 52b3b47

2 files changed

Lines changed: 175 additions & 8 deletions

File tree

sql_athame/dataclasses.py

Lines changed: 58 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@ class ColumnInfo:
4545
constraints: Additional SQL constraints (e.g., 'UNIQUE', 'CHECK (value > 0)')
4646
serialize: Function to transform Python values before database storage
4747
deserialize: Function to transform database values back to Python objects
48+
insert_only: Whether this field should only be set on INSERT, not UPDATE in upsert operations
4849
4950
Example:
5051
>>> from dataclasses import dataclass
@@ -58,6 +59,7 @@ class ColumnInfo:
5859
... name: str
5960
... price: Annotated[float, ColumnInfo(constraints="CHECK (price > 0)")]
6061
... tags: Annotated[list, ColumnInfo(type="JSONB", serialize=json.dumps, deserialize=json.loads)]
62+
... created_at: Annotated[datetime, ColumnInfo(insert_only=True)]
6163
"""
6264

6365
type: Optional[str] = None
@@ -69,6 +71,7 @@ class ColumnInfo:
6971

7072
serialize: Optional[Callable[[Any], Any]] = None
7173
deserialize: Optional[Callable[[Any], Any]] = None
74+
insert_only: Optional[bool] = None
7275

7376
def __post_init__(self, constraints: Union[str, Iterable[str], None]) -> None:
7477
if constraints is not None:
@@ -94,6 +97,7 @@ def merge(a: "ColumnInfo", b: "ColumnInfo") -> "ColumnInfo":
9497
_constraints=(*a._constraints, *b._constraints),
9598
serialize=b.serialize if b.serialize is not None else a.serialize,
9699
deserialize=b.deserialize if b.deserialize is not None else a.deserialize,
100+
insert_only=b.insert_only if b.insert_only is not None else a.insert_only,
97101
)
98102

99103

@@ -113,6 +117,7 @@ class ConcreteColumnInfo:
113117
constraints: Tuple of SQL constraint strings
114118
serialize: Optional serialization function
115119
deserialize: Optional deserialization function
120+
insert_only: Whether this field should only be set on INSERT, not UPDATE
116121
"""
117122

118123
field: Field
@@ -123,6 +128,7 @@ class ConcreteColumnInfo:
123128
constraints: tuple[str, ...]
124129
serialize: Optional[Callable[[Any], Any]] = None
125130
deserialize: Optional[Callable[[Any], Any]] = None
131+
insert_only: bool = False
126132

127133
@staticmethod
128134
def from_column_info(
@@ -156,6 +162,7 @@ def from_column_info(
156162
constraints=info._constraints,
157163
serialize=info.serialize,
158164
deserialize=info.deserialize,
165+
insert_only=bool(info.insert_only),
159166
)
160167

161168
def create_table_string(self) -> str:
@@ -365,6 +372,20 @@ def field_names(cls, *, exclude: FieldNamesSet = ()) -> list[str]:
365372
if ci.field.name not in exclude
366373
]
367374

375+
@classmethod
376+
def insert_only_field_names(cls) -> set[str]:
377+
"""Get set of field names marked as insert_only in ColumnInfo.
378+
379+
Returns:
380+
Set of field names that should only be set on INSERT, not UPDATE
381+
"""
382+
return cls._cached(
383+
("insert_only_field_names",),
384+
lambda: {
385+
ci.field.name for ci in cls.column_info().values() if ci.insert_only
386+
},
387+
)
388+
368389
@classmethod
369390
def field_names_sql(
370391
cls,
@@ -874,9 +895,15 @@ async def upsert(
874895
>>> user = User(id=1, name="Alice", created_at=datetime.now())
875896
>>> # Only set created_at on INSERT, not UPDATE
876897
>>> was_updated = await user.upsert(pool, insert_only={'created_at'})
898+
899+
Note:
900+
Fields marked with ColumnInfo(insert_only=True) are automatically
901+
treated as insert-only and combined with the insert_only parameter.
877902
"""
903+
# Combine auto-detected insert_only fields with manual ones
904+
all_insert_only = self.insert_only_field_names() | set(insert_only)
878905
# Combine exclude and insert_only for the UPDATE clause
879-
update_exclude = set(exclude) | set(insert_only)
906+
update_exclude = set(exclude) | all_insert_only
880907
query = sql(
881908
"{} RETURNING xmax",
882909
self.upsert_sql(self.insert_sql(exclude=exclude), exclude=update_exclude),
@@ -1193,19 +1220,26 @@ async def upsert_multiple(
11931220
11941221
Example:
11951222
>>> await User.upsert_multiple(pool, users, insert_only={'created_at'})
1223+
1224+
Note:
1225+
Fields marked with ColumnInfo(insert_only=True) are automatically
1226+
treated as insert-only and combined with the insert_only parameter.
11961227
"""
1228+
# Combine auto-detected insert_only fields with manual ones
1229+
all_insert_only = cls.insert_only_field_names() | set(insert_only)
1230+
11971231
if cls.insert_multiple_mode == "executemany":
11981232
await cls.upsert_multiple_executemany(
1199-
connection_or_pool, rows, insert_only=insert_only
1233+
connection_or_pool, rows, insert_only=all_insert_only
12001234
)
12011235
return "INSERT"
12021236
elif cls.insert_multiple_mode == "array_safe":
12031237
return await cls.upsert_multiple_array_safe(
1204-
connection_or_pool, rows, insert_only=insert_only
1238+
connection_or_pool, rows, insert_only=all_insert_only
12051239
)
12061240
else:
12071241
return await cls.upsert_multiple_unnest(
1208-
connection_or_pool, rows, insert_only=insert_only
1242+
connection_or_pool, rows, insert_only=all_insert_only
12091243
)
12101244

12111245
@classmethod
@@ -1259,8 +1293,14 @@ async def plan_replace_multiple(
12591293
... conn, new_users, where=sql("department_id = {}", dept_id)
12601294
... )
12611295
>>> print(f"Will create {len(plan.created)}, update {len(plan.updated)}, delete {len(plan.deleted)}")
1296+
1297+
Note:
1298+
Fields marked with ColumnInfo(insert_only=True) are automatically
1299+
treated as insert-only and combined with the insert_only parameter.
12621300
"""
1263-
ignore = sorted(set(ignore) | set(insert_only))
1301+
# Combine auto-detected insert_only fields with manual ones
1302+
all_insert_only = cls.insert_only_field_names() | set(insert_only)
1303+
ignore = sorted(set(ignore) | all_insert_only)
12641304
equal_ignoring = cls._cached(
12651305
("equal_ignoring", tuple(ignore)),
12661306
lambda: cls._get_equal_ignoring_fn(ignore),
@@ -1283,7 +1323,7 @@ async def plan_replace_multiple(
12831323

12841324
created = list(pending.values())
12851325

1286-
return ReplaceMultiplePlan(cls, insert_only, created, updated, deleted)
1326+
return ReplaceMultiplePlan(cls, all_insert_only, created, updated, deleted)
12871327

12881328
@classmethod
12891329
async def replace_multiple(
@@ -1315,6 +1355,10 @@ async def replace_multiple(
13151355
>>> created, updated, deleted = await User.replace_multiple(
13161356
... conn, new_users, where=sql("department_id = {}", dept_id)
13171357
... )
1358+
1359+
Note:
1360+
Fields marked with ColumnInfo(insert_only=True) are automatically
1361+
treated as insert-only and combined with the insert_only parameter.
13181362
"""
13191363
plan = await cls.plan_replace_multiple(
13201364
connection, rows, where=where, ignore=ignore, insert_only=insert_only
@@ -1380,8 +1424,14 @@ async def replace_multiple_reporting_differences(
13801424
... )
13811425
>>> for old, new, fields in updates:
13821426
... print(f"Updated {old.name}: changed {', '.join(fields)}")
1427+
1428+
Note:
1429+
Fields marked with ColumnInfo(insert_only=True) are automatically
1430+
treated as insert-only and combined with the insert_only parameter.
13831431
"""
1384-
ignore = sorted(set(ignore) | set(insert_only))
1432+
# Combine auto-detected insert_only fields with manual ones
1433+
all_insert_only = cls.insert_only_field_names() | set(insert_only)
1434+
ignore = sorted(set(ignore) | all_insert_only)
13851435
differences_ignoring = cls._cached(
13861436
("differences_ignoring", tuple(ignore)),
13871437
lambda: cls._get_differences_ignoring_fn(ignore),
@@ -1410,7 +1460,7 @@ async def replace_multiple_reporting_differences(
14101460
await cls.upsert_multiple(
14111461
connection,
14121462
(*created, *(t[1] for t in updated_triples)),
1413-
insert_only=insert_only,
1463+
insert_only=all_insert_only,
14141464
)
14151465
if deleted:
14161466
await cls.delete_multiple(connection, deleted)

tests/test_dataclasses.py

Lines changed: 117 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -245,3 +245,120 @@ class Test(ModelBase, table_name="table"):
245245
assert Test.from_prepended_mapping(
246246
{"p_foo": "FOO", "p_bar": "BAR", "foo": "not foo", "other": "other"}, "p_"
247247
) == Test("foo", "BAR")
248+
249+
250+
def test_insert_only_column_info():
251+
"""Test that ColumnInfo(insert_only=True) works correctly."""
252+
253+
@dataclass
254+
class Test(ModelBase, table_name="table", primary_key="id"):
255+
id: int
256+
name: str
257+
created_at: Annotated[str, ColumnInfo(insert_only=True)]
258+
updated_at: str
259+
260+
# Test that insert_only fields are detected
261+
insert_only_fields = Test.insert_only_field_names()
262+
assert insert_only_fields == {"created_at"}
263+
264+
# Test that upsert SQL excludes insert_only fields from UPDATE when explicitly passed
265+
test_instance = Test(1, "Alice", "2023-01-01", "2023-01-02")
266+
insert_sql = test_instance.insert_sql()
267+
268+
# Pass the insert_only fields to upsert_sql to exclude them from UPDATE
269+
upsert_sql = Test.upsert_sql(insert_sql, exclude=Test.insert_only_field_names())
270+
271+
# The upsert should include created_at in INSERT but exclude it from UPDATE
272+
upsert_query = upsert_sql.query()[0]
273+
assert "created_at" in upsert_query # Should be in INSERT part
274+
# The UPDATE SET clause should only include name and updated_at
275+
assert '"name"=EXCLUDED."name"' in upsert_query
276+
assert '"updated_at"=EXCLUDED."updated_at"' in upsert_query
277+
assert '"created_at"=EXCLUDED."created_at"' not in upsert_query
278+
279+
280+
def test_insert_only_automatic_handling():
281+
"""Test that the upsert() method automatically handles insert_only fields."""
282+
283+
@dataclass
284+
class Test(ModelBase, table_name="table", primary_key="id"):
285+
id: int
286+
name: str
287+
created_at: Annotated[str, ColumnInfo(insert_only=True)]
288+
updated_at: str
289+
290+
# Test automatic handling by checking the generated SQL from the upsert method
291+
test_instance = Test(1, "Alice", "2023-01-01", "2023-01-02")
292+
293+
# Get the SQL that would be generated for upsert operation
294+
# We simulate what happens inside the upsert method
295+
all_insert_only = test_instance.insert_only_field_names()
296+
insert_sql = test_instance.insert_sql()
297+
upsert_sql = Test.upsert_sql(insert_sql, exclude=all_insert_only)
298+
299+
upsert_query = upsert_sql.query()[0]
300+
301+
# created_at should be excluded from UPDATE clause automatically
302+
assert '"name"=EXCLUDED."name"' in upsert_query
303+
assert '"updated_at"=EXCLUDED."updated_at"' in upsert_query
304+
assert '"created_at"=EXCLUDED."created_at"' not in upsert_query
305+
306+
307+
def test_insert_only_merge_with_manual():
308+
"""Test that ColumnInfo insert_only merges with manual insert_only parameter."""
309+
310+
@dataclass
311+
class Test(ModelBase, table_name="table", primary_key="id"):
312+
id: int
313+
name: str
314+
created_at: Annotated[str, ColumnInfo(insert_only=True)] # Auto insert-only
315+
updated_at: str
316+
version: int
317+
318+
# Verify auto-detected fields
319+
assert Test.insert_only_field_names() == {"created_at"}
320+
321+
# Test combining auto-detected with manual insert_only
322+
test_instance = Test(1, "Alice", "2023-01-01", "2023-01-02", 1)
323+
324+
# Simulate what happens in upsert methods - combine auto and manual
325+
auto_insert_only = Test.insert_only_field_names()
326+
manual_insert_only = {"version"}
327+
all_insert_only = auto_insert_only | manual_insert_only
328+
329+
insert_sql = test_instance.insert_sql()
330+
upsert_sql = Test.upsert_sql(insert_sql, exclude=all_insert_only)
331+
upsert_query = upsert_sql.query()[0]
332+
333+
# Both created_at (auto) and version (manual) should be excluded from UPDATE
334+
assert '"name"=EXCLUDED."name"' in upsert_query
335+
assert '"updated_at"=EXCLUDED."updated_at"' in upsert_query
336+
assert '"created_at"=EXCLUDED."created_at"' not in upsert_query
337+
assert '"version"=EXCLUDED."version"' not in upsert_query
338+
339+
340+
def test_column_info_merge_insert_only():
341+
"""Test that ColumnInfo.merge handles insert_only properly."""
342+
343+
base_info = ColumnInfo(type="TEXT")
344+
insert_only_info = ColumnInfo(insert_only=True)
345+
346+
# Test merging - insert_only should be preserved
347+
merged = ColumnInfo.merge(base_info, insert_only_info)
348+
assert merged.insert_only is True
349+
assert merged.type == "TEXT"
350+
351+
# Test merging the other way
352+
merged2 = ColumnInfo.merge(insert_only_info, base_info)
353+
assert merged2.insert_only is True # Should remain True
354+
assert merged2.type == "TEXT"
355+
356+
# Test with both having insert_only set
357+
both_false = ColumnInfo(type="INTEGER", insert_only=False)
358+
merged3 = ColumnInfo.merge(both_false, insert_only_info)
359+
assert merged3.insert_only is True # True should take precedence
360+
361+
# Test with None (default)
362+
none_info = ColumnInfo(type="BIGINT")
363+
merged4 = ColumnInfo.merge(none_info, insert_only_info)
364+
assert merged4.insert_only is True

0 commit comments

Comments
 (0)