@@ -61,26 +61,58 @@ def visit_Constant(self, node: ast.Constant) -> None: # noqa: D102
6161 else :
6262 self .generic_visit (node )
6363
64- def rewrite_quotes_for_node (self , node : Union [ast .Str , ast .Constant ]):
64+ def visit_definition (self , node : Union [ast .ClassDef , ast .FunctionDef , ast .AsyncFunctionDef ]) -> None :
65+ """
66+ Mark the docstring of the function or class to identify it later.
67+
68+ :param node:
69+ """
70+
71+ if node .body and isinstance (node .body [0 ], ast .Expr ):
72+ doc_node = node .body [0 ].value
73+ doc_node .is_docstring = True # type: ignore
74+
75+ self .generic_visit (node )
76+
77+ def visit_ClassDef (self , node : ast .ClassDef ) -> None : # noqa: D102
78+ self .visit_definition (node )
79+
80+ def visit_FunctionDef (self , node : ast .FunctionDef ) -> None : # noqa: D102
81+ self .visit_definition (node )
82+
83+ def visit_ASyncFunctionDef (self , node : ast .AsyncFunctionDef ) -> None : # noqa: D102
84+ self .visit_definition (node )
85+
86+ def rewrite_quotes_for_node (self , node : Union [ast .Str , ast .Constant ]) -> None :
87+ """
88+ Mark the area for rewriting quotes in the given node.
89+
90+ :param node:
91+ """
92+
6593 text_range = self .tokens .get_text_range (node )
6694
6795 if text_range == (0 , 0 ):
6896 return
6997
7098 string = self .source [text_range [0 ]:text_range [1 ]]
7199
72- if string in {'""' , "''" }:
73- self .record_replacement (text_range , "''" )
74- elif not re .match ("^[\" ']" , string ):
75- return
76- elif len (node .s ) == 1 :
77- self .record_replacement (text_range , repr (node .s ))
78- elif '\n ' in string :
79- return
80- elif '\n ' in node .s or "\\ n" in node .s :
100+ if getattr (node , "is_docstring" , False ):
101+ # TODO: format docstring with triple quotes and correct indentation
81102 return
82103 else :
83- self .record_replacement (text_range , double_repr_string (node .s ))
104+ if string in {'""' , "''" }:
105+ self .record_replacement (text_range , "''" )
106+ elif not re .match ("^[\" ']" , string ):
107+ return
108+ elif len (node .s ) == 1 :
109+ self .record_replacement (text_range , repr (node .s ))
110+ elif '\n ' in string :
111+ return
112+ elif '\n ' in node .s or "\\ n" in node .s :
113+ return
114+ else :
115+ self .record_replacement (text_range , double_repr_string (node .s ))
84116
85117
86118def dynamic_quotes (source : str ) -> str :
0 commit comments