Skip to content

Commit 662add1

Browse files
febus982claude
andcommitted
Extract shared primary key utility function
Consolidates duplicated PK inspection logic from BaseRepository._model_pk() and result_presenters._pk_from_result_object() into a single get_model_pk_name() function in common.py. Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
1 parent e6dee27 commit 662add1

4 files changed

Lines changed: 23 additions & 21 deletions

File tree

sqlalchemy_bind_manager/_repository/base_repository.py

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@
3131
Union,
3232
)
3333

34-
from sqlalchemy import asc, desc, func, inspect, select
34+
from sqlalchemy import asc, desc, func, select
3535
from sqlalchemy.orm import Mapper, aliased, class_mapper, lazyload
3636
from sqlalchemy.orm.exc import UnmappedClassError
3737
from sqlalchemy.sql import Select
@@ -41,6 +41,7 @@
4141
from .common import (
4242
MODEL,
4343
CursorReference,
44+
get_model_pk_name,
4445
)
4546

4647

@@ -342,11 +343,7 @@ def _model_pk(self) -> str:
342343
343344
:return:
344345
"""
345-
primary_keys = inspect(self._model).primary_key # type: ignore
346-
if len(primary_keys) > 1:
347-
raise NotImplementedError("Composite primary keys are not supported.")
348-
349-
return primary_keys[0].name
346+
return get_model_pk_name(self._model)
350347

351348
def _fail_if_invalid_models(self, objects: Iterable[MODEL]) -> None:
352349
if any(not isinstance(x, self._model) for x in objects):

sqlalchemy_bind_manager/_repository/common.py

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,15 +18,29 @@
1818
# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
1919
# DEALINGS IN THE SOFTWARE.
2020

21-
from typing import Generic, List, TypeVar, Union
21+
from typing import Generic, List, Type, TypeVar, Union
2222
from uuid import UUID
2323

2424
from pydantic import BaseModel, StrictInt, StrictStr
25+
from sqlalchemy import inspect
2526

2627
MODEL = TypeVar("MODEL")
2728
PRIMARY_KEY = Union[str, int, tuple, dict, UUID]
2829

2930

31+
def get_model_pk_name(model_class: Type) -> str:
32+
"""Retrieves the primary key column name from a SQLAlchemy model class.
33+
34+
:param model_class: A SQLAlchemy model class
35+
:return: The name of the primary key column
36+
:raises NotImplementedError: If the model has composite primary keys
37+
"""
38+
primary_keys = inspect(model_class).primary_key # type: ignore
39+
if len(primary_keys) > 1:
40+
raise NotImplementedError("Composite primary keys are not supported.")
41+
return primary_keys[0].name
42+
43+
3044
class PageInfo(BaseModel):
3145
"""
3246
Paginated query metadata.

sqlalchemy_bind_manager/_repository/result_presenters.py

Lines changed: 2 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -21,15 +21,14 @@
2121
from math import ceil
2222
from typing import List, Union
2323

24-
from sqlalchemy import inspect
25-
2624
from .common import (
2725
MODEL,
2826
CursorPageInfo,
2927
CursorPaginatedResult,
3028
CursorReference,
3129
PageInfo,
3230
PaginatedResult,
31+
get_model_pk_name,
3332
)
3433

3534

@@ -93,7 +92,7 @@ def _build_no_cursor_result(
9392
has_next_page = len(result_items) > items_per_page
9493
if has_next_page:
9594
result_items = result_items[0:items_per_page]
96-
reference_column = _pk_from_result_object(result_items[0])
95+
reference_column = get_model_pk_name(type(result_items[0]))
9796

9897
return CursorPaginatedResult(
9998
items=result_items,
@@ -237,11 +236,3 @@ def build_result(
237236
has_previous_page=has_previous_page,
238237
),
239238
)
240-
241-
242-
def _pk_from_result_object(model) -> str:
243-
primary_keys = inspect(type(model)).primary_key # type: ignore
244-
if len(primary_keys) > 1:
245-
raise NotImplementedError("Composite primary keys are not supported.")
246-
247-
return primary_keys[0].name

tests/repository/result_presenters/test_composite_pk.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,15 +2,15 @@
22

33
import pytest
44

5-
from sqlalchemy_bind_manager._repository.result_presenters import _pk_from_result_object
5+
from sqlalchemy_bind_manager._repository.common import get_model_pk_name
66

77

88
def test_exception_raised_if_multiple_primary_keys():
99
with (
1010
patch(
11-
"sqlalchemy_bind_manager._repository.result_presenters.inspect",
11+
"sqlalchemy_bind_manager._repository.common.inspect",
1212
return_value=Mock(primary_key=["1", "2"]),
1313
),
1414
pytest.raises(NotImplementedError),
1515
):
16-
_pk_from_result_object("irrelevant")
16+
get_model_pk_name(str)

0 commit comments

Comments
 (0)