Skip to content

Commit f4619eb

Browse files
sheinbergonagronholmpre-commit-ci[bot]
authored
Support native python enum generation (#446)
* support native python enum generation * support native python enum generation * Update CHANGES.rst with PR credit Added PR credit for enum generation feature. * Remove irrational logic * Remove Logic * test improvements * Remove unneeded test * Reinstate synthetic enum generation * CHANGES.rst improvements * Update CHANGES.rst Co-authored-by: Alex Grönholm <alex.gronholm@nextday.fi> * Update src/sqlacodegen/generators.py Co-authored-by: Alex Grönholm <alex.gronholm@nextday.fi> * Update src/sqlacodegen/generators.py Co-authored-by: Alex Grönholm <alex.gronholm@nextday.fi> * Update src/sqlacodegen/generators.py Co-authored-by: Alex Grönholm <alex.gronholm@nextday.fi> * Update src/sqlacodegen/generators.py Co-authored-by: Alex Grönholm <alex.gronholm@nextday.fi> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * PR Fixes * Fix RST formatting * PR Fixes * PR Fixes * PR Fixes * Revert PR Fix * Update CHANGES.rst Co-authored-by: Alex Grönholm <alex.gronholm@nextday.fi> * PR Fixes * Fix CHANGES.rst * Rework CHANGES.rst * Update src/sqlacodegen/generators.py Co-authored-by: Alex Grönholm <alex.gronholm@nextday.fi> * Update src/sqlacodegen/generators.py Co-authored-by: Alex Grönholm <alex.gronholm@nextday.fi> * PR Fixes * Minor cleanup * PR Fixes * PR Fixes * Update CHANGES.rst Co-authored-by: Alex Grönholm <alex.gronholm@nextday.fi> * Reformatted --------- Co-authored-by: Alex Grönholm <alex.gronholm@nextday.fi> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent 3a35437 commit f4619eb

7 files changed

Lines changed: 750 additions & 50 deletions

File tree

CHANGES.rst

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,17 @@
11
Version history
22
===============
33

4+
**4.0.0rc1**
5+
6+
- **BACKWARD INCOMPATIBLE** ``TablesGenerator.render_column_type()`` was changed to
7+
receive the ``Column`` object instead of the column type object as its sole argument
8+
- Added Python enum generation for native database ENUM types (e.g., PostgreSQL / MySQL ENUM).
9+
Retained synthetic Python enum generation from CHECK constraints with
10+
IN clauses (e.g., ``column IN ('val1', 'val2', ...)``). Use ``--options nonativeenums`` to
11+
disable enum generation for native database enums. Use ``--options nosyntheticenums`` to
12+
disable enum generation for synthetic database enums (VARCHAR columns with check constraints).
13+
(PR by @sheinbergon)
14+
415
**3.2.0**
516

617
- Dropped support for Python 3.9

README.rst

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -106,6 +106,8 @@ values must be delimited by commas, e.g. ``--options noconstraints,nobidi``):
106106
* ``noconstraints``: ignore constraints (foreign key, unique etc.)
107107
* ``nocomments``: ignore table/column comments
108108
* ``noindexes``: ignore indexes
109+
* ``nonativeenums``: don't generate Python enum classes for native database ENUM types (e.g., PostgreSQL ENUM); use plain string mapping instead
110+
* ``nosyntheticenums``: don't generate Python enum classes from CHECK constraints with IN clauses (e.g., ``column IN ('value1', 'value2', ...)``); preserves CHECK constraints as-is
109111
* ``noidsuffix``: prevent the special naming logic for single column many-to-one
110112
and one-to-one relationships (see `Relationship naming logic`_ for details)
111113
* ``include_dialect_options``: render a table' dialect options, such as ``starrocks_partition`` for StarRocks' specific options.

src/sqlacodegen/generators.py

