Skip to content

Commit ce96fbb

Browse files
committed
Added exception propogation for dependencies.
Signed-off-by: Pavel Kirilin <win10@list.ru>
1 parent 6d00859 commit ce96fbb

3 files changed

Lines changed: 145 additions & 6 deletions

File tree

.flake8

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -88,6 +88,8 @@ ignore =
8888
N802,
8989
; Do not perform function calls in argument defaults.
9090
B008,
91+
; Found except `BaseException`
92+
WPS424,
9193

9294
; all init files
9395
__init__.py:

taskiq_dependencies/ctx.py

Lines changed: 53 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import asyncio
22
import inspect
33
from copy import copy
4+
from logging import getLogger
45
from typing import TYPE_CHECKING, Any, Dict, Generator, List, Optional
56

67
from taskiq_dependencies.utils import ParamInfo
@@ -9,6 +10,9 @@
910
from taskiq_dependencies.graph import DependencyGraph # pragma: no cover
1011

1112

13+
logger = getLogger("taskiq.dependencies.ctx")
14+
15+
1216
class BaseResolveContext:
1317
"""Base resolver context."""
1418

@@ -116,18 +120,34 @@ def __enter__(self) -> "SyncResolveContext":
116120
return self
117121

118122
def __exit__(self, *args: Any) -> None:
119-
self.close()
123+
self.close(*args)
120124

121-
def close(self) -> None:
125+
def close(self, *args: Any) -> None: # noqa: C901
122126
"""
123127
Close all opened dependencies.
124128
125129
This function runs teardown of all dependencies.
130+
131+
:param args: exception info if any.
126132
"""
133+
exception_found = False
134+
if args[1] is not None:
135+
exception_found = True
127136
for ctx in self.sub_contexts:
128-
ctx.close()
137+
ctx.close(*args)
129138
for dep in reversed(self.opened_dependencies):
130139
if inspect.isgenerator(dep):
140+
if exception_found:
141+
try:
142+
dep.throw(*args)
143+
except BaseException as exc:
144+
logger.warning(
145+
"Exception found on dependency teardown %s",
146+
exc,
147+
exc_info=True,
148+
)
149+
continue
150+
continue
131151
for _ in dep: # noqa: WPS328
132152
pass # noqa: WPS420
133153

@@ -201,21 +221,48 @@ async def __aenter__(self) -> "AsyncResolveContext":
201221
return self
202222

203223
async def __aexit__(self, *args: Any) -> None:
204-
await self.close()
224+
await self.close(*args)
205225

206-
async def close(self) -> None: # noqa: C901
226+
async def close(self, *args: Any) -> None: # noqa: C901
207227
"""
208228
Close all opened dependencies.
209229
210230
This function runs teardown of all dependencies.
231+
232+
:param args: exception info if any.
211233
"""
234+
exception_found = False
235+
if args[1] is not None:
236+
exception_found = True
212237
for ctx in self.sub_contexts:
213-
await ctx.close() # type: ignore
238+
await ctx.close(*args) # type: ignore
214239
for dep in reversed(self.opened_dependencies):
215240
if inspect.isgenerator(dep):
241+
if exception_found:
242+
try:
243+
dep.throw(*args)
244+
except BaseException as exc:
245+
logger.warning(
246+
"Exception found on dependency teardown %s",
247+
exc,
248+
exc_info=True,
249+
)
250+
continue
251+
continue
216252
for _ in dep: # noqa: WPS328
217253
pass # noqa: WPS420
218254
elif inspect.isasyncgen(dep):
255+
if exception_found:
256+
try:
257+
await dep.athrow(*args)
258+
except BaseException as exc:
259+
logger.warning(
260+
"Exception found on dependency teardown %s",
261+
exc,
262+
exc_info=True,
263+
)
264+
continue
265+
continue
219266
async for _ in dep: # noqa: WPS328
220267
pass # noqa: WPS420
221268

tests/test_graph.py

Lines changed: 90 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -310,3 +310,93 @@ def target(class_val: str = Depends(TeClass("tval"))) -> None:
310310

311311
info: str = kwargs["class_val"]
312312
assert info == "tval"
313+
314+
315+
def test_exception_generators() -> None:
316+
317+
errors_found = 0
318+
319+
def my_generator() -> Generator[int, None, None]:
320+
nonlocal errors_found
321+
try:
322+
yield 1
323+
except ValueError:
324+
errors_found += 1
325+
326+
def target(_: int = Depends(my_generator)) -> None:
327+
raise ValueError()
328+
329+
with pytest.raises(ValueError):
330+
with DependencyGraph(target=target).sync_ctx() as g:
331+
target(**g.resolve_kwargs())
332+
333+
assert errors_found == 1
334+
335+
336+
@pytest.mark.anyio
337+
async def test_async_exception_generators() -> None:
338+
339+
errors_found = 0
340+
341+
async def my_generator() -> AsyncGenerator[int, None]:
342+
nonlocal errors_found
343+
try:
344+
yield 1
345+
except ValueError:
346+
errors_found += 1
347+
348+
def target(_: int = Depends(my_generator)) -> None:
349+
raise ValueError()
350+
351+
with pytest.raises(ValueError):
352+
async with DependencyGraph(target=target).async_ctx() as g:
353+
target(**(await g.resolve_kwargs()))
354+
355+
assert errors_found == 1
356+
357+
358+
@pytest.mark.anyio
359+
async def test_async_exception_generators_multiple() -> None:
360+
361+
errors_found = 0
362+
363+
async def my_generator() -> AsyncGenerator[int, None]:
364+
nonlocal errors_found
365+
try:
366+
yield 1
367+
except ValueError:
368+
errors_found += 1
369+
370+
def target(
371+
_a: int = Depends(my_generator, use_cache=False),
372+
_b: int = Depends(my_generator, use_cache=False),
373+
_c: int = Depends(my_generator, use_cache=False),
374+
) -> None:
375+
raise ValueError()
376+
377+
with pytest.raises(ValueError):
378+
async with DependencyGraph(target=target).async_ctx() as g:
379+
target(**(await g.resolve_kwargs()))
380+
381+
assert errors_found == 3
382+
383+
384+
@pytest.mark.anyio
385+
async def test_async_exception_generators_no_propogation() -> None:
386+
387+
errors_found = 0
388+
389+
async def my_generator() -> AsyncGenerator[int, None]:
390+
nonlocal errors_found
391+
try:
392+
yield 1
393+
except ValueError:
394+
errors_found += 1
395+
raise Exception()
396+
397+
def target(_: int = Depends(my_generator)) -> None:
398+
raise ValueError()
399+
400+
with pytest.raises(ValueError):
401+
async with DependencyGraph(target=target).async_ctx() as g:
402+
target(**(await g.resolve_kwargs()))

0 commit comments

Comments
 (0)