@@ -20,10 +20,13 @@ class BaseResolveContext:
2020 def __init__ (
2121 self ,
2222 graph : "DependencyGraph" ,
23+ main_graph : "DependencyGraph" ,
2324 initial_cache : Optional [Dict [Any , Any ]] = None ,
2425 exception_propagation : bool = True ,
2526 ) -> None :
2627 self .graph = graph
28+ # Main graph that contains all the subgraphs.
29+ self .main_graph = main_graph
2730 self .opened_dependencies : List [Any ] = []
2831 self .sub_contexts : "List[Any]" = []
2932 self .initial_cache = initial_cache or {}
@@ -91,7 +94,7 @@ def traverse_deps( # noqa: C901
9194 if subdep .dependency == ParamInfo :
9295 kwargs [subdep .param_name ] = ParamInfo (
9396 dep .param_name ,
94- self .graph ,
97+ self .main_graph ,
9598 dep .signature ,
9699 )
97100 continue
@@ -201,7 +204,7 @@ def resolver(self, executed_func: Any, initial_cache: Dict[Any, Any]) -> Any:
201204 :return: dict with resolved kwargs.
202205 """
203206 if getattr (executed_func , "dep_graph" , False ):
204- ctx = SyncResolveContext (executed_func , initial_cache )
207+ ctx = SyncResolveContext (executed_func , self . main_graph , initial_cache )
205208 self .sub_contexts .append (ctx )
206209 sub_result = ctx .resolve_kwargs ()
207210 elif inspect .isgenerator (executed_func ):
@@ -329,7 +332,7 @@ async def resolver(
329332 :return: dict with resolved kwargs.
330333 """
331334 if getattr (executed_func , "dep_graph" , False ):
332- ctx = AsyncResolveContext (executed_func , initial_cache ) # type: ignore
335+ ctx = AsyncResolveContext (executed_func , self . main_graph , initial_cache ) # type: ignore
333336 self .sub_contexts .append (ctx )
334337 sub_result = await ctx .resolve_kwargs ()
335338 elif inspect .isgenerator (executed_func ):
0 commit comments