Skip to content

Commit 4772a34

Browse files
Support enum arrays (#455)
* Support enum arrays * re-order changes * Update src/sqlacodegen/generators.py Co-authored-by: Alex Grönholm <alex.gronholm@nextday.fi> * Update CHANGES.rst Co-authored-by: Alex Grönholm <alex.gronholm@nextday.fi> * PR Fixes round #1 * PR Fixes round #2 * PR Fixes round #3 * Added a blank line after a control block * Removed duplicate changelog entries --------- Co-authored-by: Alex Grönholm <alex.gronholm@nextday.fi>
1 parent e9bfd34 commit 4772a34

4 files changed

Lines changed: 364 additions & 32 deletions

File tree

CHANGES.rst

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,14 @@
11
Version history
22
===============
33

4-
**UNRELEASED**
4+
**4.0.0rc3**
55

66
- **BACKWARD INCOMPATIBLE** Relationship names changed when multiple FKs or junction tables
77
connect to the same target table. Regenerating models will break existing code.
8+
- Added support for generating Python enum classes for ``ARRAY(Enum(...))`` columns
9+
(e.g., PostgreSQL ``ARRAY(ENUM)``). Supports named/unnamed enums, shared enums across
10+
columns, and multi-dimensional arrays. Respects ``--options nonativeenums``.
11+
(PR by @sheinbergon)
812
- Improved relationship naming: one-to-many uses FK column names (e.g.,
913
``simple_items_parent_container``), many-to-many uses junction table names (e.g.,
1014
``students_enrollments``). Use ``--options nofknames`` to revert to old behavior. (PR by @sheinbergon)

src/sqlacodegen/generators.py

Lines changed: 48 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -554,6 +554,20 @@ def render_column_type(self, column: Column[Any]) -> str:
554554

555555
args = []
556556
kwargs: dict[str, Any] = {}
557+
558+
# Check if this is an ARRAY column with an Enum item type mapped to a Python enum class
559+
if isinstance(column_type, ARRAY) and isinstance(column_type.item_type, Enum):
560+
if enum_class_name := self.enum_classes.get(
561+
(column.table.name, column.name)
562+
):
563+
self.add_import(ARRAY)
564+
self.add_import(Enum)
565+
rendered_enum = f"Enum({enum_class_name}, values_callable=lambda cls: [member.value for member in cls])"
566+
if column_type.dimensions is not None:
567+
kwargs["dimensions"] = repr(column_type.dimensions)
568+
569+
return render_callable("ARRAY", rendered_enum, kwargs=kwargs)
570+
557571
sig = inspect.signature(column_type.__class__.__init__)
558572
defaults = {param.name: param.default for param in sig.parameters.values()}
559573
missing = object()
@@ -809,6 +823,31 @@ def render_enum_classes(self) -> str:
809823

810824
def fix_column_types(self, table: Table) -> None:
811825
"""Adjust the reflected column types."""
826+
827+
def fix_enum_column(col_name: str, enum_type: Enum) -> None:
828+
if (table.name, col_name) in self.enum_classes:
829+
return
830+
831+
if enum_type.name:
832+
existing_class = None
833+
for (_, _), cls in self.enum_classes.items():
834+
if cls == self._enum_name_to_class_name(enum_type.name):
835+
existing_class = cls
836+
break
837+
838+
if existing_class:
839+
enum_class_name = existing_class
840+
else:
841+
enum_class_name = self._enum_name_to_class_name(enum_type.name)
842+
if enum_class_name not in self.enum_values:
843+
self.enum_values[enum_class_name] = list(enum_type.enums)
844+
else:
845+
enum_class_name = self._create_enum_class(
846+
table.name, col_name, list(enum_type.enums)
847+
)
848+
849+
self.enum_classes[(table.name, col_name)] = enum_class_name
850+
812851
# Detect check constraints for boolean and enum columns
813852
for constraint in table.constraints.copy():
814853
if isinstance(constraint, CheckConstraint):
@@ -852,37 +891,16 @@ def fix_column_types(self, table: Table) -> None:
852891
and isinstance(column.type, Enum)
853892
and column.type.enums
854893
):
855-
if column.type.name:
856-
# Named enum - create shared enum class if not already created
857-
if (table.name, column.name) not in self.enum_classes:
858-
# Check if we've already created an enum for this name
859-
existing_class = None
860-
for (t, c), cls in self.enum_classes.items():
861-
if cls == self._enum_name_to_class_name(column.type.name):
862-
existing_class = cls
863-
break
864-
865-
if existing_class:
866-
enum_class_name = existing_class
867-
else:
868-
# Create new enum class from the enum's name
869-
enum_class_name = self._enum_name_to_class_name(
870-
column.type.name
871-
)
872-
# Register the enum values if not already registered
873-
if enum_class_name not in self.enum_values:
874-
self.enum_values[enum_class_name] = list(
875-
column.type.enums
876-
)
894+
fix_enum_column(column.name, column.type)
877895

