|
1 | 1 | import inspect |
| 2 | +import sys |
2 | 3 | from collections import defaultdict, deque |
3 | 4 | from graphlib import TopologicalSorter |
4 | | -from typing import ( |
5 | | - Any, |
6 | | - Callable, |
7 | | - Dict, |
8 | | - ForwardRef, |
9 | | - List, |
10 | | - Optional, |
11 | | - TypeVar, |
12 | | - get_type_hints, |
13 | | -) |
| 5 | +from typing import Any, Callable, Dict, List, Optional, TypeVar, get_type_hints |
14 | 6 |
|
15 | 7 | from taskiq_dependencies.ctx import AsyncResolveContext, SyncResolveContext |
16 | 8 | from taskiq_dependencies.dependency import Dependency |
@@ -115,6 +107,12 @@ def _build_graph(self) -> None: # noqa: C901, WPS210 |
115 | 107 | :raises ValueError: if something happened. |
116 | 108 | """ |
117 | 109 | dep_deque = deque([Dependency(self.target, use_cache=True)]) |
| 110 | + # This is for `from __future__ import annotations` support. |
| 111 | + # We need to use `eval_str` argument, because |
| 112 | + # signature of the function is a string, not an object. |
| 113 | + signature_kwargs: Dict[str, Any] = {} |
| 114 | + if sys.version_info >= (3, 10): |
| 115 | + signature_kwargs["eval_str"] = True |
118 | 116 |
|
119 | 117 | while dep_deque: |
120 | 118 | dep = dep_deque.popleft() |
@@ -164,38 +162,31 @@ def _build_graph(self) -> None: # noqa: C901, WPS210 |
164 | 162 | # If this is a class, we need to get signature of |
165 | 163 | # an __init__ method. |
166 | 164 | hints = get_type_hints(origin.__init__) # noqa: WPS609 |
167 | | - sign = inspect.signature(origin.__init__) # noqa: WPS609 |
| 165 | + sign = inspect.signature( |
| 166 | + origin.__init__, # noqa: WPS609 |
| 167 | + **signature_kwargs, |
| 168 | + ) |
168 | 169 | elif inspect.isfunction(dep.dependency): |
169 | 170 | # If this is function or an instance of a class, we get it's type hints. |
170 | 171 | hints = get_type_hints(dep.dependency) |
171 | | - sign = inspect.signature(origin) # type: ignore |
| 172 | + sign = inspect.signature(origin, **signature_kwargs) # type: ignore |
172 | 173 | else: |
173 | 174 | hints = get_type_hints( |
174 | 175 | dep.dependency.__call__, # type: ignore # noqa: WPS609 |
175 | 176 | ) |
176 | | - sign = inspect.signature(origin) # type: ignore |
| 177 | + sign = inspect.signature(origin, **signature_kwargs) # type: ignore |
177 | 178 |
|
178 | 179 | # Now we need to iterate over parameters, to |
179 | 180 | # find all parameters, that have TaskiqDepends as it's |
180 | 181 | # default vaule. |
181 | 182 | for param_name, param in sign.parameters.items(): |
182 | 183 | default_value = param.default |
183 | | - annotation = param.annotation |
184 | | - if isinstance(param.annotation, str): |
185 | | - globalns = getattr(origin, "__globals__", {}) |
186 | | - annotation = ForwardRef( # type: ignore # noqa: WPS437 |
187 | | - param.annotation, |
188 | | - )._evaluate( |
189 | | - globalns, |
190 | | - None, |
191 | | - set(), |
192 | | - ) |
193 | | - if hasattr(annotation, "__metadata__"): # noqa: WPS421 |
| 184 | + if hasattr(param.annotation, "__metadata__"): # noqa: WPS421 |
194 | 185 | # We go backwards, |
195 | 186 | # because you may want to override your annotation |
196 | 187 | # and the overriden value will appear to be after |
197 | 188 | # the original `Depends` annotation. |
198 | | - for meta in reversed(annotation.__metadata__): |
| 189 | + for meta in reversed(param.annotation.__metadata__): |
199 | 190 | if isinstance(meta, Dependency): |
200 | 191 | default_value = meta |
201 | 192 | break |
|
0 commit comments