Skip to content

Commit 3314a39

Browse files
committed
Added dependencies replacemenets.
Signed-off-by: Pavel Kirilin <win10@list.ru>
1 parent 2d1469f commit 3314a39

3 files changed

Lines changed: 138 additions & 13 deletions

File tree

taskiq_dependencies/ctx.py

Lines changed: 29 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -20,12 +20,14 @@ def __init__(
2020
self,
2121
graph: "DependencyGraph",
2222
initial_cache: Optional[Dict[Any, Any]] = None,
23+
replaced_deps: Optional[Dict[Any, Any]] = None,
2324
exception_propagation: bool = True,
2425
) -> None:
2526
self.graph = graph
2627
self.opened_dependencies: List[Any] = []
2728
self.sub_contexts: "List[Any]" = []
2829
self.initial_cache = initial_cache or {}
30+
self.replaced_funcs = replaced_deps or {}
2931
self.propagate_excs = exception_propagation
3032

3133
def traverse_deps( # noqa: C901, WPS210
@@ -56,7 +58,7 @@ def traverse_deps( # noqa: C901, WPS210
5658
# later.
5759
if not dep.use_cache:
5860
continue
59-
# If somehow we have dependency with unknwon function.
61+
# If somehow we have dependency with unknown function.
6062
if dep.dependency is None:
6163
continue
6264
# If dependency is already calculated.
@@ -89,7 +91,13 @@ def traverse_deps( # noqa: C901, WPS210
8991
continue
9092
if subdep.kwargs:
9193
resolved_kwargs.update(subdep.kwargs)
92-
kwargs[subdep.param_name] = yield subdep.dependency(
94+
# We try to grab possible replacement for
95+
# function if any. Otherwise, original subdependency is returned.
96+
target_dependency = self.replaced_funcs.get(
97+
subdep.dependency,
98+
subdep.dependency,
99+
)
100+
kwargs[subdep.param_name] = yield target_dependency(
93101
**resolved_kwargs,
94102
)
95103

@@ -103,7 +111,13 @@ def traverse_deps( # noqa: C901, WPS210
103111
):
104112
user_kwargs = dep.kwargs
105113
user_kwargs.update(kwargs)
106-
cache[dep.dependency] = yield dep.dependency(**user_kwargs)
114+
# From dict of replaced functions,
115+
# we grab possible replacement or original function.
116+
target_dependency = self.replaced_funcs.get(
117+
dep.dependency,
118+
dep.dependency,
119+
)
120+
cache[dep.dependency] = yield target_dependency(**user_kwargs)
107121
return kwargs
108122

109123

@@ -169,7 +183,12 @@ def resolver(self, executed_func: Any, initial_cache: Dict[Any, Any]) -> Any:
169183
:return: dict with resolved kwargs.
170184
"""
171185
if getattr(executed_func, "dep_graph", False):
172-
ctx = SyncResolveContext(executed_func, initial_cache)
186+
ctx = SyncResolveContext(
187+
graph=executed_func,
188+
initial_cache=initial_cache,
189+
replaced_deps=self.replaced_funcs,
190+
exception_propagation=self.propagate_excs,
191+
)
173192
self.sub_contexts.append(ctx)
174193
sub_result = ctx.resolve_kwargs()
175194
elif inspect.isgenerator(executed_func):
@@ -286,7 +305,12 @@ async def resolver(self, executed_func: Any, initial_cache: Dict[Any, Any]) -> A
286305
:return: dict with resolved kwargs.
287306
"""
288307
if getattr(executed_func, "dep_graph", False):
289-
ctx = AsyncResolveContext(executed_func, initial_cache) # type: ignore
308+
ctx = AsyncResolveContext(
309+
graph=executed_func,
310+
initial_cache=initial_cache,
311+
replaced_deps=self.replaced_funcs,
312+
exception_propagation=self.propagate_excs,
313+
) # type: ignore
290314
self.sub_contexts.append(ctx)
291315
sub_result = await ctx.resolve_kwargs()
292316
elif inspect.isgenerator(executed_func):

taskiq_dependencies/graph.py

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@ def is_empty(self) -> bool:
4141
def async_ctx(
4242
self,
4343
initial_cache: Optional[Dict[Any, Any]] = None,
44+
replaced_deps: Optional[Dict[Any, Any]] = None,
4445
exception_propagation: bool = True,
4546
) -> AsyncResolveContext:
4647
"""
@@ -51,17 +52,20 @@ def async_ctx(
5152
:param initial_cache: initial cache dict.
5253
:param exception_propagation: If true, all found errors within
5354
context will be propagated to dependencies.
55+
:param replaced_deps: dict with dependencies to replace.
5456
:return: new resolver context.
5557
"""
5658
return AsyncResolveContext(
57-
self,
58-
initial_cache,
59-
exception_propagation,
59+
graph=self,
60+
initial_cache=initial_cache,
61+
exception_propagation=exception_propagation,
62+
replaced_deps=replaced_deps,
6063
)
6164

6265
def sync_ctx(
6366
self,
6467
initial_cache: Optional[Dict[Any, Any]] = None,
68+
replaced_deps: Optional[Dict[Any, Any]] = None,
6569
exception_propagation: bool = True,
6670
) -> SyncResolveContext:
6771
"""
@@ -72,12 +76,14 @@ def sync_ctx(
7276
:param initial_cache: initial cache dict.
7377
:param exception_propagation: If true, all found errors within
7478
context will be propagated to dependencies.
79+
:param replaced_deps: dict with dependencies to replace.
7580
:return: new resolver context.
7681
"""
7782
return SyncResolveContext(
78-
self,
79-
initial_cache,
80-
exception_propagation,
83+
graph=self,
84+
initial_cache=initial_cache,
85+
exception_propagation=exception_propagation,
86+
replaced_deps=replaced_deps,
8187
)
8288

8389
def _build_graph(self) -> None: # noqa: C901, WPS210

tests/test_graph.py

Lines changed: 97 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import asyncio
2+
import re
23
import uuid
34
from typing import Any, AsyncGenerator, Generator, Generic, Tuple, TypeVar
45

@@ -36,8 +37,9 @@ def testfunc(a: int = Depends(dep1)) -> int:
3637
return a
3738

3839
with DependencyGraph(testfunc).sync_ctx({}) as sctx:
39-
with pytest.raises(RuntimeError):
40-
assert sctx.resolve_kwargs() == {"a": 1}
40+
with pytest.warns(match=re.compile(".*was never awaited.*")):
41+
with pytest.raises(RuntimeError):
42+
assert sctx.resolve_kwargs() == {"a": 1}
4143

4244
async with DependencyGraph(testfunc).async_ctx({}) as actx:
4345
assert await actx.resolve_kwargs() == {"a": 1}
@@ -611,3 +613,96 @@ def target(
611613
assert dep_obj.dependency == GenericClass[Tuple[str, int]]
612614
assert dep_obj.signature.name == "class_val"
613615
assert dep_obj.signature.annotation == GenericClass[Tuple[str, int]]
616+
617+
618+
@pytest.mark.anyio
619+
async def test_replaced_dep_simple() -> None:
620+
def replaced() -> int:
621+
return 321
622+
623+
def dep() -> int:
624+
return 123
625+
626+
def target(val: int = Depends(dep)) -> None:
627+
return None
628+
629+
graph = DependencyGraph(target=target)
630+
async with graph.async_ctx(replaced_deps={dep: replaced}) as ctx:
631+
kwargs = await ctx.resolve_kwargs()
632+
assert kwargs["val"] == 321
633+
634+
635+
@pytest.mark.anyio
636+
async def test_replaced_dep_generators() -> None:
637+
call_count = 0
638+
639+
def replaced() -> Generator[int, None, None]:
640+
nonlocal call_count
641+
yield 321
642+
call_count += 1
643+
644+
def dep() -> int:
645+
return 123
646+
647+
def target(val: int = Depends(dep)) -> None:
648+
return None
649+
650+
graph = DependencyGraph(target=target)
651+
async with graph.async_ctx(replaced_deps={dep: replaced}) as ctx:
652+
kwargs = await ctx.resolve_kwargs()
653+
assert kwargs["val"] == 321
654+
assert call_count == 1
655+
656+
657+
@pytest.mark.anyio
658+
async def test_replaced_dep_exception_propogation() -> None:
659+
exc_count = 0
660+
661+
def replaced() -> Generator[int, None, None]:
662+
nonlocal exc_count
663+
try:
664+
yield 321
665+
except ValueError:
666+
exc_count += 1
667+
668+
def dep() -> int:
669+
return 123
670+
671+
def target(val: int = Depends(dep)) -> None:
672+
raise ValueError("lol")
673+
674+
graph = DependencyGraph(target=target)
675+
with pytest.raises(ValueError):
676+
async with graph.async_ctx(
677+
replaced_deps={dep: replaced},
678+
exception_propagation=True,
679+
) as ctx:
680+
kwargs = await ctx.resolve_kwargs()
681+
assert kwargs["val"] == 321
682+
target(**kwargs)
683+
assert exc_count == 1
684+
685+
686+
@pytest.mark.anyio
687+
async def test_replaced_dep_subdependencies() -> None:
688+
def subdep() -> int:
689+
return 321
690+
691+
def replaced(ret_val: int = Depends(subdep)) -> int:
692+
return ret_val
693+
694+
def dep() -> int:
695+
return 123
696+
697+
def target(val: int = Depends(dep)) -> None:
698+
raise ValueError("lol")
699+
700+
graph = DependencyGraph(target=target)
701+
with pytest.raises(ValueError):
702+
async with graph.async_ctx(
703+
replaced_deps={dep: replaced},
704+
exception_propagation=True,
705+
) as ctx:
706+
kwargs = await ctx.resolve_kwargs()
707+
assert kwargs["val"] == 321
708+
target(**kwargs)

0 commit comments

Comments
 (0)