Skip to content

Commit 238a582

Browse files
committed
Added exception_propogation parameter.
Signed-off-by: Pavel Kirilin <win10@list.ru>
1 parent ce96fbb commit 238a582

3 files changed

Lines changed: 60 additions & 3 deletions

File tree

taskiq_dependencies/ctx.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,11 +20,13 @@ def __init__(
2020
self,
2121
graph: "DependencyGraph",
2222
initial_cache: Optional[Dict[Any, Any]] = None,
23+
exception_propogation: bool = True,
2324
) -> None:
2425
self.graph = graph
2526
self.opened_dependencies: List[Any] = []
2627
self.sub_contexts: "List[Any]" = []
2728
self.initial_cache = initial_cache or {}
29+
self.propogate_excs = exception_propogation
2830

2931
def traverse_deps( # noqa: C901, WPS210
3032
self,
@@ -131,7 +133,7 @@ def close(self, *args: Any) -> None: # noqa: C901
131133
:param args: exception info if any.
132134
"""
133135
exception_found = False
134-
if args[1] is not None:
136+
if args[1] is not None and self.propogate_excs:
135137
exception_found = True
136138
for ctx in self.sub_contexts:
137139
ctx.close(*args)
@@ -232,7 +234,7 @@ async def close(self, *args: Any) -> None: # noqa: C901
232234
:param args: exception info if any.
233235
"""
234236
exception_found = False
235-
if args[1] is not None:
237+
if args[1] is not None and self.propogate_excs:
236238
exception_found = True
237239
for ctx in self.sub_contexts:
238240
await ctx.close(*args) # type: ignore

taskiq_dependencies/graph.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,35 +41,43 @@ def is_empty(self) -> bool:
4141
def async_ctx(
4242
self,
4343
initial_cache: Optional[Dict[Any, Any]] = None,
44+
exception_propogation: bool = True,
4445
) -> AsyncResolveContext:
4546
"""
4647
Create dependency resolver context.
4748
4849
This context is used to actually resolve dependencies.
4950
5051
:param initial_cache: initial cache dict.
52+
:param exception_propogation: If true, all found errors within
53+
context will be propogated to dependencies.
5154
:return: new resolver context.
5255
"""
5356
return AsyncResolveContext(
5457
self,
5558
initial_cache,
59+
exception_propogation,
5660
)
5761

5862
def sync_ctx(
5963
self,
6064
initial_cache: Optional[Dict[Any, Any]] = None,
65+
exception_propogation: bool = True,
6166
) -> SyncResolveContext:
6267
"""
6368
Create dependency resolver context.
6469
6570
This context is used to actually resolve dependencies.
6671
6772
:param initial_cache: initial cache dict.
73+
:param exception_propogation: If true, all found errors within
74+
context will be propogated to dependencies.
6875
:return: new resolver context.
6976
"""
7077
return SyncResolveContext(
7178
self,
7279
initial_cache,
80+
exception_propogation,
7381
)
7482

7583
def _build_graph(self) -> None: # noqa: C901, WPS210

tests/test_graph.py

Lines changed: 48 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -382,7 +382,7 @@ def target(
382382

383383

384384
@pytest.mark.anyio
385-
async def test_async_exception_generators_no_propogation() -> None:
385+
async def test_async_exception_in_teardown() -> None:
386386

387387
errors_found = 0
388388

@@ -400,3 +400,50 @@ def target(_: int = Depends(my_generator)) -> None:
400400
with pytest.raises(ValueError):
401401
async with DependencyGraph(target=target).async_ctx() as g:
402402
target(**(await g.resolve_kwargs()))
403+
404+
405+
@pytest.mark.anyio
406+
async def test_async_propogation_disabled() -> None:
407+
408+
errors_found = 0
409+
410+
async def my_generator() -> AsyncGenerator[int, None]:
411+
nonlocal errors_found
412+
try:
413+
yield 1
414+
except ValueError:
415+
errors_found += 1
416+
raise Exception()
417+
418+
def target(_: int = Depends(my_generator)) -> None:
419+
raise ValueError()
420+
421+
with pytest.raises(ValueError):
422+
async with DependencyGraph(target=target).async_ctx(
423+
exception_propogation=False,
424+
) as g:
425+
target(**(await g.resolve_kwargs()))
426+
427+
assert errors_found == 0
428+
429+
430+
def test_sync_propogation_disabled() -> None:
431+
432+
errors_found = 0
433+
434+
def my_generator() -> Generator[int, None, None]:
435+
nonlocal errors_found
436+
try:
437+
yield 1
438+
except ValueError:
439+
errors_found += 1
440+
raise Exception()
441+
442+
def target(_: int = Depends(my_generator)) -> None:
443+
raise ValueError()
444+
445+
with pytest.raises(ValueError):
446+
with DependencyGraph(target=target).sync_ctx(exception_propogation=False) as g:
447+
target(**(g.resolve_kwargs()))
448+
449+
assert errors_found == 0

0 commit comments

Comments
 (0)