|
1 | 1 | import asyncio |
2 | 2 | import inspect |
3 | 3 | from copy import copy |
| 4 | +from logging import getLogger |
4 | 5 | from typing import TYPE_CHECKING, Any, Dict, Generator, List, Optional |
5 | 6 |
|
6 | 7 | from taskiq_dependencies.utils import ParamInfo |
|
9 | 10 | from taskiq_dependencies.graph import DependencyGraph # pragma: no cover |
10 | 11 |
|
11 | 12 |
|
| 13 | +logger = getLogger("taskiq.dependencies.ctx") |
| 14 | + |
| 15 | + |
12 | 16 | class BaseResolveContext: |
13 | 17 | """Base resolver context.""" |
14 | 18 |
|
@@ -116,18 +120,34 @@ def __enter__(self) -> "SyncResolveContext": |
116 | 120 | return self |
117 | 121 |
|
118 | 122 | def __exit__(self, *args: Any) -> None: |
119 | | - self.close() |
| 123 | + self.close(*args) |
120 | 124 |
|
121 | | - def close(self) -> None: |
| 125 | + def close(self, *args: Any) -> None: # noqa: C901 |
122 | 126 | """ |
123 | 127 | Close all opened dependencies. |
124 | 128 |
|
125 | 129 | This function runs teardown of all dependencies. |
| 130 | +
|
| 131 | + :param args: exception info if any. |
126 | 132 | """ |
| 133 | + exception_found = False |
| 134 | + if args[1] is not None: |
| 135 | + exception_found = True |
127 | 136 | for ctx in self.sub_contexts: |
128 | | - ctx.close() |
| 137 | + ctx.close(*args) |
129 | 138 | for dep in reversed(self.opened_dependencies): |
130 | 139 | if inspect.isgenerator(dep): |
| 140 | + if exception_found: |
| 141 | + try: |
| 142 | + dep.throw(*args) |
| 143 | + except BaseException as exc: |
| 144 | + logger.warning( |
| 145 | + "Exception found on dependency teardown %s", |
| 146 | + exc, |
| 147 | + exc_info=True, |
| 148 | + ) |
| 149 | + continue |
| 150 | + continue |
131 | 151 | for _ in dep: # noqa: WPS328 |
132 | 152 | pass # noqa: WPS420 |
133 | 153 |
|
@@ -201,21 +221,48 @@ async def __aenter__(self) -> "AsyncResolveContext": |
201 | 221 | return self |
202 | 222 |
|
203 | 223 | async def __aexit__(self, *args: Any) -> None: |
204 | | - await self.close() |
| 224 | + await self.close(*args) |
205 | 225 |
|
206 | | - async def close(self) -> None: # noqa: C901 |
| 226 | + async def close(self, *args: Any) -> None: # noqa: C901 |
207 | 227 | """ |
208 | 228 | Close all opened dependencies. |
209 | 229 |
|
210 | 230 | This function runs teardown of all dependencies. |
| 231 | +
|
| 232 | + :param args: exception info if any. |
211 | 233 | """ |
| 234 | + exception_found = False |
| 235 | + if args[1] is not None: |
| 236 | + exception_found = True |
212 | 237 | for ctx in self.sub_contexts: |
213 | | - await ctx.close() # type: ignore |
| 238 | + await ctx.close(*args) # type: ignore |
214 | 239 | for dep in reversed(self.opened_dependencies): |
215 | 240 | if inspect.isgenerator(dep): |
| 241 | + if exception_found: |
| 242 | + try: |
| 243 | + dep.throw(*args) |
| 244 | + except BaseException as exc: |
| 245 | + logger.warning( |
| 246 | + "Exception found on dependency teardown %s", |
| 247 | + exc, |
| 248 | + exc_info=True, |
| 249 | + ) |
| 250 | + continue |
| 251 | + continue |
216 | 252 | for _ in dep: # noqa: WPS328 |
217 | 253 | pass # noqa: WPS420 |
218 | 254 | elif inspect.isasyncgen(dep): |
| 255 | + if exception_found: |
| 256 | + try: |
| 257 | + await dep.athrow(*args) |
| 258 | + except BaseException as exc: |
| 259 | + logger.warning( |
| 260 | + "Exception found on dependency teardown %s", |
| 261 | + exc, |
| 262 | + exc_info=True, |
| 263 | + ) |
| 264 | + continue |
| 265 | + continue |
219 | 266 | async for _ in dep: # noqa: WPS328 |
220 | 267 | pass # noqa: WPS420 |
221 | 268 |
|
|
0 commit comments