|
1 | 1 | import inspect |
2 | 2 | from collections import defaultdict, deque |
3 | 3 | from graphlib import TopologicalSorter |
4 | | -from typing import Any, Callable, Dict, List, Optional, get_type_hints |
| 4 | +from typing import Any, Callable, Dict, List, Optional, TypeVar, get_type_hints |
5 | 5 |
|
6 | 6 | from taskiq_dependencies.ctx import AsyncResolveContext, SyncResolveContext |
7 | 7 | from taskiq_dependencies.dependency import Dependency |
@@ -103,18 +103,55 @@ def _build_graph(self) -> None: # noqa: C901, WPS210 |
103 | 103 | if dep.dependency is None: |
104 | 104 | continue |
105 | 105 | # Get signature and type hints. |
106 | | - sign = inspect.signature(dep.dependency) |
107 | | - if inspect.isclass(dep.dependency): |
| 106 | + origin = getattr(dep.dependency, "__origin__", None) |
| 107 | + if origin is None: |
| 108 | + origin = dep.dependency |
| 109 | + |
| 110 | + # If we found the typevar. |
| 111 | + # It means, that somebody depend on generic type. |
| 112 | + if isinstance(origin, TypeVar): |
| 113 | + if dep.parent is None: |
| 114 | + raise ValueError(f"Cannot resolve generic {dep.dependency}") |
| 115 | + parent_cls = dep.parent.dependency |
| 116 | + parent_cls_origin = getattr(parent_cls, "__origin__", None) |
| 117 | + # If we cannot find origin, than means, that we cannot resolve |
| 118 | + # generic parameters. So exiting. |
| 119 | + if parent_cls_origin is None: |
| 120 | + raise ValueError( |
| 121 | + f"Unknown generic argument {origin}. " |
| 122 | + + f"Please provide a type in param `{dep.parent.param_name}`" |
| 123 | + + f" of `{dep.parent.dependency}`", |
| 124 | + ) |
| 125 | + # We zip together names of parameters and the subsctituted values |
| 126 | + # In parameters we would see TypeVars in args |
| 127 | + # we would find actual classes. |
| 128 | + generics = zip( |
| 129 | + parent_cls_origin.__parameters__, |
| 130 | + parent_cls.__args__, # type: ignore |
| 131 | + ) |
| 132 | + for tvar, type_param in generics: |
| 133 | + # If we found the typevar we're currently try to resolve, |
| 134 | + # we need to find origin of the substituted class. |
| 135 | + if tvar == origin: |
| 136 | + dep.dependency = type_param |
| 137 | + origin = getattr(type_param, "__origin__", None) |
| 138 | + if origin is None: |
| 139 | + origin = type_param |
| 140 | + |
| 141 | + if inspect.isclass(origin): |
108 | 142 | # If this is a class, we need to get signature of |
109 | 143 | # an __init__ method. |
110 | | - hints = get_type_hints(dep.dependency.__init__) # noqa: WPS609 |
| 144 | + hints = get_type_hints(origin.__init__) # noqa: WPS609 |
| 145 | + sign = inspect.signature(origin.__init__) # noqa: WPS609 |
111 | 146 | elif inspect.isfunction(dep.dependency): |
112 | 147 | # If this is function or an instance of a class, we get it's type hints. |
113 | 148 | hints = get_type_hints(dep.dependency) |
| 149 | + sign = inspect.signature(origin) # type: ignore |
114 | 150 | else: |
115 | 151 | hints = get_type_hints( |
116 | 152 | dep.dependency.__call__, # type: ignore # noqa: WPS609 |
117 | 153 | ) |
| 154 | + sign = inspect.signature(origin) # type: ignore |
118 | 155 |
|
119 | 156 | # Now we need to iterate over parameters, to |
120 | 157 | # find all parameters, that have TaskiqDepends as it's |
@@ -172,6 +209,7 @@ def _build_graph(self) -> None: # noqa: C901, WPS210 |
172 | 209 | use_cache=default_value.use_cache, |
173 | 210 | kwargs=default_value.kwargs, |
174 | 211 | signature=param, |
| 212 | + parent=dep, |
175 | 213 | ) |
176 | 214 | # Also we set the parameter name, |
177 | 215 | # it will help us in future when |
|
0 commit comments