55from logging import getLogger
66from typing import TYPE_CHECKING , Any , DefaultDict , Dict , Generator , List , Optional
77
8- from taskiq_dependencies .utils import ParamInfo
8+ from taskiq_dependencies .utils import ParamInfo , isasynccontextmanager , iscontextmanager
99
1010if TYPE_CHECKING :
1111 from taskiq_dependencies .graph import DependencyGraph # pragma: no cover
@@ -59,7 +59,7 @@ def traverse_deps( # noqa: C901, WPS210
5959 # later.
6060 if not dep .use_cache :
6161 continue
62- # If somehow we have dependency with unknwon function.
62+ # If somehow we have dependency with unknown function.
6363 if dep .dependency is None :
6464 continue
6565 # If dependency is already calculated.
@@ -180,6 +180,8 @@ def close(self, *args: Any) -> None: # noqa: C901
180180 continue
181181 for _ in dep : # noqa: WPS328
182182 pass # noqa: WPS420
183+ elif iscontextmanager (dep ):
184+ dep .__exit__ (* args ) # noqa: WPS609
183185
184186 def resolver (self , executed_func : Any , initial_cache : Dict [Any , Any ]) -> Any :
185187 """
@@ -194,7 +196,7 @@ def resolver(self, executed_func: Any, initial_cache: Dict[Any, Any]) -> Any:
194196
195197 :return: dict with resolved kwargs.
196198 """
197- if getattr (executed_func , "dep_graph" , False ):
199+ if getattr (executed_func , "dep_graph" , False ): # noqa: WPS223
198200 ctx = SyncResolveContext (executed_func , initial_cache )
199201 self .sub_contexts .append (ctx )
200202 sub_result = ctx .resolve_kwargs ()
@@ -206,7 +208,10 @@ def resolver(self, executed_func: Any, initial_cache: Dict[Any, Any]) -> Any:
206208 "Coroutines cannot be used in sync context. "
207209 "Please use async context instead." ,
208210 )
209- elif inspect .isasyncgen (executed_func ):
211+ elif iscontextmanager (executed_func ):
212+ sub_result = executed_func .__enter__ () # noqa: WPS609
213+ self .opened_dependencies .append (executed_func )
214+ elif inspect .isasyncgen (executed_func ) or isasynccontextmanager (executed_func ):
210215 raise RuntimeError (
211216 "Coroutines cannot be used in sync context. "
212217 "Please use async context instead." ,
@@ -299,8 +304,16 @@ async def close(self, *args: Any) -> None: # noqa: C901
299304 continue
300305 async for _ in dep : # noqa: WPS328
301306 pass # noqa: WPS420
307+ elif iscontextmanager (dep ):
308+ dep .__exit__ (* args ) # noqa: WPS609
309+ elif isasynccontextmanager (dep ):
310+ await dep .__aexit__ (* args ) # noqa: WPS609
302311
303- async def resolver (self , executed_func : Any , initial_cache : Dict [Any , Any ]) -> Any :
312+ async def resolver ( # noqa: C901
313+ self ,
314+ executed_func : Any ,
315+ initial_cache : Dict [Any , Any ],
316+ ) -> Any :
304317 """
305318 Async resolver.
306319
@@ -311,7 +324,7 @@ async def resolver(self, executed_func: Any, initial_cache: Dict[Any, Any]) -> A
311324 :param initial_cache: cache to build a context if graph was passed.
312325 :return: dict with resolved kwargs.
313326 """
314- if getattr (executed_func , "dep_graph" , False ):
327+ if getattr (executed_func , "dep_graph" , False ): # noqa: WPS223
315328 ctx = AsyncResolveContext (executed_func , initial_cache ) # type: ignore
316329 self .sub_contexts .append (ctx )
317330 sub_result = await ctx .resolve_kwargs ()
@@ -323,6 +336,12 @@ async def resolver(self, executed_func: Any, initial_cache: Dict[Any, Any]) -> A
323336 elif inspect .isasyncgen (executed_func ):
324337 sub_result = await executed_func .__anext__ () # noqa: WPS609
325338 self .opened_dependencies .append (executed_func )
339+ elif iscontextmanager (executed_func ):
340+ sub_result = executed_func .__enter__ () # noqa: WPS609
341+ self .opened_dependencies .append (executed_func )
342+ elif isasynccontextmanager (executed_func ):
343+ sub_result = await executed_func .__aenter__ () # noqa: WPS609
344+ self .opened_dependencies .append (executed_func )
326345 else :
327346 sub_result = executed_func
328347 return sub_result
0 commit comments