Skip to content

Commit 1291fda

Browse files
committed
Add support for context managers
1 parent 32dad9c commit 1291fda

6 files changed

Lines changed: 242 additions & 8 deletions

File tree

.flake8

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -92,6 +92,8 @@ ignore =
9292
WPS424,
9393
; Found a too complex `f` string
9494
WPS237,
95+
; Found `no cover` comments overuse: 7 > 5
96+
WPS403,
9597

9698
per-file-ignores =
9799
; all tests

taskiq_dependencies/ctx.py

Lines changed: 25 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
from logging import getLogger
66
from typing import TYPE_CHECKING, Any, DefaultDict, Dict, Generator, List, Optional
77

8-
from taskiq_dependencies.utils import ParamInfo
8+
from taskiq_dependencies.utils import ParamInfo, isasynccontextmanager, iscontextmanager
99

1010
if TYPE_CHECKING:
1111
from taskiq_dependencies.graph import DependencyGraph # pragma: no cover
@@ -59,7 +59,7 @@ def traverse_deps( # noqa: C901, WPS210
5959
# later.
6060
if not dep.use_cache:
6161
continue
62-
# If somehow we have dependency with unknwon function.
62+
# If somehow we have dependency with unknown function.
6363
if dep.dependency is None:
6464
continue
6565
# If dependency is already calculated.
@@ -180,6 +180,8 @@ def close(self, *args: Any) -> None: # noqa: C901
180180
continue
181181
for _ in dep: # noqa: WPS328
182182
pass # noqa: WPS420
183+
elif iscontextmanager(dep):
184+
dep.__exit__(*args) # noqa: WPS609
183185

184186
def resolver(self, executed_func: Any, initial_cache: Dict[Any, Any]) -> Any:
185187
"""
@@ -194,7 +196,7 @@ def resolver(self, executed_func: Any, initial_cache: Dict[Any, Any]) -> Any:
194196
195197
:return: dict with resolved kwargs.
196198
"""
197-
if getattr(executed_func, "dep_graph", False):
199+
if getattr(executed_func, "dep_graph", False): # noqa: WPS223
198200
ctx = SyncResolveContext(executed_func, initial_cache)
199201
self.sub_contexts.append(ctx)
200202
sub_result = ctx.resolve_kwargs()
@@ -206,7 +208,10 @@ def resolver(self, executed_func: Any, initial_cache: Dict[Any, Any]) -> Any:
206208
"Coroutines cannot be used in sync context. "
207209
"Please use async context instead.",
208210
)
209-
elif inspect.isasyncgen(executed_func):
211+
elif iscontextmanager(executed_func):
212+
sub_result = executed_func.__enter__() # noqa: WPS609
213+
self.opened_dependencies.append(executed_func)
214+
elif inspect.isasyncgen(executed_func) or isasynccontextmanager(executed_func):
210215
raise RuntimeError(
211216
"Coroutines cannot be used in sync context. "
212217
"Please use async context instead.",
@@ -299,8 +304,16 @@ async def close(self, *args: Any) -> None: # noqa: C901
299304
continue
300305
async for _ in dep: # noqa: WPS328
301306
pass # noqa: WPS420
307+
elif iscontextmanager(dep):
308+
dep.__exit__(*args) # noqa: WPS609
309+
elif isasynccontextmanager(dep):
310+
await dep.__aexit__(*args) # noqa: WPS609
302311

303-
async def resolver(self, executed_func: Any, initial_cache: Dict[Any, Any]) -> Any:
312+
async def resolver( # noqa: C901
313+
self,
314+
executed_func: Any,
315+
initial_cache: Dict[Any, Any],
316+
) -> Any:
304317
"""
305318
Async resolver.
306319
@@ -311,7 +324,7 @@ async def resolver(self, executed_func: Any, initial_cache: Dict[Any, Any]) -> A
311324
:param initial_cache: cache to build a context if graph was passed.
312325
:return: dict with resolved kwargs.
313326
"""
314-
if getattr(executed_func, "dep_graph", False):
327+
if getattr(executed_func, "dep_graph", False): # noqa: WPS223
315328
ctx = AsyncResolveContext(executed_func, initial_cache) # type: ignore
316329
self.sub_contexts.append(ctx)
317330
sub_result = await ctx.resolve_kwargs()
@@ -323,6 +336,12 @@ async def resolver(self, executed_func: Any, initial_cache: Dict[Any, Any]) -> A
323336
elif inspect.isasyncgen(executed_func):
324337
sub_result = await executed_func.__anext__() # noqa: WPS609
325338
self.opened_dependencies.append(executed_func)
339+
elif iscontextmanager(executed_func):
340+
sub_result = executed_func.__enter__() # noqa: WPS609
341+
self.opened_dependencies.append(executed_func)
342+
elif isasynccontextmanager(executed_func):
343+
sub_result = await executed_func.__aenter__() # noqa: WPS609
344+
self.opened_dependencies.append(executed_func)
326345
else:
327346
sub_result = executed_func
328347
return sub_result

taskiq_dependencies/dependency.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,10 @@
22
import uuid
33
from typing import ( # noqa: WPS235
44
Any,
5+
AsyncContextManager,
56
AsyncGenerator,
67
Callable,
8+
ContextManager,
79
Coroutine,
810
Dict,
911
Generator,
@@ -17,6 +19,26 @@
1719
_T = TypeVar("_T") # noqa: WPS111
1820

1921

22+
@overload
23+
def Depends( # noqa: WPS234
24+
dependency: Optional[Callable[..., ContextManager[_T]]] = None,
25+
*,
26+
use_cache: bool = True,
27+
kwargs: Optional[Dict[str, Any]] = None,
28+
) -> _T: # pragma: no cover
29+
...
30+
31+
32+
@overload
33+
def Depends( # noqa: WPS234
34+
dependency: Optional[Callable[..., AsyncContextManager[_T]]] = None,
35+
*,
36+
use_cache: bool = True,
37+
kwargs: Optional[Dict[str, Any]] = None,
38+
) -> _T: # pragma: no cover
39+
...
40+
41+
2042
@overload
2143
def Depends( # noqa: WPS234
2244
dependency: Optional[Callable[..., AsyncGenerator[_T, None]]] = None,

taskiq_dependencies/utils.py

Lines changed: 29 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
import inspect
2-
from typing import Optional
2+
from typing import Any, AsyncContextManager, ContextManager, Optional, TypeGuard
33

44

55
class ParamInfo:
@@ -23,3 +23,31 @@ def __init__(
2323

2424
def __repr__(self) -> str:
2525
return f"ParamInfo<name={self.name}>"
26+
27+
28+
def iscontextmanager(obj: Any) -> TypeGuard[ContextManager[Any]]:
29+
"""
30+
Return true if the object is a sync context manager.
31+
32+
:param obj: object to check.
33+
:return: bool that indicates whether the object is a context manager or not.
34+
"""
35+
if not hasattr(obj, "__enter__"): # noqa: WPS421
36+
return False
37+
elif not hasattr(obj, "__exit__"): # noqa: WPS421
38+
return False
39+
return True
40+
41+
42+
def isasynccontextmanager(obj: Any) -> TypeGuard[AsyncContextManager[Any]]:
43+
"""
44+
Return true if the object is a async context manager.
45+
46+
:param obj: object to check.
47+
:return: bool that indicates whether the object is a async context manager or not.
48+
"""
49+
if not hasattr(obj, "__aenter__"): # noqa: WPS421
50+
return False
51+
elif not hasattr(obj, "__aexit__"): # noqa: WPS421
52+
return False
53+
return True

tests/test_annotated.py

Lines changed: 93 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,12 @@
11
import sys
2+
from contextlib import asynccontextmanager, contextmanager
23

34
import pytest
45

56
if sys.version_info < (3, 10):
67
pytest.skip("Annotated is available only for python 3.10+", allow_module_level=True)
78

8-
from typing import Annotated, AsyncGenerator, Generic, Tuple, TypeVar
9+
from typing import Annotated, AsyncGenerator, Generator, Generic, Tuple, TypeVar
910

1011
from taskiq_dependencies import DependencyGraph, Depends
1112

@@ -53,6 +54,38 @@ def test_func(a: Annotated[MainClass[MyClass], Depends()]) -> MyClass:
5354
assert isinstance(value, MyClass)
5455

5556

57+
@pytest.mark.anyio
58+
async def test_annotated_gen() -> None:
59+
opened = False
60+
closed = False
61+
62+
def my_gen() -> Generator[int, None, None]:
63+
nonlocal opened, closed
64+
opened = True
65+
66+
yield 1
67+
68+
closed = True
69+
70+
def test_func(dep: Annotated[int, Depends(my_gen)]) -> int:
71+
return dep
72+
73+
with DependencyGraph(target=test_func).sync_ctx() as sctx:
74+
value = test_func(**sctx.resolve_kwargs())
75+
assert value == 1
76+
77+
assert opened and closed
78+
79+
opened = False
80+
closed = False
81+
82+
async with DependencyGraph(target=test_func).async_ctx() as actx:
83+
value = test_func(**(await actx.resolve_kwargs()))
84+
assert value == 1
85+
86+
assert opened and closed
87+
88+
5689
@pytest.mark.anyio
5790
async def test_annotated_asyncgen() -> None:
5891
opened = False
@@ -76,6 +109,65 @@ def test_func(dep: Annotated[int, Depends(my_gen)]) -> int:
76109
assert opened and closed
77110

78111

112+
@pytest.mark.anyio
113+
async def test_annotated_manager() -> None:
114+
opened = False
115+
closed = False
116+
117+
@contextmanager
118+
def my_gen() -> Generator[int, None, None]:
119+
nonlocal opened, closed
120+
opened = True
121+
122+
try:
123+
yield 1
124+
finally:
125+
closed = True
126+
127+
def test_func(dep: Annotated[int, Depends(my_gen)]) -> int:
128+
return dep
129+
130+
with DependencyGraph(target=test_func).sync_ctx() as sctx:
131+
value = test_func(**sctx.resolve_kwargs())
132+
assert value == 1
133+
134+
assert opened and closed
135+
136+
opened = False
137+
closed = False
138+
139+
async with DependencyGraph(target=test_func).async_ctx() as actx:
140+
value = test_func(**(await actx.resolve_kwargs()))
141+
assert value == 1
142+
143+
assert opened and closed
144+
145+
146+
@pytest.mark.anyio
147+
async def test_annotated_asyncmanager() -> None:
148+
opened = False
149+
closed = False
150+
151+
@asynccontextmanager
152+
async def my_gen() -> AsyncGenerator[int, None]:
153+
nonlocal opened, closed
154+
opened = True
155+
156+
try:
157+
yield 1
158+
finally:
159+
closed = True
160+
161+
def test_func(dep: Annotated[int, Depends(my_gen)]) -> int:
162+
return dep
163+
164+
async with DependencyGraph(target=test_func).async_ctx() as g:
165+
value = test_func(**(await g.resolve_kwargs()))
166+
assert value == 1
167+
168+
assert opened and closed
169+
170+
79171
def test_multiple() -> None:
80172
class TestClass:
81173
pass

tests/test_graph.py

Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import asyncio
22
import re
33
import uuid
4+
from contextlib import asynccontextmanager, contextmanager
45
from typing import Any, AsyncGenerator, Generator, Generic, Tuple, TypeVar
56

67
import pytest
@@ -111,6 +112,76 @@ def testfunc(a: int = Depends(dep1)) -> int:
111112
assert closes == 1
112113

113114

115+
@pytest.mark.anyio
116+
async def test_dependency_contextmanager_successful() -> None:
117+
"""Tests that contextmanagers work as expected."""
118+
starts = 0
119+
closes = 0
120+
121+
@contextmanager
122+
def dep1() -> Generator[int, None, None]:
123+
nonlocal starts # noqa: WPS420
124+
nonlocal closes # noqa: WPS420
125+
126+
starts += 1
127+
128+
try:
129+
yield 1
130+
finally:
131+
closes += 1
132+
133+
def testfunc(a: int = Depends(dep1)) -> int:
134+
return a
135+
136+
with DependencyGraph(testfunc).sync_ctx({}) as sctx:
137+
assert sctx.resolve_kwargs() == {"a": 1}
138+
assert starts == 1
139+
assert closes == 0
140+
starts = 0
141+
assert closes == 1
142+
closes = 0
143+
144+
async with DependencyGraph(testfunc).async_ctx({}) as actx:
145+
assert await actx.resolve_kwargs() == {"a": 1}
146+
assert starts == 1
147+
assert closes == 0
148+
assert closes == 1
149+
150+
151+
@pytest.mark.anyio
152+
async def test_dependency_async_manager_successful() -> None:
153+
"""This test checks that async contextmanagers work."""
154+
starts = 0
155+
closes = 0
156+
157+
@asynccontextmanager
158+
async def dep1() -> AsyncGenerator[int, None]:
159+
nonlocal starts # noqa: WPS420
160+
nonlocal closes # noqa: WPS420
161+
162+
await asyncio.sleep(0.001)
163+
starts += 1
164+
165+
try:
166+
yield 1
167+
finally:
168+
await asyncio.sleep(0.001)
169+
closes += 1
170+
171+
def testfunc(a: int = Depends(dep1)) -> int:
172+
return a
173+
174+
with DependencyGraph(testfunc).sync_ctx({}) as sctx:
175+
with pytest.raises(RuntimeError):
176+
assert sctx.resolve_kwargs() == {"a": 1}
177+
178+
async with DependencyGraph(testfunc).async_ctx({}) as actx:
179+
assert await actx.resolve_kwargs() == {"a": 1}
180+
assert starts == 1
181+
assert closes == 0
182+
assert closes == 1
183+
184+
114185
@pytest.mark.anyio
115186
async def test_dependency_subdeps() -> None:
116187
"""Tests how subdependencies work."""

0 commit comments

Comments
 (0)