Skip to content

Commit de3f846

Browse files
authored
Pass keyword arguments to SQLAlchemy relationships in SQLModel relationships (#456)
1 parent 4772a34 commit de3f846

3 files changed

Lines changed: 133 additions & 58 deletions

File tree

CHANGES.rst

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,23 @@
11
Version history
22
===============
33

4+
**UNRELEASED**
5+
6+
- **BACKWARD INCOMPATIBLE** API changes (for those who customize code generation by
7+
subclassing the existing generators):
8+
9+
* Added new optional keyword argument, ``explicit_foreign_keys`` to
10+
``DeclarativeGenerator``, to force foreign keys to be rendered as
11+
``ClassName.attribute_name`` string references
12+
* Removed the ``render_relationship_args()`` method from the SQLModel generator
13+
* Added two new methods for customizing relationship rendering in
14+
``DeclarativeGenerator``:
15+
16+
* ``render_relationship_annotation()``: returns the appropriate type annotation
17+
(without the ``Mapped`` wrapper) for the relationship
18+
* ``render_relationship_arguments()``: returns a dictionary of keyword arguments to
19+
``sqlalchemy.orm.relationship()``
20+
421
**4.0.0rc3**
522

623
- **BACKWARD INCOMPATIBLE** Relationship names changed when multiple FKs or junction tables
@@ -14,6 +31,8 @@ Version history
1431
``students_enrollments``). Use ``--options nofknames`` to revert to old behavior. (PR by @sheinbergon)
1532
- Fixed ``Index`` kwargs (e.g. ``mysql_length``) being ignored during code generation
1633
(PR by @luliangce)
34+
- Fixed the SQLModel generator not adding the ``foreign_keys`` parameters when
35+
generating multiple relationships between the same two tables
1736

1837
**4.0.0rc2**
1938

src/sqlacodegen/generators.py

Lines changed: 53 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
import sys
66
from abc import ABCMeta, abstractmethod
77
from collections import defaultdict
8-
from collections.abc import Collection, Iterable, Sequence
8+
from collections.abc import Collection, Iterable, Mapping, Sequence
99
from dataclasses import dataclass
1010
from importlib import import_module
1111
from inspect import Parameter
@@ -1001,10 +1001,12 @@ def __init__(
10011001
*,
10021002
indentation: str = " ",
10031003
base_class_name: str = "Base",
1004+
explicit_foreign_keys: bool = False,
10041005
):
10051006
super().__init__(metadata, bind, options, indentation=indentation)
10061007
self.base_class_name: str = base_class_name
10071008
self.inflect_engine = inflect.engine()
1009+
self.explicit_foreign_keys = explicit_foreign_keys
10081010

