-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathfinal_class.py
More file actions
83 lines (68 loc) · 3.64 KB
/
final_class.py
File metadata and controls
83 lines (68 loc) · 3.64 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
from __future__ import annotations
import ast
import typing
from community_of_python_flake8_plugin.utils import check_inherits_from_bases
from community_of_python_flake8_plugin.violation_codes import ViolationCodes
from community_of_python_flake8_plugin.violations import Violation
def contains_final_decorator(class_node: ast.ClassDef) -> bool:
for one_decorator in class_node.decorator_list:
target_name = one_decorator.func if isinstance(one_decorator, ast.Call) else one_decorator
if isinstance(target_name, ast.Name) and target_name.id == "final":
return True
if isinstance(target_name, ast.Attribute) and target_name.attr == "final":
return True
return False
def is_protocol_class(class_node: ast.ClassDef) -> bool:
"""Check if the class directly inherits from typing.Protocol."""
for one_base in class_node.bases:
# Check for direct Protocol reference: class MyClass(Protocol):
if isinstance(one_base, ast.Name) and one_base.id == "Protocol":
return True
# Check for attributed Protocol reference: class MyClass(typing.Protocol):
if isinstance(one_base, ast.Attribute) and one_base.attr == "Protocol":
return True
# Check for subscripted Protocol reference: class MyClass(Protocol[SomeType]):
if isinstance(one_base, ast.Subscript):
if isinstance(one_base.value, ast.Name) and one_base.value.id == "Protocol":
return True
if isinstance(one_base.value, ast.Attribute) and one_base.value.attr == "Protocol":
return True
return False
def is_model_factory_class(class_node: ast.ClassDef) -> bool:
"""Check if the class inherits from ModelFactory or SQLAlchemyFactory."""
return check_inherits_from_bases(class_node, {"ModelFactory", "SQLAlchemyFactory"})
def has_local_subclasses(syntax_tree: ast.AST, class_node: ast.ClassDef) -> bool:
"""Check if there are classes in the same file that inherit from this class."""
for one_node in ast.walk(syntax_tree):
if isinstance(one_node, ast.ClassDef) and one_node != class_node:
for one_base in one_node.bases:
# Check for direct class reference: class Child(Parent):
if isinstance(one_base, ast.Name) and one_base.id == class_node.name:
return True
# Check for attributed class reference: class Child(module.Parent):
if isinstance(one_base, ast.Attribute) and one_base.attr == class_node.name:
return True
return False
@typing.final
class FinalClassCheck(ast.NodeVisitor):
def __init__(self, syntax_tree: ast.AST) -> None:
self.syntax_tree = syntax_tree
self.violations: list[Violation] = []
def visit_ClassDef(self, ast_node: ast.ClassDef) -> None:
self._check_final_decorator(ast_node)
self.generic_visit(ast_node)
def _check_final_decorator(self, ast_node: ast.ClassDef) -> None:
# Skip Protocol classes, test classes, and ModelFactory classes
if is_protocol_class(ast_node) or ast_node.name.startswith("Test") or is_model_factory_class(ast_node):
return
# If there are classes in this file that inherit from this class, don't require the decorator
if has_local_subclasses(self.syntax_tree, ast_node):
return
if not contains_final_decorator(ast_node):
self.violations.append(
Violation(
line_number=ast_node.lineno,
column_number=ast_node.col_offset,
violation_code=ViolationCodes.FINAL_CLASS,
)
)