@@ -20,12 +20,14 @@ def __init__(
2020 self ,
2121 graph : "DependencyGraph" ,
2222 initial_cache : Optional [Dict [Any , Any ]] = None ,
23+ replaced_deps : Optional [Dict [Any , Any ]] = None ,
2324 exception_propagation : bool = True ,
2425 ) -> None :
2526 self .graph = graph
2627 self .opened_dependencies : List [Any ] = []
2728 self .sub_contexts : "List[Any]" = []
2829 self .initial_cache = initial_cache or {}
30+ self .replaced_funcs = replaced_deps or {}
2931 self .propagate_excs = exception_propagation
3032
3133 def traverse_deps ( # noqa: C901, WPS210
@@ -56,7 +58,7 @@ def traverse_deps( # noqa: C901, WPS210
5658 # later.
5759 if not dep .use_cache :
5860 continue
59- # If somehow we have dependency with unknwon function.
61+ # If somehow we have dependency with unknown function.
6062 if dep .dependency is None :
6163 continue
6264 # If dependency is already calculated.
@@ -89,7 +91,13 @@ def traverse_deps( # noqa: C901, WPS210
8991 continue
9092 if subdep .kwargs :
9193 resolved_kwargs .update (subdep .kwargs )
92- kwargs [subdep .param_name ] = yield subdep .dependency (
94+ # We try to grab possible replacement for
95+ # function if any. Otherwise, original subdependency is returned.
96+ target_dependency = self .replaced_funcs .get (
97+ subdep .dependency ,
98+ subdep .dependency ,
99+ )
100+ kwargs [subdep .param_name ] = yield target_dependency (
93101 ** resolved_kwargs ,
94102 )
95103
@@ -103,7 +111,13 @@ def traverse_deps( # noqa: C901, WPS210
103111 ):
104112 user_kwargs = dep .kwargs
105113 user_kwargs .update (kwargs )
106- cache [dep .dependency ] = yield dep .dependency (** user_kwargs )
114+ # From dict of replaced functions,
115+ # we grab possible replacement or original function.
116+ target_dependency = self .replaced_funcs .get (
117+ dep .dependency ,
118+ dep .dependency ,
119+ )
120+ cache [dep .dependency ] = yield target_dependency (** user_kwargs )
107121 return kwargs
108122
109123
@@ -169,7 +183,12 @@ def resolver(self, executed_func: Any, initial_cache: Dict[Any, Any]) -> Any:
169183 :return: dict with resolved kwargs.
170184 """
171185 if getattr (executed_func , "dep_graph" , False ):
172- ctx = SyncResolveContext (executed_func , initial_cache )
186+ ctx = SyncResolveContext (
187+ graph = executed_func ,
188+ initial_cache = initial_cache ,
189+ replaced_deps = self .replaced_funcs ,
190+ exception_propagation = self .propagate_excs ,
191+ )
173192 self .sub_contexts .append (ctx )
174193 sub_result = ctx .resolve_kwargs ()
175194 elif inspect .isgenerator (executed_func ):
@@ -286,7 +305,12 @@ async def resolver(self, executed_func: Any, initial_cache: Dict[Any, Any]) -> A
286305 :return: dict with resolved kwargs.
287306 """
288307 if getattr (executed_func , "dep_graph" , False ):
289- ctx = AsyncResolveContext (executed_func , initial_cache ) # type: ignore
308+ ctx = AsyncResolveContext (
309+ graph = executed_func ,
310+ initial_cache = initial_cache ,
311+ replaced_deps = self .replaced_funcs ,
312+ exception_propagation = self .propagate_excs ,
313+ ) # type: ignore
290314 self .sub_contexts .append (ctx )
291315 sub_result = await ctx .resolve_kwargs ()
292316 elif inspect .isgenerator (executed_func ):
0 commit comments