Skip to content

Commit de5adcc

Browse files
Properly support explict enum name/schema (#461)
* Properly support explict enum name/schema * Update src/sqlacodegen/generators.py Co-authored-by: Alex Grönholm <alex.gronholm@nextday.fi> * Update src/sqlacodegen/generators.py Co-authored-by: Alex Grönholm <alex.gronholm@nextday.fi> --------- Co-authored-by: Alex Grönholm <alex.gronholm@nextday.fi>
1 parent c35e4a2 commit de5adcc

5 files changed

Lines changed: 226 additions & 16 deletions

File tree

CHANGES.rst

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,11 @@
11
Version history
22
===============
33

4+
**4.0.1**
5+
6+
- Fix enum column definitions to explicitly include schema and name if reflected
7+
via SQLAlchemy's Metadata (pr by @sheinbergon)
8+
49
**4.0.0**
510

611
- **BACKWARD INCOMPATIBLE** API changes (for those who customize code generation by

src/sqlacodegen/generators.py

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -550,7 +550,14 @@ def render_column_type(self, column: Column[Any]) -> str:
550550
):
551551
# Import SQLAlchemy Enum (will be handled in collect_imports)
552552
self.add_import(Enum)
553-
return f"Enum({enum_class_name}, values_callable=lambda cls: [member.value for member in cls])"
553+
extra_kwargs = ""
554+
if column_type.name is not None:
555+
extra_kwargs += f", name={column_type.name!r}"
556+
557+
if column_type.schema is not None:
558+
extra_kwargs += f", schema={column_type.schema!r}"
559+
560+
return f"Enum({enum_class_name}, values_callable=lambda cls: [member.value for member in cls]{extra_kwargs})"
554561

555562
args = []
556563
kwargs: dict[str, Any] = {}
@@ -562,7 +569,14 @@ def render_column_type(self, column: Column[Any]) -> str:
562569
):
563570
self.add_import(ARRAY)
564571
self.add_import(Enum)
565-
rendered_enum = f"Enum({enum_class_name}, values_callable=lambda cls: [member.value for member in cls])"
572+
extra_kwargs = ""
573+
if column_type.item_type.name is not None:
574+
extra_kwargs += f", name={column_type.item_type.name!r}"
575+
576+
if column_type.item_type.schema is not None:
577+
extra_kwargs += f", schema={column_type.item_type.schema!r}"
578+
579+
rendered_enum = f"Enum({enum_class_name}, values_callable=lambda cls: [member.value for member in cls]{extra_kwargs})"
566580
if column_type.dimensions is not None:
567581
kwargs["dimensions"] = repr(column_type.dimensions)
568582

tests/test_generator_declarative.py

Lines changed: 87 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -2508,14 +2508,14 @@ class Accounts(Base):
25082508
__tablename__ = 'accounts'
25092509
25102510
id: Mapped[int] = mapped_column(Integer, primary_key=True)
2511-
status: Mapped[StatusEnum] = mapped_column(Enum(StatusEnum, values_callable=lambda cls: [member.value for member in cls]), nullable=False)
2511+
status: Mapped[StatusEnum] = mapped_column(Enum(StatusEnum, values_callable=lambda cls: [member.value for member in cls], name='status_enum'), nullable=False)
25122512
25132513
25142514
class Users(Base):
25152515
__tablename__ = 'users'
25162516
25172517
id: Mapped[int] = mapped_column(Integer, primary_key=True)
2518-
status: Mapped[StatusEnum] = mapped_column(Enum(StatusEnum, values_callable=lambda cls: [member.value for member in cls]), nullable=False)
2518+
status: Mapped[StatusEnum] = mapped_column(Enum(StatusEnum, values_callable=lambda cls: [member.value for member in cls], name='status_enum'), nullable=False)
25192519
""",
25202520
)
25212521

@@ -2851,7 +2851,7 @@ class Users(Base):
28512851
__tablename__ = 'users'
28522852
28532853
id: Mapped[int] = mapped_column(Integer, primary_key=True)
2854-
roles: Mapped[list[RoleEnum]] = mapped_column(ARRAY(Enum(RoleEnum, values_callable=lambda cls: [member.value for member in cls])), nullable=False)
2854+
roles: Mapped[list[RoleEnum]] = mapped_column(ARRAY(Enum(RoleEnum, values_callable=lambda cls: [member.value for member in cls], name='role_enum')), nullable=False)
28552855
""",
28562856
)
28572857

