|
1 | 1 | import asyncio |
2 | 2 | import inspect |
3 | 3 | from copy import copy |
| 4 | +from logging import getLogger |
4 | 5 | from typing import TYPE_CHECKING, Any, Dict, Generator, List, Optional |
5 | 6 |
|
6 | 7 | from taskiq_dependencies.utils import ParamInfo |
|
9 | 10 | from taskiq_dependencies.graph import DependencyGraph # pragma: no cover |
10 | 11 |
|
11 | 12 |
|
| 13 | +logger = getLogger("taskiq.dependencies.ctx") |
| 14 | + |
| 15 | + |
12 | 16 | class BaseResolveContext: |
13 | 17 | """Base resolver context.""" |
14 | 18 |
|
15 | 19 | def __init__( |
16 | 20 | self, |
17 | 21 | graph: "DependencyGraph", |
18 | 22 | initial_cache: Optional[Dict[Any, Any]] = None, |
| 23 | + exception_propagation: bool = True, |
19 | 24 | ) -> None: |
20 | 25 | self.graph = graph |
21 | 26 | self.opened_dependencies: List[Any] = [] |
22 | 27 | self.sub_contexts: "List[Any]" = [] |
23 | 28 | self.initial_cache = initial_cache or {} |
| 29 | + self.propagate_excs = exception_propagation |
24 | 30 |
|
25 | 31 | def traverse_deps( # noqa: C901, WPS210 |
26 | 32 | self, |
@@ -116,18 +122,34 @@ def __enter__(self) -> "SyncResolveContext": |
116 | 122 | return self |
117 | 123 |
|
118 | 124 | def __exit__(self, *args: Any) -> None: |
119 | | - self.close() |
| 125 | + self.close(*args) |
120 | 126 |
|
121 | | - def close(self) -> None: |
| 127 | + def close(self, *args: Any) -> None: # noqa: C901 |
122 | 128 | """ |
123 | 129 | Close all opened dependencies. |
124 | 130 |
|
125 | 131 | This function runs teardown of all dependencies. |
| 132 | +
|
| 133 | + :param args: exception info if any. |
126 | 134 | """ |
| 135 | + exception_found = False |
| 136 | + if args[1] is not None and self.propagate_excs: |
| 137 | + exception_found = True |
127 | 138 | for ctx in self.sub_contexts: |
128 | | - ctx.close() |
| 139 | + ctx.close(*args) |
129 | 140 | for dep in reversed(self.opened_dependencies): |
130 | 141 | if inspect.isgenerator(dep): |
| 142 | + if exception_found: |
| 143 | + try: |
| 144 | + dep.throw(*args) |
| 145 | + except BaseException as exc: |
| 146 | + logger.warning( |
| 147 | + "Exception found on dependency teardown %s", |
| 148 | + exc, |
| 149 | + exc_info=True, |
| 150 | + ) |
| 151 | + continue |
| 152 | + continue |
131 | 153 | for _ in dep: # noqa: WPS328 |
132 | 154 | pass # noqa: WPS420 |
133 | 155 |
|
@@ -201,21 +223,48 @@ async def __aenter__(self) -> "AsyncResolveContext": |
201 | 223 | return self |
202 | 224 |
|
203 | 225 | async def __aexit__(self, *args: Any) -> None: |
204 | | - await self.close() |
| 226 | + await self.close(*args) |
205 | 227 |
|
206 | | - async def close(self) -> None: # noqa: C901 |
| 228 | + async def close(self, *args: Any) -> None: # noqa: C901 |
207 | 229 | """ |
208 | 230 | Close all opened dependencies. |
209 | 231 |
|
210 | 232 | This function runs teardown of all dependencies. |
| 233 | +
|
| 234 | + :param args: exception info if any. |
211 | 235 | """ |
| 236 | + exception_found = False |
| 237 | + if args[1] is not None and self.propagate_excs: |
| 238 | + exception_found = True |
212 | 239 | for ctx in self.sub_contexts: |
213 | | - await ctx.close() # type: ignore |
| 240 | + await ctx.close(*args) # type: ignore |
214 | 241 | for dep in reversed(self.opened_dependencies): |
215 | 242 | if inspect.isgenerator(dep): |
| 243 | + if exception_found: |
| 244 | + try: |
| 245 | + dep.throw(*args) |
| 246 | + except BaseException as exc: |
| 247 | + logger.warning( |
| 248 | + "Exception found on dependency teardown %s", |
| 249 | + exc, |
| 250 | + exc_info=True, |
| 251 | + ) |
| 252 | + continue |
| 253 | + continue |
216 | 254 | for _ in dep: # noqa: WPS328 |
217 | 255 | pass # noqa: WPS420 |
218 | 256 | elif inspect.isasyncgen(dep): |
| 257 | + if exception_found: |
| 258 | + try: |
| 259 | + await dep.athrow(*args) |
| 260 | + except BaseException as exc: |
| 261 | + logger.warning( |
| 262 | + "Exception found on dependency teardown %s", |
| 263 | + exc, |
| 264 | + exc_info=True, |
| 265 | + ) |
| 266 | + continue |
| 267 | + continue |
219 | 268 | async for _ in dep: # noqa: WPS328 |
220 | 269 | pass # noqa: WPS420 |
221 | 270 |
|
|
0 commit comments