@@ -123,6 +123,8 @@ class TablesGenerator(CodeGenerator):
123123 "noindexes" ,
124124 "noconstraints" ,
125125 "nocomments" ,
126+ "nonativeenums" ,
127+ "nosyntheticenums" ,
126128 "include_dialect_options" ,
127129 "keep_dialect_types" ,
128130 }
@@ -148,6 +150,11 @@ def __init__(
148150 # Keep dialect-specific types instead of adapting to generic SQLAlchemy types
149151 self .keep_dialect_types : bool = "keep_dialect_types" in self .options
150152
153+ # Track Python enum classes: maps (table_name, column_name) -> enum_class_name
154+ self .enum_classes : dict [tuple [str , str ], str ] = {}
155+ # Track enum values: maps enum_class_name -> list of values
156+ self .enum_values : dict [str , list [str ]] = {}
157+
151158 @property
152159 def views_supported (self ) -> bool :
153160 return True
@@ -192,19 +199,22 @@ def generate(self) -> str:
192199 models : list [Model ] = self .generate_models ()
193200
194201 # Render module level variables
195- variables = self .render_module_variables (models )
196- if variables :
202+ if variables := self .render_module_variables (models ):
197203 sections .append (variables + "\n " )
198204
205+ # Render enum classes
206+ if enum_classes := self .render_enum_classes ():
207+ sections .append (enum_classes + "\n " )
208+
199209 # Render models
200- rendered_models = self .render_models (models )
201- if rendered_models :
210+ if rendered_models := self .render_models (models ):
202211 sections .append (rendered_models )
203212
204213 # Render collected imports
205214 groups = self .group_imports ()
206- imports = "\n \n " .join ("\n " .join (line for line in group ) for group in groups )
207- if imports :
215+ if imports := "\n \n " .join (
216+ "\n " .join (line for line in group ) for group in groups
217+ ):
208218 sections .insert (0 , imports )
209219
210220 return "\n \n " .join (sections ) + "\n "
@@ -467,7 +477,7 @@ def render_column(
467477 # Render the column type if there are no foreign keys on it or any of them
468478 # points back to itself
469479 if not dedicated_fks or any (fk .column is column for fk in dedicated_fks ):
470- args .append (self .render_column_type (column . type ))
480+ args .append (self .render_column_type (column ))
471481
472482 for fk in dedicated_fks :
473483 args .append (self .render_constraint (fk ))
@@ -528,10 +538,21 @@ def render_column_callable(self, is_table: bool, *args: Any, **kwargs: Any) -> s
528538 else :
529539 return render_callable ("mapped_column" , * args , kwargs = kwargs )
530540
531- def render_column_type (self , coltype : TypeEngine [Any ]) -> str :
541+ def render_column_type (self , column : Column [Any ]) -> str :
542+ column_type = column .type
543+ # Check if this is an enum column with a Python enum class
544+ if isinstance (column_type , Enum ) and column is not None :
545+ if enum_class_name := self .enum_classes .get (
546+ (column .table .name , column .name )
547+ ):
548+ # Import SQLAlchemy Enum (will be handled in collect_imports)
549+ self .add_import (Enum )
550+ # Return the Python enum class as the type parameter
551+ return f"Enum({ enum_class_name } )"
552+
532553 args = []
533554 kwargs : dict [str , Any ] = {}
534- sig = inspect .signature (coltype .__class__ .__init__ )
555+ sig = inspect .signature (column_type .__class__ .__init__ )
535556 defaults = {param .name : param .default for param in sig .parameters .values ()}
536557 missing = object ()
537558 use_kwargs = False
@@ -543,7 +564,7 @@ def render_column_type(self, coltype: TypeEngine[Any]) -> str:
543564 use_kwargs = True
544565 continue
545566
546- value = getattr (coltype , param .name , missing )
567+ value = getattr (column_type , param .name , missing )
547568
548569 if isinstance (value , (JSONB , JSON )):
549570 # Remove astext_type if it's the default
@@ -577,28 +598,28 @@ def render_column_type(self, coltype: TypeEngine[Any]) -> str:
577598 ),
578599 None ,
579600 )
580- if vararg and hasattr (coltype , vararg ):
581- varargs_repr = [repr (arg ) for arg in getattr (coltype , vararg )]
601+ if vararg and hasattr (column_type , vararg ):
602+ varargs_repr = [repr (arg ) for arg in getattr (column_type , vararg )]
582603 args .extend (varargs_repr )
583604
584605 # These arguments cannot be autodetected from the Enum initializer
585- if isinstance (coltype , Enum ):
606+ if isinstance (column_type , Enum ):
586607 for colname in "name" , "schema" :
587- if (value := getattr (coltype , colname )) is not None :
608+ if (value := getattr (column_type , colname )) is not None :
588609 kwargs [colname ] = repr (value )
589610
590- if isinstance (coltype , (JSONB , JSON )):
611+ if isinstance (column_type , (JSONB , JSON )):
591612 # Remove astext_type if it's the default
592613 if (
593- isinstance (coltype .astext_type , Text )
594- and coltype .astext_type .length is None
614+ isinstance (column_type .astext_type , Text )
615+ and column_type .astext_type .length is None
595616 ):
596617 del kwargs ["astext_type" ]
597618
598619 if args or kwargs :
599- return render_callable (coltype .__class__ .__name__ , * args , kwargs = kwargs )
620+ return render_callable (column_type .__class__ .__name__ , * args , kwargs = kwargs )
600621 else :
601- return coltype .__class__ .__name__
622+ return column_type .__class__ .__name__
602623
603624 def render_constraint (self , constraint : Constraint | ForeignKey ) -> str :
604625 def add_fk_options (* opts : Any ) -> None :
@@ -709,6 +730,81 @@ def find_free_name(
709730
710731 return name
711732
733+ def _enum_name_to_class_name (self , enum_name : str ) -> str :
734+ """Convert a database enum name to a Python class name (PascalCase)."""
735+ return "" .join (part .capitalize () for part in enum_name .split ("_" ) if part )
736+
737+ def _create_enum_class (
738+ self , table_name : str , column_name : str , values : list [str ]
739+ ) -> str :
740+ """
741+ Create a Python enum class name and register it.
742+
743+ Returns the enum class name to use in generated code.
744+ """
745+ # Generate enum class name from table and column names
746+ # Convert to PascalCase: user_status -> UserStatus
747+ base_name = "" .join (
748+ part .capitalize ()
749+ for part in table_name .split ("_" ) + column_name .split ("_" )
750+ if part
751+ )
752+
753+ # Ensure uniqueness
754+ enum_class_name = base_name
755+ for counter in count (1 ):
756+ if enum_class_name not in self .enum_values :
757+ break
758+
759+ # Check if it's the same enum (same values)
760+ if self .enum_values [enum_class_name ] == values :
761+ # Reuse existing enum class
762+ return enum_class_name
763+
764+ enum_class_name = f"{ base_name } { counter } "
765+
766+ # Register the new enum class
767+ self .enum_values [enum_class_name ] = values
768+ return enum_class_name
769+
770+ def render_enum_classes (self ) -> str :
771+ """Render Python enum class definitions."""
772+ if not self .enum_values :
773+ return ""
774+
775+ self .add_module_import ("enum" )
776+
777+ enum_defs = []
778+ for enum_class_name , values in sorted (self .enum_values .items ()):
779+ # Create enum members with valid Python identifiers
780+ members = []
781+ for value in values :
782+ # Unescape SQL escape sequences (e.g., \' -> ')
783+ # The value from the CHECK constraint has SQL escaping
784+ unescaped_value = value .replace ("\\ '" , "'" ).replace ("\\ \\ " , "\\ " )
785+
786+ # Create a valid identifier from the enum value
787+ member_name = _re_invalid_identifier .sub ("_" , unescaped_value ).upper ()
788+ if not member_name :
789+ member_name = "EMPTY"
790+ elif member_name [0 ].isdigit ():
791+ member_name = "_" + member_name
792+ elif iskeyword (member_name ):
793+ member_name += "_"
794+ #
795+ # # Re-escape for Python string literal
796+ # python_escaped = unescaped_value.replace("\\", "\\\\").replace(
797+ # "'", "\\'"
798+ # )
799+ members .append (f" { member_name } = { unescaped_value !r} " )
800+
801+ enum_def = f"class { enum_class_name } (str, enum.Enum):\n " + "\n " .join (
802+ members
803+ )
804+ enum_defs .append (enum_def )
805+
806+ return "\n \n \n " .join (enum_defs )
807+
712808 def fix_column_types (self , table : Table ) -> None :
713809 """Adjust the reflected column types."""
714810 # Detect check constraints for boolean and enum columns
@@ -718,34 +814,74 @@ def fix_column_types(self, table: Table) -> None:
718814
719815 # Turn any integer-like column with a CheckConstraint like
720816 # "column IN (0, 1)" into a Boolean
721- match = _re_boolean_check_constraint .match (sqltext )
722- if match :
723- colname_match = _re_column_name .match (match .group (1 ))
724- if colname_match :
817+ if match := _re_boolean_check_constraint .match (sqltext ):
818+ if colname_match := _re_column_name .match (match .group (1 )):
725819 colname = colname_match .group (3 )
726820 table .constraints .remove (constraint )
727821 table .c [colname ].type = Boolean ()
728822 continue
729823
730- # Turn any string-type column with a CheckConstraint like
731- # "column IN (...)" into an Enum
732- match = _re_enum_check_constraint .match (sqltext )
733- if match :
734- colname_match = _re_column_name .match (match .group (1 ))
735- if colname_match :
736- colname = colname_match .group (3 )
737- items = match .group (2 )
738- if isinstance (table .c [colname ].type , String ):
739- table .constraints .remove (constraint )
740- if not isinstance (table .c [colname ].type , Enum ):
741- options = _re_enum_item .findall (items )
742- table .c [colname ].type = Enum (
743- * options , native_enum = False
824+ # Turn VARCHAR columns with CHECK constraints like "column IN ('a', 'b')"
825+ # into synthetic Enum types with Python enum classes
826+ if (
827+ "nosyntheticenums" not in self .options
828+ and (match := _re_enum_check_constraint .match (sqltext ))
829+ and (colname_match := _re_column_name .match (match .group (1 )))
830+ ):
831+ colname = colname_match .group (3 )
832+ items = match .group (2 )
833+ if isinstance (table .c [colname ].type , String ) and not isinstance (
834+ table .c [colname ].type , Enum
835+ ):
836+ options = _re_enum_item .findall (items )
837+ # Create Python enum class
838+ enum_class_name = self ._create_enum_class (
839+ table .name , colname , options
840+ )
841+ self .enum_classes [(table .name , colname )] = enum_class_name
842+ # Convert to Enum type but KEEP the constraint
843+ table .c [colname ].type = Enum (* options , native_enum = False )
844+ continue
845+
846+ for column in table .c :
847+ # Handle native database Enum types (e.g., PostgreSQL ENUM)
848+ if (
849+ "nonativeenums" not in self .options
850+ and isinstance (column .type , Enum )
851+ and column .type .enums
852+ ):
853+ if column .type .name :
854+ # Named enum - create shared enum class if not already created
855+ if (table .name , column .name ) not in self .enum_classes :
856+ # Check if we've already created an enum for this name
857+ existing_class = None
858+ for (t , c ), cls in self .enum_classes .items ():
859+ if cls == self ._enum_name_to_class_name (column .type .name ):
860+ existing_class = cls
861+ break
862+
863+ if existing_class :
864+ enum_class_name = existing_class
865+ else :
866+ # Create new enum class from the enum's name
867+ enum_class_name = self ._enum_name_to_class_name (
868+ column .type .name
869+ )
870+ # Register the enum values if not already registered
871+ if enum_class_name not in self .enum_values :
872+ self .enum_values [enum_class_name ] = list (
873+ column .type .enums
744874 )
745875
746- continue
876+ self .enum_classes [(table .name , column .name )] = enum_class_name
877+ else :
878+ # Unnamed enum - create enum class per column
879+ if (table .name , column .name ) not in self .enum_classes :
880+ enum_class_name = self ._create_enum_class (
881+ table .name , column .name , list (column .type .enums )
882+ )
883+ self .enum_classes [(table .name , column .name )] = enum_class_name
747884
748- for column in table .c :
749885 if not self .keep_dialect_types :
750886 try :
751887 column .type = self .get_adapted_type (column .type )
@@ -1326,6 +1462,14 @@ def get_type_qualifiers() -> tuple[str, TypeEngine[Any], str]:
13261462 return "" .join (pre ), column_type , "]" * post_size
13271463
13281464 def render_python_type (column_type : TypeEngine [Any ]) -> str :
1465+ # Check if this is an enum column with a Python enum class
1466+ if isinstance (column_type , Enum ):
1467+ table_name = column .table .name
1468+ column_name = column .name
1469+ if (table_name , column_name ) in self .enum_classes :
1470+ enum_class_name = self .enum_classes [(table_name , column_name )]
1471+ return enum_class_name
1472+
13291473 if isinstance (column_type , DOMAIN ):
13301474 column_type = column_type .data_type
13311475
0 commit comments