@@ -2927,7 +2927,7 @@ class Users(Base):
29272927
__tablename__ = 'users'
29282928
29292929
id: Mapped[int] = mapped_column(Integer, primary_key=True)
2930-
roles: Mapped[Optional[list[RoleEnum]]] = mapped_column(ARRAY(Enum(RoleEnum, values_callable=lambda cls: [member.value for member in cls])))
2930+
roles: Mapped[Optional[list[RoleEnum]]] = mapped_column(ARRAY(Enum(RoleEnum, values_callable=lambda cls: [member.value for member in cls], name='role_enum')))
29312931
""",
29322932
)
29332933

@@ -2965,7 +2965,7 @@ class Items(Base):
29652965
__tablename__ = 'items'
29662966
29672967
id: Mapped[int] = mapped_column(Integer, primary_key=True)
2968-
tag_matrix: Mapped[list[list[TagEnum]]] = mapped_column(ARRAY(Enum(TagEnum, values_callable=lambda cls: [member.value for member in cls]), dimensions=2), nullable=False)
2968+
tag_matrix: Mapped[list[list[TagEnum]]] = mapped_column(ARRAY(Enum(TagEnum, values_callable=lambda cls: [member.value for member in cls], name='tag_enum'), dimensions=2), nullable=False)
29692969
""",
29702970
)
29712971

