Skip to content

Commit 97ad55b

Browse files
committed
Support Yields section and Generator functions
And refactor the DocstringAnnotations class in the process.
1 parent 04b17b0 commit 97ad55b

3 files changed

Lines changed: 235 additions & 66 deletions

File tree

src/docstub/_docstrings.py

Lines changed: 138 additions & 61 deletions
Original file line numberDiff line numberDiff line change
@@ -4,13 +4,12 @@
44
import traceback
55
from dataclasses import dataclass, field
66
from functools import cached_property
7-
from itertools import chain
87
from pathlib import Path
98

109
import click
1110
import lark
1211
import lark.visitors
13-
from numpydoc.docscrape import NumpyDocString
12+
import numpydoc.docscrape as npds
1413

1514
from ._analysis import KnownImport, TypesDatabase
1615
from ._utils import ContextFormatter, DocstubError, accumulate_qualname, escape_qualname
@@ -56,47 +55,68 @@ def __str__(self) -> str:
5655
return self.value
5756

5857
@classmethod
59-
def as_return_tuple(cls, return_types):
58+
def many_as_tuple(cls, types):
6059
"""Concatenate multiple annotations and wrap in tuple if more than one.
6160
6261
Useful to combine multiple returned types for a function into a single
6362
annotation.
6463
6564
Parameters
6665
----------
67-
return_types : Iterable[Annotation]
66+
types : Iterable[Annotation]
6867
The types to combine.
6968
7069
Returns
7170
-------
7271
concatenated : Annotation
7372
The concatenated types.
7473
"""
75-
values, imports = cls._aggregate_annotations(*return_types)
74+
values, imports = cls._aggregate_annotations(*types)
7675
value = ", ".join(values)
7776
if len(values) > 1:
7877
value = f"tuple[{value}]"
7978
concatenated = cls(value=value, imports=imports)
8079
return concatenated
8180