10091011
def generate_base(self) -> None:
10101012
self.base = Base(
@@ -1626,6 +1628,33 @@ def render_column_attribute(self, column_attr: ColumnAttribute) -> str:
16261628
return f"{column_attr.name}: Mapped[{rendered_column_python_type}] = {rendered_column}"
16271629

16281630
def render_relationship(self, relationship: RelationshipAttribute) -> str:
1631+
kwargs = self.render_relationship_arguments(relationship)
1632+
annotation = self.render_relationship_annotation(relationship)
1633+
rendered_relationship = render_callable(
1634+
"relationship", repr(relationship.target.name), kwargs=kwargs
1635+
)
1636+
return f"{relationship.name}: Mapped[{annotation}] = {rendered_relationship}"
1637+
1638+
def render_relationship_annotation(
1639+
self, relationship: RelationshipAttribute
1640+
) -> str:
1641+
match relationship.type:
1642+
case RelationshipType.ONE_TO_MANY:
1643+
return f"list[{relationship.target.name!r}]"
1644+
case RelationshipType.ONE_TO_ONE | RelationshipType.MANY_TO_ONE:
1645+
if relationship.constraint and any(
1646+
col.nullable for col in relationship.constraint.columns
1647+
):
1648+
self.add_literal_import("typing", "Optional")
1649+
return f"Optional[{relationship.target.name!r}]"
1650+
else:
1651+
return f"'{relationship.target.name}'"
1652+
case RelationshipType.MANY_TO_MANY:
1653+
return f"list[{relationship.target.name!r}]"
1654+
1655+
def render_relationship_arguments(
1656+
self, relationship: RelationshipAttribute
1657+
) -> Mapping[str, Any]:
16291658
def render_column_attrs(column_attrs: list[ColumnAttribute]) -> str:
16301659
rendered = []
16311660
for attr in column_attrs:
@@ -1641,7 +1670,7 @@ def render_foreign_keys(column_attrs: list[ColumnAttribute]) -> str:
16411670
render_as_string = False
16421671
# Assume that column_attrs are all in relationship.source or none
16431672
for attr in column_attrs:
1644-
if attr.model is relationship.source:
1673+
if not self.explicit_foreign_keys and attr.model is relationship.source:
16451674
rendered.append(attr.name)
16461675
else:
16471676
rendered.append(f"{attr.model.name}.{attr.name}")
@@ -1697,33 +1726,7 @@ def render_join(terms: list[JoinType]) -> str:
16971726
if relationship.backref:
16981727
kwargs["back_populates"] = repr(relationship.backref.name)
16991728

1700-
rendered_relationship = render_callable(
1701-
"relationship", repr(relationship.target.name), kwargs=kwargs
1702-
)
1703-
1704-
relationship_type: str
1705-
if relationship.type == RelationshipType.ONE_TO_MANY:
1706-
relationship_type = f"list['{relationship.target.name}']"
1707-
elif relationship.type in (
1708-
RelationshipType.ONE_TO_ONE,
1709-
RelationshipType.MANY_TO_ONE,
1710-
):
1711-
relationship_type = f"'{relationship.target.name}'"
1712-
if relationship.constraint and any(
1713-
col.nullable for col in relationship.constraint.columns
1714-
):
1715-
self.add_literal_import("typing", "Optional")
1716-
relationship_type = f"Optional[{relationship_type}]"
1717-
elif relationship.type == RelationshipType.MANY_TO_MANY:
1718-
relationship_type = f"list['{relationship.target.name}']"
1719-
else:
1720-
self.add_literal_import("typing", "Any")
1721-
relationship_type = "Any"
1722-
1723-
return (
1724-
f"{relationship.name}: Mapped[{relationship_type}] "
1725-
f"= {rendered_relationship}"
1726-
)
1729+
return kwargs
17271730

17281731

17291732
class DataclassGenerator(DeclarativeGenerator):
@@ -1778,6 +1781,7 @@ def __init__(
17781781
options,
17791782
indentation=indentation,
17801783
base_class_name=base_class_name,
1784+
explicit_foreign_keys=True,
17811785
)
17821786

17831787
@property
@@ -1858,34 +1862,26 @@ def render_column_attribute(self, column_attr: ColumnAttribute) -> str:
18581862
return f"{column_attr.name}: {rendered_column_python_type} = {rendered_field}"
18591863

18601864
def render_relationship(self, relationship: RelationshipAttribute) -> str:
1861-
rendered = super().render_relationship(relationship).partition(" = ")[2]
1862-
args = self.render_relationship_args(rendered)
1863-
kwargs: dict[str, Any] = {}
1864-
annotation = repr(relationship.target.name)
1865+
kwargs = self.render_relationship_arguments(relationship)
1866+
annotation = self.render_relationship_annotation(relationship)
1867+
1868+
native_kwargs: dict[str, Any] = {}
1869+
non_native_kwargs: dict[str, Any] = {}
1870+
for key, value in kwargs.items():
1871+
# The following keyword arguments are natively supported in Relationship
1872+
if key in ("back_populates", "cascade_delete", "passive_deletes"):
1873+
native_kwargs[key] = value
1874+
else:
1875+
non_native_kwargs[key] = value
18651876

1866-
if relationship.type in (
1867-
RelationshipType.ONE_TO_MANY,
1868-
RelationshipType.MANY_TO_MANY,
1869-
):
1870-
annotation = f"list[{annotation}]"
1871-
else:
1872-
self.add_literal_import("typing", "Optional")
1873-
annotation = f"Optional[{annotation}]"
1877+
if non_native_kwargs:
1878+
native_kwargs["sa_relationship_kwargs"] = (
1879+
"{"
1880+
+ ", ".join(
1881+
f"{key!r}: {value}" for key, value in non_native_kwargs.items()
1882+
)
1883+
+ "}"
1884+
)
18741885

1875-
rendered_field = render_callable("Relationship", *args, kwargs=kwargs)
1886+
rendered_field = render_callable("Relationship", kwargs=native_kwargs)
18761887
return f"{relationship.name}: {annotation} = {rendered_field}"
1877-
1878-
def render_relationship_args(self, arguments: str) -> list[str]:
1879-
argument_list = arguments.split(",")
1880-
# delete ')' and ' ' from args
1881-
argument_list[-1] = argument_list[-1][:-1]
1882-
argument_list = [argument[1:] for argument in argument_list]
1883-
1884-
rendered_args: list[str] = []
1885-
for arg in argument_list:
1886-
if "back_populates" in arg:
1887-
rendered_args.append(arg)
1888-
if "uselist=False" in arg:
1889-
rendered_args.append("sa_relationship_kwargs={'uselist': False}")
1890-
1891-
return rendered_args

tests/test_generator_sqlmodel.py

Lines changed: 61 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -142,6 +142,66 @@ class SimpleGoods(SQLModel, table=True):
142142
)
143143

