|
10 | 10 |
|
11 | 11 | import libcst as cst |
12 | 12 |
|
13 | | -from ._utils import accumulate_qualname |
| 13 | +from ._utils import accumulate_qualname, module_name_from_path |
14 | 14 |
|
15 | 15 | logger = logging.getLogger(__name__) |
16 | 16 |
|
@@ -42,16 +42,21 @@ def _shared_leading_path(*paths): |
42 | 42 | class KnownImport: |
43 | 43 | """Import information associated with a single known type annotation. |
44 | 44 |
|
45 | | - Parameters |
| 45 | + Attributes |
46 | 46 | ---------- |
47 | | - import_name : |
48 | | - Dotted names after "import". |
49 | | - import_path : |
| 47 | + import_path : str, optional |
50 | 48 | Dotted names after "from". |
51 | | - import_alias : |
| 49 | + import_name : str, optional |
| 50 | + Dotted names after "import". |
| 51 | + import_alias : str, optional |
52 | 52 | Name (without ".") after "as". |
53 | | - builtin_name : |
| 53 | + builtin_name : str, optional |
54 | 54 | Names an object that's builtin and doesn't need an import. |
| 55 | +
|
| 56 | + Examples |
| 57 | + -------- |
| 58 | + >>> KnownImport(import_path="numpy", import_name="uint8", import_alias="ui8") |
| 59 | + <KnownImport 'from numpy import uint8 as ui8'> |
55 | 60 | """ |
56 | 61 |
|
57 | 62 | import_name: str = None |
@@ -170,14 +175,6 @@ def __str__(self): |
170 | 175 | return out |
171 | 176 |
|
172 | 177 |
|
173 | | -@dataclass(slots=True, frozen=True) |
174 | | -class InspectionContext: |
175 | | - """Currently inspected module and other information.""" |
176 | | - |
177 | | - file_path: Path |
178 | | - in_package_path: str |
179 | | - |
180 | | - |
181 | 178 | def _is_type(value) -> bool: |
182 | 179 | """Check if value is a type.""" |
183 | 180 | # Checking for isinstance(..., type) isn't enough, some types such as |
@@ -262,45 +259,57 @@ def common_known_imports(): |
262 | 259 | return known_imports |
263 | 260 |
|
264 | 261 |
|
265 | | -class KnownImportCollector(cst.CSTVisitor): |
| 262 | +class TypeCollector(cst.CSTVisitor): |
266 | 263 | @classmethod |
267 | | - def collect(cls, file, module_name): |
| 264 | + def collect(cls, file): |
| 265 | + """Collect importable type annotations in given file. |
| 266 | +
|
| 267 | + Parameters |
| 268 | + ---------- |
| 269 | + file : Path |
| 270 | +
|
| 271 | + Returns |
| 272 | + ------- |
| 273 | + collected : dict[str, KnownImport] |
| 274 | + """ |
268 | 275 | file = Path(file) |
269 | 276 | with file.open("r") as fo: |
270 | 277 | source = fo.read() |
271 | 278 |
|
272 | 279 | tree = cst.parse_module(source) |
273 | | - collector = cls(module_name=module_name) |
| 280 | + collector = cls(module_name=module_name_from_path(file)) |
274 | 281 | tree.visit(collector) |
275 | 282 | return collector.known_imports |
276 | 283 |
|
277 | 284 | def __init__(self, *, module_name): |
| 285 | + """Initialize type collector. |
| 286 | +
|
| 287 | + Parameters |
| 288 | + ---------- |
| 289 | + module_name : str |
| 290 | + """ |
278 | 291 | self.module_name = module_name |
279 | 292 | self._stack = [] |
280 | 293 | self.known_imports = {} |
281 | 294 |
|
282 | | - def visit_ClassDef(self, node): |
| 295 | + def visit_ClassDef(self, node: cst.ClassDef) -> bool: |
283 | 296 | self._stack.append(node.name.value) |
284 | 297 |
|
285 | 298 | class_name = ".".join(self._stack[:1]) |
286 | 299 | qualname = f"{self.module_name}.{'.'.join(self._stack)}" |
287 | | - |
288 | | - known_import = KnownImport( |
289 | | - import_name=class_name, |
290 | | - import_path=self.module_name, |
291 | | - ) |
| 300 | + known_import = KnownImport(import_path=self.module_name, import_name=class_name) |
292 | 301 | self.known_imports[qualname] = known_import |
293 | 302 |
|
294 | 303 | return True |
295 | 304 |
|
296 | | - def leave_ClassDef(self, original_node): |
| 305 | + def leave_ClassDef(self, original_node: cst.ClassDef) -> None: |
297 | 306 | self._stack.pop() |
298 | 307 |
|
299 | | - def visit_FunctionDef(self, node): |
| 308 | + def visit_FunctionDef(self, node: cst.FunctionDef) -> bool: |
300 | 309 | self._stack.append(node.name.value) |
301 | 310 | return True |
302 | 311 |
|
303 | | - def leave_FunctionDef(self, original_node): |
| 312 | + def leave_FunctionDef(self, original_node: cst.FunctionDef) -> None: |
304 | 313 | self._stack.pop() |
305 | 314 |
|
306 | 315 |
|
@@ -395,7 +404,8 @@ def query(self, search_name): |
395 | 404 |
|
396 | 405 | if known_import is None and self.current_source: |
397 | 406 | # Try scope of current module |
398 | | - try_qualname = f"{self.current_source.import_path}.{search_name}" |
| 407 | + module_name = module_name_from_path(self.current_source) |
| 408 | + try_qualname = f"{module_name}.{search_name}" |
399 | 409 | known_import = self.known_imports.get(try_qualname) |
400 | 410 | if known_import: |
401 | 411 | annotation_name = search_name |
|
0 commit comments