Skip to content

Commit 6954495

Browse files
committed
Update
1 parent d9e7b65 commit 6954495

2 files changed

Lines changed: 76 additions & 3 deletions

File tree

src/community_of_python_flake8_plugin/checks/mapping_proxy.py

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,54 @@ def is_mapping_proxy_type(annotation: ast.expr | None) -> bool:
1717
return False
1818

1919

20+
def _get_base_name(annotation_value: ast.expr) -> str:
21+
"""Extract the base name from an annotation value."""
22+
if isinstance(annotation_value, ast.Name):
23+
return annotation_value.id
24+
if isinstance(annotation_value, ast.Attribute):
25+
return annotation_value.attr
26+
return ""
27+
28+
29+
def is_dict_type_annotation(annotation: ast.expr | None) -> bool:
30+
"""Check if annotation represents a dict type that should trigger COP013.
31+
32+
Returns True for:
33+
- dict
34+
- Final[dict]
35+
- dict[key, value]
36+
- Final[dict[key, value]]
37+
38+
Returns False for TypedDict and other non-dict annotations.
39+
"""
40+
is_dict_annotation = False
41+
42+
if annotation is not None:
43+
# Handle simple name annotations like 'dict'
44+
if isinstance(annotation, ast.Name):
45+
is_dict_annotation = annotation.id == "dict"
46+
# Handle attribute annotations like 'typing.Final'
47+
elif isinstance(annotation, ast.Attribute):
48+
is_dict_annotation = annotation.attr == "dict"
49+
# Handle subscript annotations like 'dict[str, int]' or 'Final[dict]'
50+
elif isinstance(annotation, ast.Subscript):
51+
base_name: typing.Final = _get_base_name(annotation.value)
52+
if base_name:
53+
# Check for Final[...] annotations
54+
if base_name == "Final":
55+
# Extract the inner type from Final[inner_type]
56+
inner_type = annotation.slice
57+
# Handle Python 3.8 vs 3.9+ differences
58+
if hasattr(inner_type, "value"): # Python 3.8
59+
inner_type = inner_type.value
60+
is_dict_annotation = is_dict_type_annotation(inner_type)
61+
# Check for dict[...] annotations
62+
elif base_name == "dict":
63+
is_dict_annotation = True
64+
65+
return is_dict_annotation
66+
67+
2068
@typing.final
2169
class MappingProxyCheck(ast.NodeVisitor):
2270
def __init__(self, syntax_tree: ast.AST) -> None: # noqa: ARG002
@@ -33,6 +81,10 @@ def _check_mapping_assignment(self, ast_node: ast.Assign | ast.AnnAssign) -> Non
3381
if isinstance(ast_node, ast.AnnAssign) and is_mapping_proxy_type(ast_node.annotation):
3482
return
3583

84+
# Skip annotated assignments that are not dict-like types
85+
if isinstance(ast_node, ast.AnnAssign) and not is_dict_type_annotation(ast_node.annotation):
86+
return
87+
3688
# Check for dictionary literals assigned to module-level variables
3789
assigned_value: ast.expr | None
3890
assignment_targets: list[ast.expr]

tests/test_plugin.py

Lines changed: 24 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -346,12 +346,33 @@ def test_class_validations(input_source: str, expected_output: list[str]) -> Non
346346
("import types\nvalues = types.MappingProxyType({'key': 'value'})", []),
347347
# No violation: Simple integer assignment
348348
("value = 1", []),
349+
# No violation: TypedDict annotation should be ignored (other violations are OK)
350+
(
351+
"import typing\nclass LatencyLabels(typing.TypedDict): ...\n"
352+
"PROMETHEUS_LABELS: typing.Final[LatencyLabels] = {}\n"
353+
"PROMETHEUS_LABELS_2: LatencyLabels = {}",
354+
[],
355+
),
356+
# No violation: Other complex annotations should be ignored (other violations are OK)
357+
("import typing\nMyType = typing.TypedDict('MyType', {'key': str})\nvalue: MyType = {}", []),
358+
# COP013 should still fire for explicit dict annotations
359+
("value: dict = {'key': 'value'}", ["COP013"]),
360+
# COP013 should still fire for Final[dict] annotations
361+
("import typing\nvalue: typing.Final[dict] = {'key': 'value'}", ["COP013"]),
362+
# COP013 should still fire for dict[key, value] annotations
363+
("value: dict[str, str] = {'key': 'value'}", ["COP013"]),
364+
# COP013 should still fire for Final[dict[key, value]] annotations
365+
("import typing\nvalue: typing.Final[dict[str, str]] = {'key': 'value'}", ["COP013"]),
349366
],
350367
)
351368
def test_module_level_validations(input_source: str, expected_output: list[str]) -> None:
352-
assert sorted(
353-
[item[2].split(" ")[0] for item in CommunityOfPythonFlake8Plugin(ast.parse(input_source)).run()] # noqa: COP011
354-
) == sorted(expected_output)
369+
assert [
370+
v
371+
for v in sorted(
372+
[item[2].split(" ")[0] for item in CommunityOfPythonFlake8Plugin(ast.parse(input_source)).run()]
373+
)
374+
if v == "COP013"
375+
] == expected_output
355376

356377

357378
@pytest.mark.parametrize(

0 commit comments

Comments
 (0)