144144

145+
def test_onetomany_multiref(generator: CodeGenerator) -> None:
146+
Table(
147+
"simple_items_multiref",
148+
generator.metadata,
149+
Column("id", INTEGER, primary_key=True),
150+
Column("parent_container_id", INTEGER),
151+
Column("top_container_id", INTEGER, nullable=False),
152+
ForeignKeyConstraint(
153+
["parent_container_id"], ["simple_containers_multiref.id"]
154+
),
155+
ForeignKeyConstraint(["top_container_id"], ["simple_containers_multiref.id"]),
156+
)
157+
Table(
158+
"simple_containers_multiref",
159+
generator.metadata,
160+
Column("id", INTEGER, primary_key=True),
161+
)
162+
163+
validate_code(
164+
generator.generate(),
165+
"""\
166+
from typing import Optional
167+
168+
from sqlalchemy import Column, ForeignKey, Integer
169+
from sqlmodel import Field, Relationship, SQLModel
170+
171+
class SimpleContainersMultiref(SQLModel, table=True):
172+
__tablename__ = 'simple_containers_multiref'
173+
174+
id: int = Field(sa_column=Column('id', Integer, primary_key=True))
175+
176+
simple_items_multiref_parent_container: list['SimpleItemsMultiref'] = \
177+
Relationship(back_populates='parent_container', sa_relationship_kwargs={\
178+
'foreign_keys': '[SimpleItemsMultiref.parent_container_id]'})
179+
simple_items_multiref_top_container: list['SimpleItemsMultiref'] = \
180+
Relationship(back_populates='top_container', sa_relationship_kwargs={'foreign_keys': \
181+
'[SimpleItemsMultiref.top_container_id]'})
182+
183+
184+
class SimpleItemsMultiref(SQLModel, table=True):
185+
__tablename__ = 'simple_items_multiref'
186+
187+
id: int = Field(sa_column=Column('id', Integer, primary_key=True))
188+
top_container_id: int = \
189+
Field(sa_column=Column('top_container_id', \
190+
ForeignKey('simple_containers_multiref.id'), nullable=False))
191+
parent_container_id: Optional[int] = \
192+
Field(default=None, sa_column=Column('parent_container_id', \
193+
ForeignKey('simple_containers_multiref.id')))
194+
195+
parent_container: Optional['SimpleContainersMultiref'] = Relationship(\
196+
back_populates='simple_items_multiref_parent_container', sa_relationship_kwargs={\
197+
'foreign_keys': '[SimpleItemsMultiref.parent_container_id]'})
198+
top_container: 'SimpleContainersMultiref' = Relationship(\
199+
back_populates='simple_items_multiref_top_container', sa_relationship_kwargs={\
200+
'foreign_keys': '[SimpleItemsMultiref.top_container_id]'})
201+
""",
202+
)
203+
204+
145205
def test_onetoone(generator: CodeGenerator) -> None:
146206
Table(
147207
"simple_onetoone",
@@ -167,7 +227,7 @@ class OtherItems(SQLModel, table=True):
167227
id: int = Field(sa_column=Column('id', Integer, primary_key=True))
168228
169229
simple_onetoone: Optional['SimpleOnetoone'] = Relationship(\
170-
sa_relationship_kwargs={'uselist': False}, back_populates='other_item')
230+
back_populates='other_item', sa_relationship_kwargs={'uselist': False})
171231
172232
173233
class SimpleOnetoone(SQLModel, table=True):

0 commit comments

Comments
 (0)