8281
@classmethod
83-
def as_yields_generator(cls, yield_types, receive_types=()):
84-
"""Create new iterator type from yield and receive types.
82+
def as_generator(cls, *, yield_types, receive_types=(), return_types=()):
83+
"""Create new ``Generator`` type from yield, receive and return types.
8584
8685
Parameters
8786
----------
8887
yield_types : Iterable[Annotation]
8988
The types to yield.
9089
receive_types : Iterable[Annotation], optional
9190
The types the generator receives.
91+
return_types : Iterable[Annotation], optional
92+
The types the generator function returns.
9293
9394
Returns
9495
-------
95-
iterator : Annotation
96-
The yielded and received types wrapped in a generator.
96+
generator : Annotation
97+
The provided types wrapped in a ``Generator``.
9798
"""
98-
# TODO
99-
raise NotImplementedError()
99+
yield_annotation = cls.many_as_tuple(yield_types)
100+
imports = yield_annotation.imports
101+
value = yield_annotation.value
102+
103+
if receive_types:
104+
receive_annotation = cls.many_as_tuple(receive_types)
105+
imports |= receive_annotation.imports
106+
value = f"{value}, {receive_annotation.value}"
107+
elif return_types:
108+
# Append None, so that return types are at correct position
109+
value = f"{value}, None"
110+
111+
if return_types:
112+
return_annotation = cls.many_as_tuple(return_types)
113+
imports |= return_annotation.imports
114+
value = f"{value}, {return_annotation.value}"
115+
116+
value = f"Generator[{value}]"
117+
imports |= {KnownImport(import_path="typing", import_name="Generator")}
118+
generator = cls(value=value, imports=imports)
119+
return generator
100120

101121
def as_optional(self):
102122
"""Return optional version of this annotation by appending `| None`.
@@ -110,6 +130,7 @@ def as_optional(self):
110130
>>> Annotation(value="int").as_optional()
111131
Annotation(value='int | None', imports=frozenset())
112132
"""
133+
# TODO account for `| None` or `Optional` already being included?
113134
value = f"{self.value} | None"
114135
optional = type(self)(value=value, imports=self.imports)
115136
return optional
@@ -418,6 +439,12 @@ def _match_import(self, qualname, *, meta):
418439
class DocstringAnnotations:
419440
"""Collect annotations in a given docstring.
420441
442+
Attributes
443+
----------
444+
docstring : str
445+
transformer : DoctypeTransformer
446+
ctx : ~.ContextFormatter
447+
421448
Examples
422449
--------
423450
>>> docstring = '''
@@ -450,12 +477,12 @@ def __init__(self, docstring, *, transformer, ctx=None):
450477
ctx : ~.ContextFormatter, optional
451478
"""
452479
self.docstring = docstring
453-
self.np_docstring = NumpyDocString(docstring)
480+
self.np_docstring = npds.NumpyDocString(docstring)
454481
self.transformer = transformer
455482

456483
if ctx is None:
457484
ctx = ContextFormatter(line=0)
458-
self._ctx: ContextFormatter = ctx
485+
self.ctx: ContextFormatter = ctx
459486

460487
def _doctype_to_annotation(self, doctype, ds_line=0):
461488
"""Convert a type description to a Python-ready type.
@@ -474,7 +501,7 @@ def _doctype_to_annotation(self, doctype, ds_line=0):
474501
The transformed type, ready to be inserted into a stub file, with
475502
necessary imports attached.
476503
"""
477-
ctx = self._ctx.with_line(offset=ds_line)
504+
ctx = self.ctx.with_line(offset=ds_line)
478505

479506
try:
480507
annotation, unknown_qualnames = self.transformer.doctype_to_annotation(
@@ -504,10 +531,18 @@ def _doctype_to_annotation(self, doctype, ds_line=0):
504531
return annotation
505532

506533
@cached_property
507-
def attributes(self) -> dict[str, Annotation]:
534+
def attributes(self):
535+
"""Return the attributes found in the docstring.
536+
537+
Returns
538+
-------
539+
attributes : dict[str, Annotation]
540+
A dictionary mapping attribute names to their annotations.
541+
Attributes without annotations fall back to :class:`_typeshed.Incomplete`.
542+
"""
508543
annotations = {}
509544
for attribute in self.np_docstring["Attributes"]:
510-
self._warn_missing_whitespace(attribute)
545+
self._handle_missing_whitespace(attribute)
511546
if not attribute.type:
512547
continue
513548

@@ -528,68 +563,78 @@ def attributes(self) -> dict[str, Annotation]:
528563

529564
@cached_property
530565
def parameters(self) -> dict[str, Annotation]:
531-
all_params = chain(
532-
self.np_docstring["Parameters"], self.np_docstring["Other Parameters"]
533-
)
534-
annotated_params = {}
535-
for param in all_params:
536-
self._warn_missing_whitespace(param)
537-
if not param.type:
538-
continue
539-
540-
ds_line = 0
541-
for i, line in enumerate(self.docstring.split("\n")):
542-
if param.name in line and param.type in line:
543-
ds_line = i
544-
break
566+
"""Return the parameters and "Other Parameters" found in the docstring.
545567
546-
if param.name in annotated_params:
547-
logger.warning("duplicate parameter name %r, ignoring", param.name)
548-
continue
568+
Returns
569+
-------
570+
parameters : dict[str, Annotation]
571+
A dictionary mapping parameters names to their annotations.
572+
Parameters without annotations fall back to :class:`_typeshed.Incomplete`.
573+
"""
574+
param_section = self._get_section("Parameters")
575+
other_section = self._get_section("Other Parameters")
549576

550-
annotation = self._doctype_to_annotation(param.type, ds_line=ds_line)
551-
name = param.name.strip(" *") # normalize *args & **kwargs
552-
annotated_params[name] = annotation
577+
duplicates = param_section.keys() & other_section.keys()
578+
for duplicate in duplicates:
579+
logger.warning("duplicate parameter name %r, ignoring", duplicate)
553580

554-
return annotated_params
581+
# Last takes priority
582+
paramaters = other_section | param_section
583+
# Normalize *args & **kwargs
584+
paramaters = {name.strip(" *"): value for name, value in paramaters.items()}
585+
return paramaters
555586

556587
@cached_property
557-
def returns(self) -> Annotation | None:
558-
annotated_params = {}
559-
for param in self.np_docstring["Returns"]:
560-
self._warn_missing_whitespace(param)
561-
# NumPyDoc always requires a doctype for returns,
562-
assert param.type
588+
def returns(self):
589+
"""Return the attributes found in the docstring.
563590
564-
ds_line = 0
565-
for i, line in enumerate(self.docstring.split("\n")):
566-
if param.name in line and param.type in line:
567-
ds_line = i
568-
break
569-
570-
if param.name in annotated_params:
571-
logger.warning("duplicate parameter name %r, ignoring", param.name)
572-
continue
573-
574-
annotation = self._doctype_to_annotation(param.type, ds_line=ds_line)
575-
annotated_params[param.name.strip()] = annotation
591+
Returns
592+
-------
593+
return_annotation : Annotation | None
594+
The "return" annotation of a callable. If the docstring defines a
595+
"Yield" section, this will be a :class:`typing.Generator`.
596+
"""
597+
out = self._yields or self._returns
598+
return out
576599

577-
if annotated_params:
578-
out = Annotation.as_return_tuple(annotated_params.values())
600+
@cached_property
601+
def _returns(self) -> Annotation | None:
602+
out = self._get_section("Returns")
603+
if out:
604+
out = Annotation.many_as_tuple(out.values())
579605
else:
580606
out = None
581607
return out
582608

583-
def _warn_missing_whitespace(self, param):
584-
"""Check for warning if a whitespace is missing between parameter and colon.
609+
@cached_property
610+
def _yields(self) -> Annotation | None:
611+
yields = self._get_section("Yields")
612+
if not yields:
613+
return None
614+
615+
receive_types = self._get_section("Receives")
616+
617+
out = Annotation.as_generator(
618+
yield_types=yields.values(),
619+
receive_types=receive_types.values(),
620+
return_types=(self._returns,) if self._returns else (),
621+
)
622+
return out
623+
624+
def _handle_missing_whitespace(self, param):
625+
"""Handle missing whitespace between parameter and colon.
585626
586627
In this case, NumPyDoc parses the entire thing as the parameter name and
587628
no annotation is detected. Since this typo can lead to very subtle & confusing
588-
bugs, let's warn users about it
629+
bugs, let's warn users about it and attempt to handle it.
589630
590631
Parameters
591632
----------
592633
param : numpydoc.docscrape.Parameter
634+
635+
Returns
636+
-------
637+
param : numpydoc.docscrape.Parameter
593638
"""
594639
if ":" in param.name and param.type == "":
595640
msg = (
@@ -604,5 +649,37 @@ def _warn_missing_whitespace(self, param):
604649
if param.name in line:
605650
ds_line = i
606651
break
607-
ctx = self._ctx.with_line(offset=ds_line)
652+
ctx = self.ctx.with_line(offset=ds_line)
608653
ctx.print_message(msg, details=hint)
654+
655+
new_name, new_type = param.name.split(":", maxsplit=1)
656+
param = npds.Parameter(name=new_name, type=new_type, desc=param.desc)
657+
658+
return param
659+
660+
def _get_section(self, name: str) -> dict[str, Annotation]:
661+
annotated_params = {}
662+
for param in self.np_docstring[name]:
663+
param = self._handle_missing_whitespace(param) # noqa: PLW2901
664+
665+
if param.name in annotated_params:
666+
# TODO make error
667+
logger.warning("duplicate parameter name %r, ignoring", param.name)
668+
continue
669+
670+
if param.type:
671+
ds_line = self._find_docstring_line(param.name, param.type)
672+
annotation = self._doctype_to_annotation(param.type, ds_line=ds_line)
673+
else:
674+
annotation = FallbackAnnotation
675+
annotated_params[param.name.strip()] = annotation
676+
677+
return annotated_params
678+
679+
def _find_docstring_line(self, *patterns):
680+
line_count = 0
681+
for i, line in enumerate(self.docstring.split("\n")):
682+
if all(p in line for p in patterns):
683+
line_count = i
684+
break
685+
return line_count

src/docstub/_stubs.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -75,8 +75,8 @@ def walk_source_and_targets(root_dir, target_dir):
7575
target_dir : Path
7676
Root directory in which a matching stub package will be created.
7777
78-
Returns
79-
-------
78+
Yields
79+
------
8080
source_path : Path
8181
Either a Python file or a stub file that takes precedence.
8282
stub_path : Path

0 commit comments

Comments
 (0)