|
13 | 13 | from keyword import iskeyword |
14 | 14 | from pprint import pformat |
15 | 15 | from textwrap import indent |
16 | | -from typing import Any, ClassVar |
| 16 | +from typing import Any, ClassVar, Literal, cast |
17 | 17 |
|
18 | 18 | import inflect |
19 | 19 | import sqlalchemy |
|
38 | 38 | TypeDecorator, |
39 | 39 | UniqueConstraint, |
40 | 40 | ) |
41 | | -from sqlalchemy.dialects.postgresql import JSONB |
| 41 | +from sqlalchemy.dialects.postgresql import DOMAIN, JSONB |
42 | 42 | from sqlalchemy.engine import Connection, Engine |
43 | 43 | from sqlalchemy.exc import CompileError |
44 | 44 | from sqlalchemy.sql.elements import TextClause |
@@ -228,6 +228,8 @@ def collect_imports_for_column(self, column: Column[Any]) -> None: |
228 | 228 | or column.type.astext_type.length is not None |
229 | 229 | ): |
230 | 230 | self.add_import(column.type.astext_type) |
| 231 | + elif isinstance(column.type, DOMAIN): |
| 232 | + self.add_import(column.type.data_type.__class__) |
231 | 233 |
|
232 | 234 | if column.default: |
233 | 235 | self.add_import(column.default) |
@@ -375,7 +377,7 @@ def render_table(self, table: Table) -> str: |
375 | 377 |
|
376 | 378 | args.append(self.render_constraint(constraint)) |
377 | 379 |
|
378 | | - for index in sorted(table.indexes, key=lambda i: i.name): |
| 380 | + for index in sorted(table.indexes, key=lambda i: cast(str, i.name)): |
379 | 381 | # One-column indexes should be rendered as index=True on columns |
380 | 382 | if len(index.columns) > 1 or not uses_default_name(index): |
381 | 383 | args.append(self.render_index(index)) |
@@ -467,7 +469,7 @@ def render_column( |
467 | 469 |
|
468 | 470 | if isinstance(column.server_default, DefaultClause): |
469 | 471 | kwargs["server_default"] = render_callable( |
470 | | - "text", repr(column.server_default.arg.text) |
| 472 | + "text", repr(cast(TextClause, column.server_default.arg).text) |
471 | 473 | ) |
472 | 474 | elif isinstance(column.server_default, Computed): |
473 | 475 | expression = str(column.server_default.sqltext) |
@@ -514,12 +516,18 @@ def render_column_type(self, coltype: object) -> str: |
514 | 516 |
|
515 | 517 | value = getattr(coltype, param.name, missing) |
516 | 518 | default = defaults.get(param.name, missing) |
| 519 | + if isinstance(value, TextClause): |
| 520 | + self.add_literal_import("sqlalchemy", "text") |
| 521 | + rendered_value = render_callable("text", repr(value.text)) |
| 522 | + else: |
| 523 | + rendered_value = repr(value) |
| 524 | + |
517 | 525 | if value is missing or value == default: |
518 | 526 | use_kwargs = True |
519 | 527 | elif use_kwargs: |
520 | | - kwargs[param.name] = repr(value) |
| 528 | + kwargs[param.name] = rendered_value |
521 | 529 | else: |
522 | | - args.append(repr(value)) |
| 530 | + args.append(rendered_value) |
523 | 531 |
|
524 | 532 | vararg = next( |
525 | 533 | ( |
@@ -1072,6 +1080,7 @@ def generate_relationship_name( |
1072 | 1080 | preferred_name = column_names[0][:-3] |
1073 | 1081 |
|
1074 | 1082 | if "use_inflect" in self.options: |
| 1083 | + inflected_name: str | Literal[False] |
1075 | 1084 | if relationship.type in ( |
1076 | 1085 | RelationshipType.ONE_TO_MANY, |
1077 | 1086 | RelationshipType.MANY_TO_MANY, |
@@ -1166,7 +1175,7 @@ def render_table_args(self, table: Table) -> str: |
1166 | 1175 | args.append(self.render_constraint(constraint)) |
1167 | 1176 |
|
1168 | 1177 | # Render indexes |
1169 | | - for index in sorted(table.indexes, key=lambda i: i.name): |
| 1178 | + for index in sorted(table.indexes, key=lambda i: cast(str, i.name)): |
1170 | 1179 | if len(index.columns) > 1 or not uses_default_name(index): |
1171 | 1180 | args.append(self.render_index(index)) |
1172 | 1181 |
|
|
0 commit comments