@@ -3043,7 +3043,87 @@ class Users(Base):
30433043
__tablename__ = 'users'
30443044
30453045
id: Mapped[int] = mapped_column(Integer, primary_key=True)
3046-
primary_role: Mapped[RoleEnum] = mapped_column(Enum(RoleEnum, values_callable=lambda cls: [member.value for member in cls]), nullable=False)
3047-
all_roles: Mapped[list[RoleEnum]] = mapped_column(ARRAY(Enum(RoleEnum, values_callable=lambda cls: [member.value for member in cls])), nullable=False)
3046+
primary_role: Mapped[RoleEnum] = mapped_column(Enum(RoleEnum, values_callable=lambda cls: [member.value for member in cls], name='role_enum'), nullable=False)
3047+
all_roles: Mapped[list[RoleEnum]] = mapped_column(ARRAY(Enum(RoleEnum, values_callable=lambda cls: [member.value for member in cls], name='role_enum')), nullable=False)
3048+
""",
3049+
)
3050+
3051+
3052+
def test_enum_named_with_schema(generator: CodeGenerator) -> None:
3053+
Table(
3054+
"my_table",
3055+
generator.metadata,
3056+
Column("id", INTEGER, primary_key=True),
3057+
Column(
3058+
"status",
3059+
SAEnum("active", "inactive", name="status_enum", schema="custom_schema"),
3060+
nullable=False,
3061+
),
3062+
schema="custom_schema",
3063+
)
3064+
3065+
validate_code(
3066+
generator.generate(),
3067+
"""\
3068+
import enum
3069+
3070+
from sqlalchemy import Enum, Integer
3071+
from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column
3072+
3073+
class Base(DeclarativeBase):
3074+
pass
3075+
3076+
3077+
class StatusEnum(str, enum.Enum):
3078+
ACTIVE = 'active'
3079+
INACTIVE = 'inactive'
3080+
3081+
3082+
class MyTable(Base):
3083+
__tablename__ = 'my_table'
3084+
__table_args__ = {'schema': 'custom_schema'}
3085+
3086+
id: Mapped[int] = mapped_column(Integer, primary_key=True)
3087+
status: Mapped[StatusEnum] = mapped_column(Enum(StatusEnum, values_callable=lambda cls: [member.value for member in cls], name='status_enum', schema='custom_schema'), nullable=False)
3088+
""",
3089+
)
3090+
3091+
3092+
def test_array_enum_named_with_schema(generator: CodeGenerator) -> None:
3093+
Table(
3094+
"my_table",
3095+
generator.metadata,
3096+
Column("id", INTEGER, primary_key=True),
3097+
Column(
3098+
"tags",
3099+
ARRAY(SAEnum("a", "b", name="tag_enum", schema="custom_schema")),
3100+
nullable=False,
3101+
),
3102+
schema="custom_schema",
3103+
)
3104+
3105+
validate_code(
3106+
generator.generate(),
3107+
"""\
3108+
import enum
3109+
3110+
from sqlalchemy import ARRAY, Enum, Integer
3111+
from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column
3112+
3113+
class Base(DeclarativeBase):
3114+
pass
3115+
3116+
3117+
class TagEnum(str, enum.Enum):
3118+
A = 'a'
3119+
B = 'b'
3120+
3121+
3122+
class MyTable(Base):
3123+
__tablename__ = 'my_table'
3124+
__table_args__ = {'schema': 'custom_schema'}
3125+
3126+
id: Mapped[int] = mapped_column(Integer, primary_key=True)
3127+
tags: Mapped[list[TagEnum]] = mapped_column(ARRAY(Enum(TagEnum, values_callable=lambda cls: [member.value for member in cls], name='tag_enum', schema='custom_schema')), nullable=False)
30483128
""",
30493129
)

tests/test_generator_sqlmodel.py

Lines changed: 38 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
import pytest
44
from _pytest.fixtures import FixtureRequest
5+
from sqlalchemy import Enum as SAEnum
56
from sqlalchemy import Uuid
67
from sqlalchemy.engine import Engine
78
from sqlalchemy.schema import (
@@ -13,7 +14,7 @@
1314
Table,
1415
UniqueConstraint,
1516
)
16-
from sqlalchemy.types import INTEGER, VARCHAR
17+
from sqlalchemy.types import ARRAY, INTEGER, VARCHAR
1718

1819
from sqlacodegen.generators import CodeGenerator, SQLModelGenerator
1920

@@ -329,3 +330,39 @@ class Accounts(SQLModel, table=True):
329330
status: AccountsStatus = Field(sa_column=Column('status', Enum(AccountsStatus, values_callable=lambda cls: [member.value for member in cls]), nullable=False))
330331
""",
331332
)
333+
334+
335+
def test_array_enum_named_with_schema(generator: CodeGenerator) -> None:
336+
Table(
337+
"my_table",
338+
generator.metadata,
339+
Column("id", INTEGER, primary_key=True),
340+
Column(
341+
"tags",
342+
ARRAY(SAEnum("a", "b", name="tag_enum", schema="custom_schema")),
343+
nullable=False,
344+
),
345+
schema="custom_schema",
346+
)
347+
348+
validate_code(
349+
generator.generate(),
350+
"""\
351+
import enum
352+
353+
from sqlalchemy import ARRAY, Column, Enum, Integer
354+
from sqlmodel import Field, SQLModel
355+
356+
class TagEnum(str, enum.Enum):
357+
A = 'a'
358+
B = 'b'
359+
360+
361+
class MyTable(SQLModel, table=True):
362+
__tablename__ = 'my_table'
363+
__table_args__ = {'schema': 'custom_schema'}
364+
365+
id: int = Field(sa_column=Column('id', Integer, primary_key=True))
366+
tags: list[TagEnum] = Field(sa_column=Column('tags', ARRAY(Enum(TagEnum, values_callable=lambda cls: [member.value for member in cls], name='tag_enum', schema='custom_schema')), nullable=False))
367+
""",
368+
)

tests/test_generator_tables.py