Lines changed: 183 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -123,6 +123,8 @@ class TablesGenerator(CodeGenerator):
123123
"noindexes",
124124
"noconstraints",
125125
"nocomments",
126+
"nonativeenums",
127+
"nosyntheticenums",
126128
"include_dialect_options",
127129
"keep_dialect_types",
128130
}
@@ -148,6 +150,11 @@ def __init__(
148150
# Keep dialect-specific types instead of adapting to generic SQLAlchemy types
149151
self.keep_dialect_types: bool = "keep_dialect_types" in self.options
150152

153+
# Track Python enum classes: maps (table_name, column_name) -> enum_class_name
154+
self.enum_classes: dict[tuple[str, str], str] = {}
155+
# Track enum values: maps enum_class_name -> list of values
156+
self.enum_values: dict[str, list[str]] = {}
157+
151158
@property
152159
def views_supported(self) -> bool:
153160
return True
@@ -192,19 +199,22 @@ def generate(self) -> str:
192199
models: list[Model] = self.generate_models()
193200

194201
# Render module level variables
195-
variables = self.render_module_variables(models)
196-
if variables:
202+
if variables := self.render_module_variables(models):
197203
sections.append(variables + "\n")
198204

205+
# Render enum classes
206+
if enum_classes := self.render_enum_classes():
207+
sections.append(enum_classes + "\n")
208+
199209
# Render models
200-
rendered_models = self.render_models(models)
201-
if rendered_models:
210+
if rendered_models := self.render_models(models):
202211
sections.append(rendered_models)
203212

204213
# Render collected imports
205214
groups = self.group_imports()
206-
imports = "\n\n".join("\n".join(line for line in group) for group in groups)
207-
if imports:
215+
if imports := "\n\n".join(
216+
"\n".join(line for line in group) for group in groups
217+
):
208218
sections.insert(0, imports)
209219

210220
return "\n\n".join(sections) + "\n"
@@ -467,7 +477,7 @@ def render_column(
467477
# Render the column type if there are no foreign keys on it or any of them
468478
# points back to itself
469479
if not dedicated_fks or any(fk.column is column for fk in dedicated_fks):
470-
args.append(self.render_column_type(column.type))
480+
args.append(self.render_column_type(column))
471481

472482
for fk in dedicated_fks:
473483
args.append(self.render_constraint(fk))
@@ -528,10 +538,21 @@ def render_column_callable(self, is_table: bool, *args: Any, **kwargs: Any) -> s
528538
else:
529539
return render_callable("mapped_column", *args, kwargs=kwargs)
530540

531-
def render_column_type(self, coltype: TypeEngine[Any]) -> str:
541+
def render_column_type(self, column: Column[Any]) -> str:
542+
column_type = column.type
543+
# Check if this is an enum column with a Python enum class
544+
if isinstance(column_type, Enum) and column is not None:
545+
if enum_class_name := self.enum_classes.get(
546+
(column.table.name, column.name)
547+
):
548+
# Import SQLAlchemy Enum (will be handled in collect_imports)
549+
self.add_import(Enum)
550+
# Return the Python enum class as the type parameter
551+
return f"Enum({enum_class_name})"
552+
532553
args = []
533554
kwargs: dict[str, Any] = {}
534-
sig = inspect.signature(coltype.__class__.__init__)
555+
sig = inspect.signature(column_type.__class__.__init__)
535556
defaults = {param.name: param.default for param in sig.parameters.values()}
536557
missing = object()
537558
use_kwargs = False
@@ -543,7 +564,7 @@ def render_column_type(self, coltype: TypeEngine[Any]) -> str:
543564
use_kwargs = True
544565
continue
545566

546-
value = getattr(coltype, param.name, missing)
567+
value = getattr(column_type, param.name, missing)
547568

548569
if isinstance(value, (JSONB, JSON)):
549570
# Remove astext_type if it's the default
@@ -577,28 +598,28 @@ def render_column_type(self, coltype: TypeEngine[Any]) -> str:
577598
),
578599
None,
579600
)
580-
if vararg and hasattr(coltype, vararg):
581-
varargs_repr = [repr(arg) for arg in getattr(coltype, vararg)]
601+
if vararg and hasattr(column_type, vararg):
602+
varargs_repr = [repr(arg) for arg in getattr(column_type, vararg)]
582603
args.extend(varargs_repr)
583604

584605
# These arguments cannot be autodetected from the Enum initializer
585-
if isinstance(coltype, Enum):
606+
if isinstance(column_type, Enum):
586607
for colname in "name", "schema":
587-
if (value := getattr(coltype, colname)) is not None:
608+
if (value := getattr(column_type, colname)) is not None:
588609
kwargs[colname] = repr(value)
589610

590-
if isinstance(coltype, (JSONB, JSON)):
611+
if isinstance(column_type, (JSONB, JSON)):
591612
# Remove astext_type if it's the default
592613
if (
593-
isinstance(coltype.astext_type, Text)
594-
and coltype.astext_type.length is None
614+
isinstance(column_type.astext_type, Text)
615+
and column_type.astext_type.length is None
595616
):
596617
del kwargs["astext_type"]
597618

