Skip to content

Commit 2dbae85

Browse files
committed
Add support for SQLAlchemyFactory and generic ModelFactory types
1 parent 8aa4206 commit 2dbae85

4 files changed

Lines changed: 28 additions & 3 deletions

File tree

src/community_of_python_flake8_plugin/checks/dataclass_config.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -54,9 +54,9 @@ def is_pydantic_model(class_node: ast.ClassDef) -> bool:
5454
def is_model_factory(class_node: ast.ClassDef) -> bool:
5555
"""Check if class inherits from ModelFactory."""
5656
for base_class in class_node.bases:
57-
if isinstance(base_class, ast.Name) and base_class.id == "ModelFactory":
57+
if isinstance(base_class, ast.Name) and base_class.id in {"ModelFactory", "SQLAlchemyFactory"}:
5858
return True
59-
if isinstance(base_class, ast.Attribute) and base_class.attr == "ModelFactory":
59+
if isinstance(base_class, ast.Attribute) and base_class.attr in {"ModelFactory", "SQLAlchemyFactory"}:
6060
return True
6161
return False
6262

src/community_of_python_flake8_plugin/constants.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -108,6 +108,7 @@
108108
"ensure",
109109
"submit",
110110
"clear",
111+
"undo",
111112
}
112113

113114
SCALAR_ANNOTATIONS: typing.Final = {"int", "str", "float", "bool", "bytes", "complex"}
@@ -116,5 +117,5 @@
116117

117118
ALLOWED_STDLIB_FROM_IMPORTS: typing.Final = {"collections.abc"}
118119

119-
FINAL_CLASS_EXCLUDED_BASES: typing.Final = {"BaseModel", "RootModel", "ModelFactory"}
120+
FINAL_CLASS_EXCLUDED_BASES: typing.Final = {"BaseModel", "RootModel", "ModelFactory", "SQLAlchemyFactory"}
120121
MAX_IMPORT_NAMES: typing.Final = 2

src/community_of_python_flake8_plugin/utils.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,12 @@ def check_inherits_from_bases(class_definition: ast.ClassDef, base_classes: set[
1717
return True
1818
if isinstance(base_class, ast.Attribute) and base_class.attr in base_classes:
1919
return True
20+
# Handle generic types like ModelFactory[SomeType]
21+
if isinstance(base_class, ast.Subscript):
22+
if isinstance(base_class.value, ast.Name) and base_class.value.id in base_classes:
23+
return True
24+
if isinstance(base_class.value, ast.Attribute) and base_class.value.attr in base_classes:
25+
return True
2026
return False
2127

2228

tests/test_plugin.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -187,6 +187,24 @@ def test_type_annotation_validations(input_source: str, expected_output: list[st
187187
"from polyfactory.factories.pydantic_factory import ModelFactory\nclass MyFactory(ModelFactory):\n def calculator(self): pass",
188188
["COP012"],
189189
),
190+
# No violation: ModelFactory generic methods should be exempt from COP009
191+
(
192+
"from polyfactory.factories.pydantic_factory import ModelFactory\nimport some_module\n"
193+
"class MyFactory(ModelFactory[some_module.SomeClass]):\n def calculator(self): pass",
194+
["COP012"],
195+
),
196+
# No violation: ModelFactory classmethod should be exempt from COP009
197+
(
198+
"from polyfactory.factories.pydantic_factory import ModelFactory\n"
199+
"class MyFactory(ModelFactory):\n @classmethod\n def create(cls): pass",
200+
["COP012"],
201+
),
202+
# No violation: ModelFactory generic classmethod should be exempt from COP009
203+
(
204+
"from polyfactory.factories.pydantic_factory import ModelFactory\nimport some_module\n"
205+
"class MyFactory(ModelFactory[some_module.SomeClass]):\n @classmethod\n def create(cls): pass",
206+
["COP012"],
207+
),
190208
# No violation: cached_property imported directly should exempt function from COP009 (but triggers COP002 for import style)
191209
(
192210
"from functools import cached_property\nclass ExampleClass:\n @cached_property\n def calculator(self): pass",

0 commit comments

Comments
 (0)