Lines changed: 80 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -76,7 +76,7 @@ class Blah(str, enum.Enum):
7676
7777
t_simple_items = Table(
7878
'simple_items', metadata,
79-
Column('enum', Enum(Blah, values_callable=lambda cls: [member.value for member in cls])),
79+
Column('enum', Enum(Blah, values_callable=lambda cls: [member.value for member in cls], name='blah', schema='someschema')),
8080
Column('bool', Boolean),
8181
Column('vector', VECTOR(3)),
8282
Column('number', Numeric(10, asdecimal=False)),
@@ -309,13 +309,13 @@ class StatusEnum(str, enum.Enum):
309309
t_accounts = Table(
310310
'accounts', metadata,
311311
Column('id', Integer, primary_key=True),
312-
Column('status', Enum(StatusEnum, values_callable=lambda cls: [member.value for member in cls]))
312+
Column('status', Enum(StatusEnum, values_callable=lambda cls: [member.value for member in cls], name='status_enum'))
313313
)
314314
315315
t_users = Table(
316316
'users', metadata,
317317
Column('id', Integer, primary_key=True),
318-
Column('status', Enum(StatusEnum, values_callable=lambda cls: [member.value for member in cls]))
318+
Column('status', Enum(StatusEnum, values_callable=lambda cls: [member.value for member in cls], name='status_enum'))
319319
)
320320
""",
321321
)
@@ -348,7 +348,7 @@ class RoleEnum(str, enum.Enum):
348348
t_users = Table(
349349
'users', metadata,
350350
Column('id', Integer, primary_key=True),
351-
Column('roles', ARRAY(Enum(RoleEnum, values_callable=lambda cls: [member.value for member in cls])))
351+
Column('roles', ARRAY(Enum(RoleEnum, values_callable=lambda cls: [member.value for member in cls], name='role_enum')))
352352
)
353353
""",
354354
)
@@ -386,13 +386,87 @@ class RoleEnum(str, enum.Enum):
386386
t_groups = Table(
387387
'groups', metadata,
388388
Column('id', Integer, primary_key=True),
389-
Column('allowed_roles', ARRAY(Enum(RoleEnum, values_callable=lambda cls: [member.value for member in cls])))
389+
Column('allowed_roles', ARRAY(Enum(RoleEnum, values_callable=lambda cls: [member.value for member in cls], name='role_enum')))
390390
)
391391
392392
t_users = Table(
393393
'users', metadata,
394394
Column('id', Integer, primary_key=True),
395-
Column('roles', ARRAY(Enum(RoleEnum, values_callable=lambda cls: [member.value for member in cls])))
395+
Column('roles', ARRAY(Enum(RoleEnum, values_callable=lambda cls: [member.value for member in cls], name='role_enum')))
396+
)
397+
""",
398+
)
399+
400+
401+
def test_enum_named_with_schema(generator: CodeGenerator) -> None:
402+
Table(
403+
"my_table",
404+
generator.metadata,
405+
Column("id", INTEGER, primary_key=True),
406+
Column(
407+
"status",
408+
SAEnum("active", "inactive", name="status_enum", schema="custom_schema"),
409+
),
410+
schema="custom_schema",
411+
)
412+
413+
validate_code(
414+
generator.generate(),
415+
"""\
416+
import enum
417+
418+
from sqlalchemy import Column, Enum, Integer, MetaData, Table
419+
420+
metadata = MetaData()
421+
422+
423+
class StatusEnum(str, enum.Enum):
424+
ACTIVE = 'active'
425+
INACTIVE = 'inactive'
426+
427+
428+
t_my_table = Table(
429+
'my_table', metadata,
430+
Column('id', Integer, primary_key=True),
431+
Column('status', Enum(StatusEnum, values_callable=lambda cls: [member.value for member in cls], name='status_enum', schema='custom_schema')),
432+
schema='custom_schema'
433+
)
434+
""",
435+
)
436+
437+
438+
def test_array_enum_named_with_schema(generator: CodeGenerator) -> None:
439+
Table(
440+
"my_table",
441+
generator.metadata,
442+
Column("id", INTEGER, primary_key=True),
443+
Column(
444+
"tags",
445+
ARRAY(SAEnum("a", "b", name="tag_enum", schema="custom_schema")),
446+
),
447+
schema="custom_schema",
448+
)
449+
450+
validate_code(
451+
generator.generate(),
452+
"""\
453+
import enum
454+
455+
from sqlalchemy import ARRAY, Column, Enum, Integer, MetaData, Table
456+
457+
metadata = MetaData()
458+
459+
460+
class TagEnum(str, enum.Enum):
461+
A = 'a'
462+
B = 'b'
463+
464+
465+
t_my_table = Table(
466+
'my_table', metadata,
467+
Column('id', Integer, primary_key=True),
468+
Column('tags', ARRAY(Enum(TagEnum, values_callable=lambda cls: [member.value for member in cls], name='tag_enum', schema='custom_schema'))),
469+
schema='custom_schema'
396470
)
397471
""",
398472
)

0 commit comments

Comments
 (0)