598619
if args or kwargs:
599-
return render_callable(coltype.__class__.__name__, *args, kwargs=kwargs)
620+
return render_callable(column_type.__class__.__name__, *args, kwargs=kwargs)
600621
else:
601-
return coltype.__class__.__name__
622+
return column_type.__class__.__name__
602623

603624
def render_constraint(self, constraint: Constraint | ForeignKey) -> str:
604625
def add_fk_options(*opts: Any) -> None:
@@ -709,6 +730,81 @@ def find_free_name(
709730

710731
return name
711732

733+
def _enum_name_to_class_name(self, enum_name: str) -> str:
734+
"""Convert a database enum name to a Python class name (PascalCase)."""
735+
return "".join(part.capitalize() for part in enum_name.split("_") if part)
736+
737+
def _create_enum_class(
738+
self, table_name: str, column_name: str, values: list[str]
739+
) -> str:
740+
"""
741+
Create a Python enum class name and register it.
742+
743+
Returns the enum class name to use in generated code.
744+
"""
745+
# Generate enum class name from table and column names
746+
# Convert to PascalCase: user_status -> UserStatus
747+
base_name = "".join(
748+
part.capitalize()
749+
for part in table_name.split("_") + column_name.split("_")
750+
if part
751+
)
752+
753+
# Ensure uniqueness
754+
enum_class_name = base_name
755+
for counter in count(1):
756+
if enum_class_name not in self.enum_values:
757+
break
758+
759+
# Check if it's the same enum (same values)
760+
if self.enum_values[enum_class_name] == values:
761+
# Reuse existing enum class
762+
return enum_class_name
763+
764+
enum_class_name = f"{base_name}{counter}"
765+
766+
# Register the new enum class
767+
self.enum_values[enum_class_name] = values
768+
return enum_class_name
769+
770+
def render_enum_classes(self) -> str:
771+
"""Render Python enum class definitions."""
772+
if not self.enum_values:
773+
return ""
774+
775+
self.add_module_import("enum")
776+
777+
enum_defs = []
778+
for enum_class_name, values in sorted(self.enum_values.items()):
779+
# Create enum members with valid Python identifiers
780+
members = []
781+
for value in values:
782+
# Unescape SQL escape sequences (e.g., \' -> ')
783+
# The value from the CHECK constraint has SQL escaping
784+
unescaped_value = value.replace("\\'", "'").replace("\\\\", "\\")
785+
786+
# Create a valid identifier from the enum value
787+
member_name = _re_invalid_identifier.sub("_", unescaped_value).upper()
788+
if not member_name:
789+
member_name = "EMPTY"
790+
elif member_name[0].isdigit():
791+
member_name = "_" + member_name
792+
elif iskeyword(member_name):
793+
member_name += "_"
794+
#
795+
# # Re-escape for Python string literal
796+
# python_escaped = unescaped_value.replace("\\", "\\\\").replace(
797+
# "'", "\\'"
798+
# )
799+
members.append(f" {member_name} = {unescaped_value!r}")
800+
801+
enum_def = f"class {enum_class_name}(str, enum.Enum):\n" + "\n".join(
802+
members
803+
)
804+
enum_defs.append(enum_def)
805+
806+
return "\n\n\n".join(enum_defs)
807+
712808
def fix_column_types(self, table: Table) -> None:
713809
"""Adjust the reflected column types."""
714810
# Detect check constraints for boolean and enum columns
@@ -718,34 +814,74 @@ def fix_column_types(self, table: Table) -> None:
718814

719815
# Turn any integer-like column with a CheckConstraint like
720816
# "column IN (0, 1)" into a Boolean
721-
match = _re_boolean_check_constraint.match(sqltext)
722-
if match:
723-
colname_match = _re_column_name.match(match.group(1))
724-
if colname_match:
817+
if match := _re_boolean_check_constraint.match(sqltext):
818+
if colname_match := _re_column_name.match(match.group(1)):
725819
colname = colname_match.group(3)
726820
table.constraints.remove(constraint)
727821
table.c[colname].type = Boolean()
728822
continue
729823

730-
# Turn any string-type column with a CheckConstraint like
731-
# "column IN (...)" into an Enum
732-
match = _re_enum_check_constraint.match(sqltext)
733-
if match:
734-
colname_match = _re_column_name.match(match.group(1))
735-
if colname_match:
736-
colname = colname_match.group(3)
737-
items = match.group(2)
738-
if isinstance(table.c[colname].type, String):
739-
table.constraints.remove(constraint)
740-
if not isinstance(table.c[colname].type, Enum):
741-
options = _re_enum_item.findall(items)
742-
table.c[colname].type = Enum(
743-
*options, native_enum=False
824+
# Turn VARCHAR columns with CHECK constraints like "column IN ('a', 'b')"
825+
# into synthetic Enum types with Python enum classes
826+
if (
827+
"nosyntheticenums" not in self.options
828+
and (match := _re_enum_check_constraint.match(sqltext))
829+
and (colname_match := _re_column_name.match(match.group(1)))
830+
):
831+
colname = colname_match.group(3)
832+
items = match.group(2)
833+
if isinstance(table.c[colname].type, String) and not isinstance(
834+
table.c[colname].type, Enum
835+
):
836+
options = _re_enum_item.findall(items)
837+
# Create Python enum class
838+
enum_class_name = self._create_enum_class(
839+
table.name, colname, options
840+
)
841+
self.enum_classes[(table.name, colname)] = enum_class_name
842+
# Convert to Enum type but KEEP the constraint
843+
table.c[colname].type = Enum(*options, native_enum=False)
844+
continue
845+
846+
for column in table.c:
847+
# Handle native database Enum types (e.g., PostgreSQL ENUM)
848+
if (
849+
"nonativeenums" not in self.options
850+
and isinstance(column.type, Enum)
851+
and column.type.enums
852+
):
853+
if column.type.name:
854+
# Named enum - create shared enum class if not already created
855+
if (table.name, column.name) not in self.enum_classes:
856+
# Check if we've already created an enum for this name
857+
existing_class = None
858+
for (t, c), cls in self.enum_classes.items():
859+
if cls == self._enum_name_to_class_name(column.type.name):
860+
existing_class = cls
861+
break
862+
863+
if existing_class:
864+
enum_class_name = existing_class
865+
else:
866+
# Create new enum class from the enum's name
867+
enum_class_name = self._enum_name_to_class_name(
868+
column.type.name
869+
)
870+
# Register the enum values if not already registered
871+
if enum_class_name not in self.enum_values:
872+
self.enum_values[enum_class_name] = list(
873+
column.type.enums
744874
)
745875

746-
continue
876+
self.enum_classes[(table.name, column.name)] = enum_class_name
877+
else:
878+
# Unnamed enum - create enum class per column
879+
if (table.name, column.name) not in self.enum_classes:
880+
enum_class_name = self._create_enum_class(
881+
table.name, column.name, list(column.type.enums)
882+
)
883+
self.enum_classes[(table.name, column.name)] = enum_class_name
747884

748-
for column in table.c:
749885
if not self.keep_dialect_types:
750886
try:
751887
column.type = self.get_adapted_type(column.type)
@@ -1326,6 +1462,14 @@ def get_type_qualifiers() -> tuple[str, TypeEngine[Any], str]:
13261462
return "".join(pre), column_type, "]" * post_size
13271463

13281464
def render_python_type(column_type: TypeEngine[Any]) -> str:
1465+
# Check if this is an enum column with a Python enum class
1466+
if isinstance(column_type, Enum):
1467+
table_name = column.table.name
1468+
column_name = column.name
1469+
if (table_name, column_name) in self.enum_classes:
1470+
enum_class_name = self.enum_classes[(table_name, column_name)]
1471+
return enum_class_name
1472+
13291473
if isinstance(column_type, DOMAIN):
13301474
column_type = column_type.data_type
13311475

src/sqlacodegen/utils.py

Lines changed: 1 addition & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -210,10 +210,4 @@ def decode_postgresql_sequence(clause: TextClause) -> tuple[str | None, str | No
210210

211211

212212
def get_stdlib_module_names() -> set[str]:
213-
major, minor = sys.version_info.major, sys.version_info.minor
214-
if (major, minor) > (3, 9):
215-
return set(sys.builtin_module_names) | set(sys.stdlib_module_names)
216-
else:
217-
from stdlib_list import stdlib_list
218-
219-
return set(sys.builtin_module_names) | set(stdlib_list(f"{major}.{minor}"))
213+
return set(sys.builtin_module_names) | set(sys.stdlib_module_names)

0 commit comments

Comments
 (0)