Skip to content

Commit cd2c383

Browse files
authored
Detect inherited classes (typing.final rule), omit one_ prefix rule when iterating over range() (#2)
* Enhance final class check and for-loop validation logic * Update test cases for variable usage and dataclass validations * Rename function to clarify local subclass checking * Update test cases to properly validate final decorator usage in inheritance hierarchies * Refactor variable names and simplify range check logic
1 parent c7a4bcc commit cd2c383

3 files changed

Lines changed: 106 additions & 7 deletions

File tree

src/community_of_python_flake8_plugin/checks/final_class.py

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,9 +40,24 @@ def is_model_factory_class(class_node: ast.ClassDef) -> bool:
4040
return check_inherits_from_bases(class_node, {"ModelFactory", "SQLAlchemyFactory"})
4141

4242

43+
def has_local_subclasses(syntax_tree: ast.AST, class_node: ast.ClassDef) -> bool:
44+
"""Check if there are classes in the same file that inherit from this class."""
45+
for one_node in ast.walk(syntax_tree):
46+
if isinstance(one_node, ast.ClassDef) and one_node != class_node:
47+
for one_base in one_node.bases:
48+
# Check for direct class reference: class Child(Parent):
49+
if isinstance(one_base, ast.Name) and one_base.id == class_node.name:
50+
return True
51+
# Check for attributed class reference: class Child(module.Parent):
52+
if isinstance(one_base, ast.Attribute) and one_base.attr == class_node.name:
53+
return True
54+
return False
55+
56+
4357
@typing.final
4458
class FinalClassCheck(ast.NodeVisitor):
45-
def __init__(self, syntax_tree: ast.AST) -> None: # noqa: ARG002
59+
def __init__(self, syntax_tree: ast.AST) -> None:
60+
self.syntax_tree = syntax_tree
4661
self.violations: list[Violation] = []
4762

4863
def visit_ClassDef(self, ast_node: ast.ClassDef) -> None:
@@ -54,6 +69,10 @@ def _check_final_decorator(self, ast_node: ast.ClassDef) -> None:
5469
if is_protocol_class(ast_node) or ast_node.name.startswith("Test") or is_model_factory_class(ast_node):
5570
return
5671

72+
# If there are classes in this file that inherit from this class, don't require the decorator
73+
if has_local_subclasses(self.syntax_tree, ast_node):
74+
return
75+
5776
if not contains_final_decorator(ast_node):
5877
self.violations.append(
5978
Violation(

src/community_of_python_flake8_plugin/checks/for_loop_one_prefix.py

Lines changed: 26 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -22,14 +22,14 @@ def visit_ListComp(self, ast_node: ast.ListComp) -> None:
2222
# Validate targets in generators (the 'v' in 'for v in lst')
2323
for one_comprehension in ast_node.generators:
2424
if not self._is_partial_unpacking(ast_node.elt, one_comprehension.target):
25-
self._validate_comprehension_target(one_comprehension.target)
25+
self._validate_comprehension_target(one_comprehension.target, one_comprehension.iter)
2626
self.generic_visit(ast_node)
2727

2828
def visit_SetComp(self, ast_node: ast.SetComp) -> None:
2929
# Validate targets in generators (the 'v' in 'for v in lst')
3030
for one_comprehension in ast_node.generators:
3131
if not self._is_partial_unpacking(ast_node.elt, one_comprehension.target):
32-
self._validate_comprehension_target(one_comprehension.target)
32+
self._validate_comprehension_target(one_comprehension.target, one_comprehension.iter)
3333
self.generic_visit(ast_node)
3434

3535
def visit_DictComp(self, ast_node: ast.DictComp) -> None:
@@ -38,22 +38,22 @@ def visit_DictComp(self, ast_node: ast.DictComp) -> None:
3838
# key and value are both used
3939
for one_comprehension in ast_node.generators:
4040
if not self._is_partial_unpacking_expr_count(2, one_comprehension.target):
41-
self._validate_comprehension_target(one_comprehension.target)
41+
self._validate_comprehension_target(one_comprehension.target, one_comprehension.iter)
4242
self.generic_visit(ast_node)
4343

4444
def visit_For(self, ast_node: ast.For) -> None:
4545
# Validate target variables in regular for-loops
4646
# Apply same unpacking logic as comprehensions
4747
# For-loops don't have an expression that references vars
4848
if not self._is_partial_unpacking_expr_count(1, ast_node.target):
49-
self._validate_comprehension_target(ast_node.target)
49+
self._validate_comprehension_target(ast_node.target, ast_node.iter)
5050
self.generic_visit(ast_node)
5151

5252
def visit_GeneratorExp(self, ast_node: ast.GeneratorExp) -> None:
5353
# Validate targets in generators (the 'v' in 'for v in lst')
5454
for one_comprehension in ast_node.generators:
5555
if not self._is_partial_unpacking(ast_node.elt, one_comprehension.target):
56-
self._validate_comprehension_target(one_comprehension.target)
56+
self._validate_comprehension_target(one_comprehension.target, one_comprehension.iter)
5757
self.generic_visit(ast_node)
5858

5959
def _is_partial_unpacking(self, expression: ast.expr, target_node: ast.expr) -> bool:
@@ -65,6 +65,22 @@ def _is_partial_unpacking_expr_count(self, expression_count: int, target_node: a
6565
target_count: typing.Final = self._count_unpacked_vars(target_node)
6666
return target_count > expression_count and target_count > 1
6767

68+
def _is_literal_range(self, iter_node: ast.expr) -> bool:
69+
"""Check if the iteration is over a literal range() call."""
70+
# Check for direct range() call
71+
if isinstance(iter_node, ast.Call) and isinstance(iter_node.func, ast.Name) and iter_node.func.id == "range":
72+
# Check if all arguments are literals (no variables)
73+
for one_arg in iter_node.args:
74+
if not isinstance(one_arg, (ast.Constant, ast.UnaryOp)):
75+
# If any argument is not a literal, this is not a literal range
76+
# Note: UnaryOp is included to handle negative numbers like -1
77+
return False
78+
# For UnaryOp (like -1), check if operand is a literal
79+
if isinstance(one_arg, ast.UnaryOp) and not isinstance(one_arg.operand, ast.Constant):
80+
return False
81+
return True
82+
return False
83+
6884
def _count_referenced_vars(self, expression: ast.expr) -> int:
6985
"""Count how many variables are referenced in the expression."""
7086
if isinstance(expression, ast.Name):
@@ -83,8 +99,12 @@ def _count_unpacked_vars(self, target_node: ast.expr) -> int:
8399
return len([one_element for one_element in target_node.elts if isinstance(one_element, ast.Name)])
84100
return 0
85101

86-
def _validate_comprehension_target(self, target_node: ast.expr) -> None:
102+
def _validate_comprehension_target(self, target_node: ast.expr, iter_node: ast.expr | None = None) -> None:
87103
"""Validate that comprehension target follows the one_ prefix rule."""
104+
# Skip validation if iterating over literal range()
105+
if iter_node is not None and self._is_literal_range(iter_node):
106+
return
107+
88108
# Skip ignored targets (underscore, unpacking)
89109
if _is_ignored_target(target_node):
90110
return

tests/test_plugin.py

Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -422,6 +422,44 @@ def test_variable_usage_validations(input_source: str, expected_output: list[str
422422
"import typing\nclass MyProtocol(typing.Protocol, object):\n def fetch_value(self) -> int: ...\n",
423423
[],
424424
),
425+
# No violation: Child classes don't require @typing.final since they inherit from other classes
426+
(
427+
"class ParentClass:\n pass\n\nclass ChildClass(ParentClass):\n pass",
428+
["COP012"], # Only ParentClass should require final decorator, ChildClass inherits so it's exempt
429+
),
430+
# No violation: Multiple levels of inheritance
431+
(
432+
"class GrandParentClass:\n pass\n\nclass ParentClass(GrandParentClass):\n pass\n\n"
433+
"class ChildClass(ParentClass):\n pass",
434+
["COP012"], # Only GrandParentClass requires final decorator, ParentClass and
435+
# ChildClass inherit so they're exempt
436+
),
437+
# No violation: Child classes with module notation don't require @typing.final
438+
(
439+
"class ParentClass:\n pass\n\nclass ChildClass(module.ParentClass):\n pass",
440+
["COP012"], # Only ParentClass should require final decorator, ChildClass inherits so it's exempt
441+
),
442+
# No violation: Child class properly inherits, parent doesn't need final decorator
443+
(
444+
"import typing\n\nclass ParentClass:\n pass\n\n@typing.final\nclass ChildClass(ParentClass):\n pass",
445+
[], # No violations - ChildClass is properly marked final
446+
),
447+
# No violation: Complex inheritance hierarchy with proper final decorators
448+
(
449+
"import typing\n\n"
450+
"class BaseClass:\n pass\n\n"
451+
"class MiddleClass(BaseClass):\n pass\n\n"
452+
"@typing.final\nclass DerivedClass(MiddleClass):\n pass",
453+
[], # No violations - derived classes are properly marked final
454+
),
455+
# No violation: Multiple inheritance with proper final decorators
456+
(
457+
"import typing\n\n"
458+
"class FirstParent:\n pass\n\n"
459+
"class SecondParent:\n pass\n\n"
460+
"@typing.final\nclass ChildClass(FirstParent, SecondParent):\n pass",
461+
[], # No violations - ChildClass is properly marked final
462+
),
425463
],
426464
)
427465
def test_class_validations(input_source: str, expected_output: list[str]) -> None:
@@ -605,6 +643,28 @@ def test_dataclass_validations(input_source: str, expected_output: list[str]) ->
605643
("for x, y in pairs: pass", []),
606644
# No violation: Regular for-loop with one_ prefix
607645
("for one_x in some_list: pass", []),
646+
# No violation: Regular for-loop over literal range() without one_ prefix
647+
("for cur_number in range(10): pass", []),
648+
# No violation: Regular for-loop over literal range() with start and stop
649+
("for cur_number in range(5, 10): pass", []),
650+
# No violation: Regular for-loop over literal range() with start, stop, and step
651+
("for cur_number in range(0, 10, 2): pass", []),
652+
# No violation: Regular for-loop over literal range() with negative values
653+
("for cur_number in range(-5, 5): pass", []),
654+
# COP015: Regular for-loop over non-literal range() should still require one_ prefix
655+
("for cur_number in range(some_variable): pass", ["COP015"]),
656+
# COP015: Regular for-loop over non-literal range() with multiple variables should still require one_ prefix
657+
("for cur_number in range(start, stop): pass", ["COP015"]),
658+
# No violation: List comprehension over literal range() without one_ prefix
659+
("my_result = [cur_number for cur_number in range(10)]", []),
660+
# COP015: List comprehension over non-literal range() should still require one_ prefix
661+
("my_result = [cur_number for cur_number in range(variable)]", ["COP015"]),
662+
# No violation: Set comprehension over literal range() without one_ prefix
663+
("my_result = {cur_number for cur_number in range(10)}", []),
664+
# No violation: Dict comprehension over literal range() without one_ prefix
665+
("my_result = {cur_number: cur_number for cur_number in range(10)}", []),
666+
# No violation: Generator expression over literal range() without one_ prefix
667+
("my_result = (cur_number for cur_number in range(10))", []),
608668
],
609669
)
610670
def test_module_vs_class_level_assignments(input_source: str, expected_output: list[str]) -> None:

0 commit comments

Comments
 (0)