Skip to content

Commit 39bd913

Browse files
committed
Update
1 parent b53366e commit 39bd913

11 files changed

Lines changed: 84 additions & 123 deletions

File tree

src/community_of_python_flake8_plugin/checks/dataclass_config.py

Lines changed: 19 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import typing
44

55
from community_of_python_flake8_plugin.constants import FINAL_CLASS_EXCLUDED_BASES
6+
from community_of_python_flake8_plugin.utils import check_inherits_from_bases
67
from community_of_python_flake8_plugin.violation_codes import ViolationCodes
78
from community_of_python_flake8_plugin.violations import Violation
89

@@ -23,33 +24,23 @@ def has_required_dataclass_params(decorator: ast.expr) -> bool:
2324
if not isinstance(decorator, ast.Call):
2425
return False
2526

26-
# Check if all required parameters are present
27+
keywords: typing.Final = {kw.arg: kw.value for kw in decorator.keywords if isinstance(kw.value, ast.Constant)}
28+
kw_only_param: typing.Final = keywords.get("kw_only")
29+
slots_param: typing.Final = keywords.get("slots")
30+
frozen_param: typing.Final = keywords.get("frozen")
2731
return (
28-
any(
29-
keyword.arg == "kw_only" and isinstance(keyword.value, ast.Constant) and keyword.value.value is True
30-
for keyword in decorator.keywords
31-
)
32-
and any(
33-
keyword.arg == "slots" and isinstance(keyword.value, ast.Constant) and keyword.value.value is True
34-
for keyword in decorator.keywords
35-
)
36-
and any(
37-
keyword.arg == "frozen" and isinstance(keyword.value, ast.Constant) and keyword.value.value is True
38-
for keyword in decorator.keywords
39-
)
32+
kw_only_param is not None
33+
and isinstance(kw_only_param, ast.Constant)
34+
and kw_only_param.value is True
35+
and slots_param is not None
36+
and isinstance(slots_param, ast.Constant)
37+
and slots_param.value is True
38+
and frozen_param is not None
39+
and isinstance(frozen_param, ast.Constant)
40+
and frozen_param.value is True
4041
)
4142

4243

