55import sys
66from abc import ABCMeta , abstractmethod
77from collections import defaultdict
8- from collections .abc import Collection , Iterable , Sequence
8+ from collections .abc import Collection , Iterable , Mapping , Sequence
99from dataclasses import dataclass
1010from importlib import import_module
1111from inspect import Parameter
@@ -1001,10 +1001,12 @@ def __init__(
10011001 * ,
10021002 indentation : str = " " ,
10031003 base_class_name : str = "Base" ,
1004+ explicit_foreign_keys : bool = False ,
10041005 ):
10051006 super ().__init__ (metadata , bind , options , indentation = indentation )
10061007 self .base_class_name : str = base_class_name
10071008 self .inflect_engine = inflect .engine ()
1009+ self .explicit_foreign_keys = explicit_foreign_keys
10081010
10091011 def generate_base (self ) -> None :
10101012 self .base = Base (
@@ -1626,6 +1628,33 @@ def render_column_attribute(self, column_attr: ColumnAttribute) -> str:
16261628 return f"{ column_attr .name } : Mapped[{ rendered_column_python_type } ] = { rendered_column } "
16271629
16281630 def render_relationship (self , relationship : RelationshipAttribute ) -> str :
1631+ kwargs = self .render_relationship_arguments (relationship )
1632+ annotation = self .render_relationship_annotation (relationship )
1633+ rendered_relationship = render_callable (
1634+ "relationship" , repr (relationship .target .name ), kwargs = kwargs
1635+ )
1636+ return f"{ relationship .name } : Mapped[{ annotation } ] = { rendered_relationship } "
1637+
1638+ def render_relationship_annotation (
1639+ self , relationship : RelationshipAttribute
1640+ ) -> str :
1641+ match relationship .type :
1642+ case RelationshipType .ONE_TO_MANY :
1643+ return f"list[{ relationship .target .name !r} ]"
1644+ case RelationshipType .ONE_TO_ONE | RelationshipType .MANY_TO_ONE :
1645+ if relationship .constraint and any (
1646+ col .nullable for col in relationship .constraint .columns
1647+ ):
1648+ self .add_literal_import ("typing" , "Optional" )
1649+ return f"Optional[{ relationship .target .name !r} ]"
1650+ else :
1651+ return f"'{ relationship .target .name } '"
1652+ case RelationshipType .MANY_TO_MANY :
1653+ return f"list[{ relationship .target .name !r} ]"
1654+
1655+ def render_relationship_arguments (
1656+ self , relationship : RelationshipAttribute
1657+ ) -> Mapping [str , Any ]:
16291658 def render_column_attrs (column_attrs : list [ColumnAttribute ]) -> str :
16301659 rendered = []
16311660 for attr in column_attrs :
@@ -1641,7 +1670,7 @@ def render_foreign_keys(column_attrs: list[ColumnAttribute]) -> str:
16411670 render_as_string = False
16421671 # Assume that column_attrs are all in relationship.source or none
16431672 for attr in column_attrs :
1644- if attr .model is relationship .source :
1673+ if not self . explicit_foreign_keys and attr .model is relationship .source :
16451674 rendered .append (attr .name )
16461675 else :
16471676 rendered .append (f"{ attr .model .name } .{ attr .name } " )
@@ -1697,33 +1726,7 @@ def render_join(terms: list[JoinType]) -> str:
16971726 if relationship .backref :
16981727 kwargs ["back_populates" ] = repr (relationship .backref .name )
16991728
1700- rendered_relationship = render_callable (
1701- "relationship" , repr (relationship .target .name ), kwargs = kwargs
1702- )
1703-
1704- relationship_type : str
1705- if relationship .type == RelationshipType .ONE_TO_MANY :
1706- relationship_type = f"list['{ relationship .target .name } ']"
1707- elif relationship .type in (
1708- RelationshipType .ONE_TO_ONE ,
1709- RelationshipType .MANY_TO_ONE ,
1710- ):
1711- relationship_type = f"'{ relationship .target .name } '"
1712- if relationship .constraint and any (
1713- col .nullable for col in relationship .constraint .columns
1714- ):
1715- self .add_literal_import ("typing" , "Optional" )
1716- relationship_type = f"Optional[{ relationship_type } ]"
1717- elif relationship .type == RelationshipType .MANY_TO_MANY :
1718- relationship_type = f"list['{ relationship .target .name } ']"
1719- else :
1720- self .add_literal_import ("typing" , "Any" )
1721- relationship_type = "Any"
1722-
1723- return (
1724- f"{ relationship .name } : Mapped[{ relationship_type } ] "
1725- f"= { rendered_relationship } "
1726- )
1729+ return kwargs
17271730
17281731
17291732class DataclassGenerator (DeclarativeGenerator ):
@@ -1778,6 +1781,7 @@ def __init__(
17781781 options ,
17791782 indentation = indentation ,
17801783 base_class_name = base_class_name ,
1784+ explicit_foreign_keys = True ,
17811785 )
17821786
17831787 @property
@@ -1858,34 +1862,26 @@ def render_column_attribute(self, column_attr: ColumnAttribute) -> str:
18581862 return f"{ column_attr .name } : { rendered_column_python_type } = { rendered_field } "
18591863
18601864 def render_relationship (self , relationship : RelationshipAttribute ) -> str :
1861- rendered = super ().render_relationship (relationship ).partition (" = " )[2 ]
1862- args = self .render_relationship_args (rendered )
1863- kwargs : dict [str , Any ] = {}
1864- annotation = repr (relationship .target .name )
1865+ kwargs = self .render_relationship_arguments (relationship )
1866+ annotation = self .render_relationship_annotation (relationship )
1867+
1868+ native_kwargs : dict [str , Any ] = {}
1869+ non_native_kwargs : dict [str , Any ] = {}
1870+ for key , value in kwargs .items ():
1871+ # The following keyword arguments are natively supported in Relationship
1872+ if key in ("back_populates" , "cascade_delete" , "passive_deletes" ):
1873+ native_kwargs [key ] = value
1874+ else :
1875+ non_native_kwargs [key ] = value
18651876
1866- if relationship . type in (
1867- RelationshipType . ONE_TO_MANY ,
1868- RelationshipType . MANY_TO_MANY ,
1869- ):
1870- annotation = f"list[ { annotation } ]"
1871- else :
1872- self . add_literal_import ( "typing" , "Optional" )
1873- annotation = f"Optional[ { annotation } ]"
1877+ if non_native_kwargs :
1878+ native_kwargs [ "sa_relationship_kwargs" ] = (
1879+ "{"
1880+ + ", " . join (
1881+ f" { key !r } : { value } " for key , value in non_native_kwargs . items ()
1882+ )
1883+ + "}"
1884+ )
18741885
1875- rendered_field = render_callable ("Relationship" , * args , kwargs = kwargs )
1886+ rendered_field = render_callable ("Relationship" , kwargs = native_kwargs )
18761887 return f"{ relationship .name } : { annotation } = { rendered_field } "
1877-
1878- def render_relationship_args (self , arguments : str ) -> list [str ]:
1879- argument_list = arguments .split ("," )
1880- # delete ')' and ' ' from args
1881- argument_list [- 1 ] = argument_list [- 1 ][:- 1 ]
1882- argument_list = [argument [1 :] for argument in argument_list ]
1883-
1884- rendered_args : list [str ] = []
1885- for arg in argument_list :
1886- if "back_populates" in arg :
1887- rendered_args .append (arg )
1888- if "uselist=False" in arg :
1889- rendered_args .append ("sa_relationship_kwargs={'uselist': False}" )
1890-
1891- return rendered_args
0 commit comments