Skip to content

Commit 28ab036

Browse files
committed
Added dependencies replacement.
Signed-off-by: Pavel Kirilin <win10@list.ru>
1 parent a4c8e11 commit 28ab036

3 files changed

Lines changed: 30 additions & 48 deletions

File tree

taskiq_dependencies/ctx.py

Lines changed: 5 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -20,14 +20,12 @@ def __init__(
2020
self,
2121
graph: "DependencyGraph",
2222
initial_cache: Optional[Dict[Any, Any]] = None,
23-
replaced_deps: Optional[Dict[Any, Any]] = None,
2423
exception_propagation: bool = True,
2524
) -> None:
2625
self.graph = graph
2726
self.opened_dependencies: List[Any] = []
2827
self.sub_contexts: "List[Any]" = []
2928
self.initial_cache = initial_cache or {}
30-
self.replaced_deps = replaced_deps or {}
3129
self.propagate_excs = exception_propagation
3230

3331
def traverse_deps( # noqa: C901, WPS210
@@ -58,7 +56,7 @@ def traverse_deps( # noqa: C901, WPS210
5856
# later.
5957
if not dep.use_cache:
6058
continue
61-
# If somehow we have dependency with unknown function.
59+
# If somehow we have dependency with unknwon function.
6260
if dep.dependency is None:
6361
continue
6462
# If dependency is already calculated.
@@ -91,13 +89,7 @@ def traverse_deps( # noqa: C901, WPS210
9189
continue
9290
if subdep.kwargs:
9391
resolved_kwargs.update(subdep.kwargs)
94-
# We try to grab possible replacement for
95-
# function if any. Otherwise, original subdependency is returned.
96-
target_dependency = self.replaced_deps.get(
97-
subdep.dependency,
98-
subdep.dependency,
99-
)
100-
kwargs[subdep.param_name] = yield target_dependency(
92+
kwargs[subdep.param_name] = yield subdep.dependency(
10193
**resolved_kwargs,
10294
)
10395

@@ -111,13 +103,7 @@ def traverse_deps( # noqa: C901, WPS210
111103
):
112104
user_kwargs = dep.kwargs
113105
user_kwargs.update(kwargs)
114-
# From dict of replaced functions,
115-
# we grab possible replacement or original function.
116-
target_dependency = self.replaced_deps.get(
117-
dep.dependency,
118-
dep.dependency,
119-
)
120-
cache[dep.dependency] = yield target_dependency(**user_kwargs)
106+
cache[dep.dependency] = yield dep.dependency(**user_kwargs)
121107
return kwargs
122108

123109

@@ -183,12 +169,7 @@ def resolver(self, executed_func: Any, initial_cache: Dict[Any, Any]) -> Any:
183169
:return: dict with resolved kwargs.
184170
"""
185171
if getattr(executed_func, "dep_graph", False):
186-
ctx = SyncResolveContext(
187-
graph=executed_func,
188-
initial_cache=initial_cache,
189-
replaced_deps=self.replaced_deps,
190-
exception_propagation=self.propagate_excs,
191-
)
172+
ctx = SyncResolveContext(executed_func, initial_cache)
192173
self.sub_contexts.append(ctx)
193174
sub_result = ctx.resolve_kwargs()
194175
elif inspect.isgenerator(executed_func):
@@ -305,12 +286,7 @@ async def resolver(self, executed_func: Any, initial_cache: Dict[Any, Any]) -> A
305286
:return: dict with resolved kwargs.
306287
"""
307288
if getattr(executed_func, "dep_graph", False):
308-
ctx = AsyncResolveContext(
309-
graph=executed_func,
310-
initial_cache=initial_cache,
311-
replaced_deps=self.replaced_deps,
312-
exception_propagation=self.propagate_excs,
313-
) # type: ignore
289+
ctx = AsyncResolveContext(executed_func, initial_cache) # type: ignore
314290
self.sub_contexts.append(ctx)
315291
sub_result = await ctx.resolve_kwargs()
316292
elif inspect.isgenerator(executed_func):

taskiq_dependencies/graph.py

Lines changed: 18 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ class DependencyGraph:
2020
def __init__(
2121
self,
2222
target: Callable[..., Any],
23+
replaced_deps: Optional[Dict[Any, Any]] = None,
2324
) -> None:
2425
self.target = target
2526
# Ordinary dependencies with cache.
@@ -28,6 +29,7 @@ def __init__(
2829
# Can be considered as sub graphs.
2930
self.subgraphs: Dict[Any, DependencyGraph] = {}
3031
self.ordered_deps: List[Dependency] = []
32+
self.replaced_deps = replaced_deps
3133
self._build_graph()
3234

3335
def is_empty(self) -> bool:
@@ -52,14 +54,16 @@ def async_ctx(
5254
:param initial_cache: initial cache dict.
5355
:param exception_propagation: If true, all found errors within
5456
context will be propagated to dependencies.
55-
:param replaced_deps: dict with dependencies to replace.
57+
:param replaced_deps: Dependencies to replace during runtime.
5658
:return: new resolver context.
5759
"""
60+
graph = self
61+
if replaced_deps:
62+
graph = DependencyGraph(self.target, replaced_deps)
5863
return AsyncResolveContext(
59-
graph=self,
60-
initial_cache=initial_cache,
61-
exception_propagation=exception_propagation,
62-
replaced_deps=replaced_deps,
64+
graph,
65+
initial_cache,
66+
exception_propagation,
6367
)
6468

6569
def sync_ctx(
@@ -76,14 +80,16 @@ def sync_ctx(
7680
:param initial_cache: initial cache dict.
7781
:param exception_propagation: If true, all found errors within
7882
context will be propagated to dependencies.
79-
:param replaced_deps: dict with dependencies to replace.
83+
:param replaced_deps: Dependencies to replace during runtime.
8084
:return: new resolver context.
8185
"""
86+
graph = self
87+
if replaced_deps:
88+
graph = DependencyGraph(self.target, replaced_deps)
8289
return SyncResolveContext(
83-
graph=self,
84-
initial_cache=initial_cache,
85-
exception_propagation=exception_propagation,
86-
replaced_deps=replaced_deps,
90+
graph,
91+
initial_cache,
92+
exception_propagation,
8793
)
8894

8995
def _build_graph(self) -> None: # noqa: C901, WPS210
@@ -108,6 +114,8 @@ def _build_graph(self) -> None: # noqa: C901, WPS210
108114
continue
109115
if dep.dependency is None:
110116
continue
117+
if self.replaced_deps and dep.dependency in self.replaced_deps:
118+
dep.dependency = self.replaced_deps[dep.dependency]
111119
# Get signature and type hints.
112120
origin = getattr(dep.dependency, "__origin__", None)
113121
if origin is None:

tests/test_graph.py

Lines changed: 7 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -695,14 +695,12 @@ def dep() -> int:
695695
return 123
696696

697697
def target(val: int = Depends(dep)) -> None:
698-
raise ValueError("lol")
698+
"""Stub function."""
699699

700700
graph = DependencyGraph(target=target)
701-
with pytest.raises(ValueError):
702-
async with graph.async_ctx(
703-
replaced_deps={dep: replaced},
704-
exception_propagation=True,
705-
) as ctx:
706-
kwargs = await ctx.resolve_kwargs()
707-
assert kwargs["val"] == 321
708-
target(**kwargs)
701+
async with graph.async_ctx(
702+
replaced_deps={dep: replaced},
703+
exception_propagation=True,
704+
) as ctx:
705+
kwargs = await ctx.resolve_kwargs()
706+
assert kwargs["val"] == 321

0 commit comments

Comments
 (0)