43-
def is_inherited_from_whitelisted_class(class_node: ast.ClassDef) -> bool:
44-
"""Check if class inherits from whitelisted base classes."""
45-
for base_class in class_node.bases:
46-
if isinstance(base_class, ast.Name) and base_class.id in FINAL_CLASS_EXCLUDED_BASES:
47-
return True
48-
if isinstance(base_class, ast.Attribute) and base_class.attr in FINAL_CLASS_EXCLUDED_BASES:
49-
return True
50-
return False
51-
52-
5344
def is_pydantic_model(class_node: ast.ClassDef) -> bool:
5445
"""Check if class inherits from Pydantic BaseModel or RootModel."""
5546
for base_class in class_node.bases:
@@ -78,7 +69,7 @@ def __init__(self, syntax_tree: ast.AST) -> None: # noqa: ARG002
7869
def visit_ClassDef(self, ast_node: ast.ClassDef) -> None:
7970
# Skip whitelisted classes and classes that inherit from Exception or other special classes
8071
if (
81-
is_inherited_from_whitelisted_class(ast_node)
72+
check_inherits_from_bases(ast_node, FINAL_CLASS_EXCLUDED_BASES)
8273
or is_pydantic_model(ast_node)
8374
or is_model_factory(ast_node)
8475
or self._inherits_from_exception(ast_node)
@@ -104,8 +95,8 @@ def visit_ClassDef(self, ast_node: ast.ClassDef) -> None:
10495
def _inherits_from_exception(self, ast_node: ast.ClassDef) -> bool:
10596
"""Check if class inherits from Exception or its subclasses."""
10697
for base in ast_node.bases:
107-
if (isinstance(base, ast.Name) and ("Error" in base.id or "Exception" in base.id)) or (
108-
isinstance(base, ast.Attribute) and ("Error" in base.attr or "Exception" in base.attr)
109-
):
98+
if isinstance(base, ast.Name) and ("Error" in base.id or "Exception" in base.id):
11099
return True
111-
return len(ast_node.bases) > 0 # Skip all classes that inherit from anything
100+
if isinstance(base, ast.Attribute) and ("Error" in base.attr or "Exception" in base.attr):
101+
return True
102+
return False

src/community_of_python_flake8_plugin/checks/function_verb.py

Lines changed: 4 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import typing
44

55
from community_of_python_flake8_plugin.constants import FINAL_CLASS_EXCLUDED_BASES, VERB_PREFIXES
6-
from community_of_python_flake8_plugin.utils import find_parent_class_definition
6+
from community_of_python_flake8_plugin.utils import check_inherits_from_bases, find_parent_class_definition
77
from community_of_python_flake8_plugin.violation_codes import ViolationCodes
88
from community_of_python_flake8_plugin.violations import Violation
99

@@ -50,32 +50,26 @@ def check_is_fixture_decorator(decorator: ast.expr) -> bool:
5050
return False
5151

5252

53-
def retrieve_parent_class(syntax_tree: ast.AST, ast_node: ast.AST) -> ast.ClassDef | None:
54-
return find_parent_class_definition(syntax_tree, ast_node)
55-
56-
5753
@typing.final
5854
class FunctionVerbCheck(ast.NodeVisitor):
5955
def __init__(self, syntax_tree: ast.AST) -> None:
6056
self.violations: list[Violation] = []
6157
self.syntax_tree: typing.Final[ast.AST] = syntax_tree
6258

6359
def visit_FunctionDef(self, ast_node: ast.FunctionDef) -> None:
64-
parent_class: typing.Final = retrieve_parent_class(self.syntax_tree, ast_node) # noqa: COP011
65-
self.validate_function_name(ast_node, parent_class)
60+
self.validate_function_name(ast_node, find_parent_class_definition(self.syntax_tree, ast_node))
6661
self.generic_visit(ast_node)
6762

6863
def visit_AsyncFunctionDef(self, ast_node: ast.AsyncFunctionDef) -> None:
69-
parent_class: typing.Final = retrieve_parent_class(self.syntax_tree, ast_node) # noqa: COP011
70-
self.validate_function_name(ast_node, parent_class)
64+
self.validate_function_name(ast_node, find_parent_class_definition(self.syntax_tree, ast_node))
7165
self.generic_visit(ast_node)
7266

7367
def validate_function_name(
7468
self, ast_node: ast.FunctionDef | ast.AsyncFunctionDef, parent_class: ast.ClassDef | None
7569
) -> None:
7670
if (
7771
check_is_ignored_name(ast_node.name)
78-
or (parent_class and self.check_inherits_from_whitelisted_class(parent_class))
72+
or (parent_class and check_inherits_from_bases(parent_class, FINAL_CLASS_EXCLUDED_BASES))
7973
or check_is_property(ast_node)
8074
or check_is_pytest_fixture(ast_node)
8175
or check_is_verb_name(ast_node.name)
@@ -89,11 +83,3 @@ def validate_function_name(
8983
violation_code=ViolationCodes.FUNCTION_VERB,
9084
)
9185
)
92-
93-
def check_inherits_from_whitelisted_class(self, ast_node: ast.ClassDef) -> bool:
94-
for base_class in ast_node.bases:
95-
if isinstance(base_class, ast.Name) and base_class.id in FINAL_CLASS_EXCLUDED_BASES:
96-
return True
97-
if isinstance(base_class, ast.Attribute) and base_class.attr in FINAL_CLASS_EXCLUDED_BASES:
98-
return True
99-
return False

src/community_of_python_flake8_plugin/checks/module_import_many_names.py

Lines changed: 5 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -53,14 +53,11 @@ def validate_import_size(self, ast_node: ast.ImportFrom) -> None:
5353
if module_name is not None and module_name.endswith(".settings"):
5454
return
5555

56-
module_import_exists: typing.Final = any( # noqa: COP011
57-
isinstance(identifier, ast.alias)
58-
and module_name is not None
59-
and check_module_path_exists(f"{module_name}.{identifier.name}")
60-
for identifier in ast_node.names
61-
)
62-
63-
if not module_import_exists:
56+
if not any(
57+
check_module_path_exists(f"{module_name}.{alias.name}")
58+
for alias in ast_node.names
59+
if isinstance(alias, ast.alias) and module_name is not None
60+
):
6461
self.violations.append(
6562
Violation(
6663
line_number=ast_node.lineno,

src/community_of_python_flake8_plugin/checks/name_length.py

Lines changed: 3 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import typing
44

55
from community_of_python_flake8_plugin.constants import FINAL_CLASS_EXCLUDED_BASES, MIN_NAME_LENGTH
6-
from community_of_python_flake8_plugin.utils import find_parent_class_definition
6+
from community_of_python_flake8_plugin.utils import check_inherits_from_bases, find_parent_class_definition
77
from community_of_python_flake8_plugin.violation_codes import ViolationCodes
88
from community_of_python_flake8_plugin.violations import Violation
99

@@ -44,15 +44,6 @@ def check_is_fixture_decorator(decorator: ast.expr) -> bool:
4444
return False
4545

4646

47-
def check_inherits_from_whitelisted_class(class_node: ast.ClassDef) -> bool:
48-
for base_class in class_node.bases:
49-
if isinstance(base_class, ast.Name) and base_class.id in FINAL_CLASS_EXCLUDED_BASES:
50-
return True
51-
if isinstance(base_class, ast.Attribute) and base_class.attr in FINAL_CLASS_EXCLUDED_BASES:
52-
return True
53-
return False
54-
55-
5647
@typing.final
5748
class COP004NameLengthCheck(ast.NodeVisitor):
5849
def __init__(self, tree: ast.AST) -> None: # noqa: COP006
@@ -95,7 +86,7 @@ def validate_name_length(self, identifier: str, ast_node: ast.stmt, parent_class
9586
if (
9687
parent_class
9788
and isinstance(ast_node, (ast.AnnAssign, ast.Assign))
98-
and check_inherits_from_whitelisted_class(parent_class)
89+
and check_inherits_from_bases(parent_class, FINAL_CLASS_EXCLUDED_BASES)
9990
):
10091
return
10192

@@ -121,7 +112,7 @@ def validate_function_name(
121112
return
122113
if check_is_ignored_name(ast_node.name):
123114
return
124-
if parent_class and check_inherits_from_whitelisted_class(parent_class):
115+
if parent_class and check_inherits_from_bases(parent_class, FINAL_CLASS_EXCLUDED_BASES):
125116
return
126117
if check_is_pytest_fixture(ast_node):
127118
return

src/community_of_python_flake8_plugin/checks/scalar_annotation.py

Lines changed: 8 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import typing
44

55
from community_of_python_flake8_plugin.constants import SCALAR_ANNOTATIONS
6-
from community_of_python_flake8_plugin.utils import find_parent_class_definition
6+
from community_of_python_flake8_plugin.utils import find_parent_class_definition, find_parent_node
77
from community_of_python_flake8_plugin.violation_codes import ViolationCodes
88
from community_of_python_flake8_plugin.violations import Violation
99

@@ -36,29 +36,18 @@ def check_is_scalar_annotation(annotation_node: ast.AST) -> bool:
3636
return False
3737

3838

39-
def find_parent_function(syntax_tree: ast.AST, target_node: ast.AST) -> ast.FunctionDef | ast.AsyncFunctionDef | None:
40-
for potential_parent in ast.walk(syntax_tree):
41-
if isinstance(potential_parent, (ast.FunctionDef, ast.AsyncFunctionDef)):
42-
for child_node in ast.walk(potential_parent): # noqa: COP011
43-
if child_node is target_node:
44-
return potential_parent
45-
return None
46-
47-
4839
@typing.final
4940
class ScalarAnnotationCheck(ast.NodeVisitor):
50-
def __init__(self, tree: ast.AST) -> None: # noqa: COP006
41+
def __init__(self, syntax_tree: ast.AST) -> None: # noqa: COP006
5142
self.violations: list[Violation] = []
52-
self.syntax_tree: typing.Final[ast.AST] = tree
43+
self.syntax_tree: typing.Final[ast.AST] = syntax_tree
5344

5445
def visit_AnnAssign(self, ast_node: ast.AnnAssign) -> None:
55-
if isinstance(ast_node.target, ast.Name):
56-
parent_class: typing.Final = find_parent_class_definition(self.syntax_tree, ast_node) # noqa: COP011
57-
parent_function: typing.Final = find_parent_function(self.syntax_tree, ast_node) # noqa: COP011
58-
in_class_body: typing.Final = parent_class is not None and parent_function is None
59-
60-
if not in_class_body:
61-
self.validate_scalar_annotation(ast_node)
46+
if isinstance(ast_node.target, ast.Name) and (
47+
find_parent_class_definition(self.syntax_tree, ast_node) is None
48+
or find_parent_node(self.syntax_tree, ast_node, (ast.FunctionDef, ast.AsyncFunctionDef)) is not None
49+
):
50+
self.validate_scalar_annotation(ast_node)
6251
self.generic_visit(ast_node)
6352

6453
def validate_scalar_annotation(self, ast_node: ast.AnnAssign) -> None:

src/community_of_python_flake8_plugin/checks/temp_var.py

Lines changed: 11 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -7,21 +7,8 @@
77
from community_of_python_flake8_plugin.violations import Violation
88

99

10-
def collect_variable_usage(function_node: ast.AST) -> dict[str, list[ast.Name]]:
10+
def collect_variable_usage_and_stores(function_node: ast.AST) -> tuple[dict[str, list[ast.Name]], set[str]]:
1111
variable_usage: typing.Final[defaultdict[str, list[ast.Name]]] = defaultdict(list)
12-
13-
@typing.final
14-
class VariableCollector(ast.NodeVisitor):
15-
def visit_Name(self, name_node: ast.Name) -> None:
16-
# Collect all name references (both loads and stores)
17-
variable_usage[name_node.id].append(name_node)
18-
self.generic_visit(name_node)
19-
20-
VariableCollector().visit(function_node)
21-
return dict(variable_usage)
22-
23-
24-
def collect_assignment_stores(function_node: ast.AST) -> set[str]:
2512
assigned_variable_names: typing.Final[set[str]] = set()
2613

2714
def extract_names(expression: ast.expr) -> typing.Iterable[str]:
@@ -32,7 +19,11 @@ def extract_names(expression: ast.expr) -> typing.Iterable[str]:
3219
yield from extract_names(elt)
3320

3421
@typing.final
35-
class AssignmentStoreCollector(ast.NodeVisitor):
22+
class UsageCollector(ast.NodeVisitor):
23+
def visit_Name(self, name_node: ast.Name) -> None:
24+
variable_usage[name_node.id].append(name_node)
25+
self.generic_visit(name_node)
26+
3627
def visit_Assign(self, assign_node: ast.Assign) -> None:
3728
for target in assign_node.targets:
3829
assigned_variable_names.update(extract_names(target))
@@ -46,8 +37,8 @@ def visit_AnnAssign(self, ann_assign_node: ast.AnnAssign) -> None:
4637
assigned_variable_names.update(extract_names(ann_assign_node.target))
4738
self.generic_visit(ann_assign_node)
4839

49-
AssignmentStoreCollector().visit(function_node)
50-
return assigned_variable_names
40+
UsageCollector().visit(function_node)
41+
return dict(variable_usage), assigned_variable_names
5142

5243

5344
@typing.final
@@ -64,22 +55,19 @@ def visit_AsyncFunctionDef(self, ast_node: ast.AsyncFunctionDef) -> None:
6455
self.generic_visit(ast_node)
6556

6657
def _check_temporary_variables(self, ast_node: ast.FunctionDef | ast.AsyncFunctionDef) -> None:
67-
# Flag only the first temporary variable to match test expectations
58+
usage_and_stores: typing.Final = collect_variable_usage_and_stores(ast_node)
6859
found_temporary_variable = False
6960

70-
for variable_name, usages in collect_variable_usage(ast_node).items():
71-
# Skip special variables
61+
for variable_name, usages in usage_and_stores[0].items():
7262
if variable_name.startswith("_") or variable_name in {"self", "cls"}:
7363
continue
7464

75-
# Flag variables that are assigned once (in assignment) and read once
7665
if (
7766
len([usage for usage in usages if isinstance(usage.ctx, ast.Store)]) == 1
7867
and len([usage for usage in usages if isinstance(usage.ctx, ast.Load)]) == 1
79-
and variable_name in collect_assignment_stores(ast_node)
68+
and variable_name in usage_and_stores[1]
8069
and not found_temporary_variable
8170
):
82-
# Find the first usage (likely the assignment) to report the violation
8371
first_usage = usages[0]
8472
self.violations.append(
8573
Violation(

src/community_of_python_flake8_plugin/plugin.py

Lines changed: 11 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -31,8 +31,17 @@ def __init__(self, tree: ast.AST) -> None: # noqa: COP006
3131
self.ast_syntax_tree: typing.Final[ast.AST] = tree
3232

3333
def run(self) -> Iterable[tuple[int, int, str, type[object]]]: # noqa: COP007
34-
checks_collection: typing.Final[list[PluginCheckProtocol]] = []
34+
for check_instance in self._collect_checks():
35+
for violation in check_instance.violations:
36+
yield (
37+
violation.line_number,
38+
violation.column_number,
39+
f"{violation.violation_code.code} {violation.violation_code.description}",
40+
type(self),
41+
)
3542

43+
def _collect_checks(self) -> list[PluginCheckProtocol]:
44+
checks_collection: typing.Final = []
3645
for _, module_name, _ in pkgutil.iter_modules(checks_module.__path__):
3746
imported_module = importlib.import_module(f"{checks_module.__name__}.{module_name}")
3847

@@ -42,12 +51,4 @@ def run(self) -> Iterable[tuple[int, int, str, type[object]]]: # noqa: COP007
4251
check_instance = attribute(self.ast_syntax_tree)
4352
check_instance.visit(self.ast_syntax_tree)
4453
checks_collection.append(check_instance)
45-
46-
for check_instance in checks_collection:
47-
for violation in check_instance.violations:
48-
yield (
49-
violation.line_number,
50-
violation.column_number,
51-
f"{violation.violation_code.code} {violation.violation_code.description}",
52-
type(self),
53-
)
54+
return checks_collection

src/community_of_python_flake8_plugin/utils.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,3 +9,21 @@ def find_parent_class_definition(syntax_tree: ast.AST, target_node: ast.AST) ->
99
if child_node is target_node:
1010
return potential_parent
1111
return None
12+
13+
14+
def check_inherits_from_bases(class_definition: ast.ClassDef, base_classes: set[str]) -> bool:
15+
for base_class in class_definition.bases:
16+
if isinstance(base_class, ast.Name) and base_class.id in base_classes:
17+
return True
18+
if isinstance(base_class, ast.Attribute) and base_class.attr in base_classes:
19+
return True
20+
return False
21+
22+
23+
def find_parent_node(syntax_tree: ast.AST, target_node: ast.AST, node_types: tuple[type, ...]) -> ast.AST | None:
24+
for potential_parent in ast.walk(syntax_tree):
25+
if isinstance(potential_parent, node_types):
26+
for child_node in ast.walk(potential_parent):
27+
if child_node is target_node:
28+
return potential_parent
29+
return None

src/community_of_python_flake8_plugin/violation_codes.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,8 @@
44

55

66
@typing.final
7-
@dataclasses.dataclass(frozen=True, kw_only=True, slots=True)
8-
class ViolationCodeItem:
7+
@dataclasses.dataclass(kw_only=True, slots=True, frozen=True)
8+
class ViolationCodeItem: # noqa: COP014
99
code: str # noqa: COP004
1010
description: str
1111

src/community_of_python_flake8_plugin/violations.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,8 @@
88

99

1010
@typing.final
11-
@dataclasses.dataclass(frozen=True, kw_only=True, slots=True)
12-
class Violation:
11+
@dataclasses.dataclass(kw_only=True, slots=True, frozen=True)
12+
class Violation: # noqa: COP014
1313
line_number: int
1414
column_number: int
1515
violation_code: ViolationCodeItem

0 commit comments

Comments
 (0)