878-
self.enum_classes[(table.name, column.name)] = enum_class_name
879-
else:
880-
# Unnamed enum - create enum class per column
881-
if (table.name, column.name) not in self.enum_classes:
882-
enum_class_name = self._create_enum_class(
883-
table.name, column.name, list(column.type.enums)
884-
)
885-
self.enum_classes[(table.name, column.name)] = enum_class_name
896+
# Handle ARRAY columns with Enum item types (e.g., PostgreSQL ARRAY(ENUM))
897+
elif (
898+
"nonativeenums" not in self.options
899+
and isinstance(column.type, ARRAY)
900+
and isinstance(column.type.item_type, Enum)
901+
and column.type.item_type.enums
902+
):
903+
fix_enum_column(column.name, column.type.item_type)
886904

887905
if not self.keep_dialect_types:
888906
try:

tests/test_generator_declarative.py

Lines changed: 232 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2815,3 +2815,235 @@ class Users(Base):
28152815
status: Mapped[UsersStatus] = mapped_column(Enum(UsersStatus, values_callable=lambda cls: [member.value for member in cls]), nullable=False)
28162816
""",
28172817
)
2818+
2819+
2820+
def test_array_enum_named(generator: CodeGenerator) -> None:
2821+
Table(
2822+
"users",
2823+
generator.metadata,
2824+
Column("id", INTEGER, primary_key=True),
2825+
Column(
2826+
"roles",
2827+
ARRAY(SAEnum("admin", "user", "moderator", name="role_enum")),
2828+
nullable=False,
2829+
),
2830+
)
2831+
2832+
validate_code(
2833+
generator.generate(),
2834+
"""\
2835+
import enum
2836+
2837+
from sqlalchemy import ARRAY, Enum, Integer
2838+
from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column
2839+
2840+
class Base(DeclarativeBase):
2841+
pass
2842+
2843+
2844+
class RoleEnum(str, enum.Enum):
2845+
ADMIN = 'admin'
2846+
USER = 'user'
2847+
MODERATOR = 'moderator'
2848+
2849+
2850+
class Users(Base):
2851+
__tablename__ = 'users'
2852+
2853+
id: Mapped[int] = mapped_column(Integer, primary_key=True)
2854+
roles: Mapped[list[RoleEnum]] = mapped_column(ARRAY(Enum(RoleEnum, values_callable=lambda cls: [member.value for member in cls])), nullable=False)
2855+
""",
2856+
)
2857+
2858+
2859+
def test_array_enum_unnamed(generator: CodeGenerator) -> None:
2860+
Table(
2861+
"users",
2862+
generator.metadata,
2863+
Column("id", INTEGER, primary_key=True),
2864+
Column(
2865+
"roles",
2866+
ARRAY(SAEnum("admin", "user")),
2867+
nullable=False,
2868+
),
2869+
)
2870+
2871+
validate_code(
2872+
generator.generate(),
2873+
"""\
2874+
import enum
2875+
2876+
from sqlalchemy import ARRAY, Enum, Integer
2877+
from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column
2878+
2879+
class Base(DeclarativeBase):
2880+
pass
2881+
2882+
2883+
class UsersRoles(str, enum.Enum):
2884+
ADMIN = 'admin'
2885+
USER = 'user'
2886+
2887+
2888+
class Users(Base):
2889+
__tablename__ = 'users'
2890+
2891+
id: Mapped[int] = mapped_column(Integer, primary_key=True)
2892+
roles: Mapped[list[UsersRoles]] = mapped_column(ARRAY(Enum(UsersRoles, values_callable=lambda cls: [member.value for member in cls])), nullable=False)
2893+
""",
2894+
)
2895+
2896+
2897+
def test_array_enum_nullable(generator: CodeGenerator) -> None:
2898+
Table(
2899+
"users",
2900+
generator.metadata,
2901+
Column("id", INTEGER, primary_key=True),
2902+
Column(
2903+
"roles",
2904+
ARRAY(SAEnum("admin", "user", name="role_enum")),
2905+
),
2906+
)
2907+
2908+
validate_code(
2909+
generator.generate(),
2910+
"""\
2911+
from typing import Optional
2912+
import enum
2913+
2914+
from sqlalchemy import ARRAY, Enum, Integer
2915+
from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column
2916+
2917+
class Base(DeclarativeBase):
2918+
pass
2919+
2920+
2921+
class RoleEnum(str, enum.Enum):
2922+
ADMIN = 'admin'
2923+
USER = 'user'
2924+
2925+
2926+
class Users(Base):
2927+
__tablename__ = 'users'
2928+
2929+
id: Mapped[int] = mapped_column(Integer, primary_key=True)
2930+
roles: Mapped[Optional[list[RoleEnum]]] = mapped_column(ARRAY(Enum(RoleEnum, values_callable=lambda cls: [member.value for member in cls])))
2931+
""",
2932+
)
2933+
2934+
2935+
def test_array_enum_with_dimensions(generator: CodeGenerator) -> None:
2936+
Table(
2937+
"items",
2938+
generator.metadata,
2939+
Column("id", INTEGER, primary_key=True),
2940+
Column(
2941+
"tag_matrix",
2942+
ARRAY(SAEnum("a", "b", name="tag_enum"), dimensions=2),
2943+
nullable=False,
2944+
),
2945+
)
2946+
2947+
validate_code(
2948+
generator.generate(),
2949+
"""\
2950+
import enum
2951+
2952+
from sqlalchemy import ARRAY, Enum, Integer
2953+
from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column
2954+
2955+
class Base(DeclarativeBase):
2956+
pass
2957+
2958+
2959+
class TagEnum(str, enum.Enum):
2960+
A = 'a'
2961+
B = 'b'
2962+
2963+
2964+
class Items(Base):
2965+
__tablename__ = 'items'
2966+
2967+
id: Mapped[int] = mapped_column(Integer, primary_key=True)
2968+
tag_matrix: Mapped[list[list[TagEnum]]] = mapped_column(ARRAY(Enum(TagEnum, values_callable=lambda cls: [member.value for member in cls]), dimensions=2), nullable=False)
2969+
""",
2970+
)
2971+
2972+
2973+
def test_array_enum_nonativeenums_option(generator: CodeGenerator) -> None:
2974+
Table(
2975+
"users",
2976+
generator.metadata,
2977+
Column("id", INTEGER, primary_key=True),
2978+
Column(
2979+
"roles",
2980+
ARRAY(SAEnum("admin", "user", name="role_enum")),
2981+
nullable=False,
2982+
),
2983+
)
2984+
2985+
generator = DeclarativeGenerator(
2986+
generator.metadata, generator.bind, ["nonativeenums"]
2987+
)
2988+
2989+
validate_code(
2990+
generator.generate(),
2991+
"""\
2992+
from sqlalchemy import ARRAY, Enum, Integer
2993+
from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column
2994+
2995+
class Base(DeclarativeBase):
2996+
pass
2997+
2998+
2999+
class Users(Base):
3000+
__tablename__ = 'users'
3001+
3002+
id: Mapped[int] = mapped_column(Integer, primary_key=True)
3003+
roles: Mapped[list[str]] = mapped_column(ARRAY(Enum('admin', 'user', name='role_enum')), nullable=False)
3004+
""",
3005+
)
3006+
3007+
3008+
def test_array_enum_shared_with_regular_enum(generator: CodeGenerator) -> None:
3009+
Table(
3010+
"users",
3011+
generator.metadata,
3012+
Column("id", INTEGER, primary_key=True),
3013+
Column(
3014+
"primary_role",
3015+
SAEnum("admin", "user", name="role_enum"),
3016+
nullable=False,
3017+
),
3018+
Column(
3019+
"all_roles",
3020+
ARRAY(SAEnum("admin", "user", name="role_enum")),
3021+
nullable=False,
3022+
),
3023+
)
3024+
3025+
validate_code(
3026+
generator.generate(),
3027+
"""\
3028+
import enum
3029+
3030+
from sqlalchemy import ARRAY, Enum, Integer
3031+
from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column
3032+
3033+
class Base(DeclarativeBase):
3034+
pass
3035+
3036+
3037+
class RoleEnum(str, enum.Enum):
3038+
ADMIN = 'admin'
3039+
USER = 'user'
3040+
3041+
3042+
class Users(Base):
3043+
__tablename__ = 'users'
3044+
3045+
id: Mapped[int] = mapped_column(Integer, primary_key=True)
3046+
primary_role: Mapped[RoleEnum] = mapped_column(Enum(RoleEnum, values_callable=lambda cls: [member.value for member in cls]), nullable=False)
3047+
all_roles: Mapped[list[RoleEnum]] = mapped_column(ARRAY(Enum(RoleEnum, values_callable=lambda cls: [member.value for member in cls])), nullable=False)
3048+
""",
3049+
)

0 commit comments

Comments
 (0)