Skip to content

Commit c4dda5b

Browse files
authored
Merge pull request #25 from cofob/feature-support-contextmanagers
2 parents cafad33 + 6b85e2b commit c4dda5b

8 files changed

Lines changed: 256 additions & 21 deletions

File tree

.flake8

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -92,6 +92,8 @@ ignore =
9292
WPS424,
9393
; Found a too complex `f` string
9494
WPS237,
95+
; Found `no cover` comments overuse
96+
WPS403,
9597

9698
per-file-ignores =
9799
; all tests

poetry.lock

Lines changed: 7 additions & 7 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

pyproject.toml

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ version = "1.5.0"
44
description = "FastAPI like dependency injection implementation"
55
authors = ["Pavel Kirilin <win10@list.ru>"]
66
readme = "README.md"
7-
packages = [{include = "taskiq_dependencies"}]
7+
packages = [{ include = "taskiq_dependencies" }]
88
classifiers = [
99
"Typing :: Typed",
1010
"Programming Language :: Python",
@@ -21,7 +21,8 @@ keywords = ["taskiq", "dependencies", "injection", "async", "DI"]
2121

2222
[tool.poetry.dependencies]
2323
python = "^3.8.1"
24-
graphlib-backport = { version = "^1.0.3", python="<3.9" }
24+
graphlib-backport = { version = "^1.0.3", python = "<3.9" }
25+
typing-extensions = { version = "^4.6.3", python = "<3.10" }
2526

2627
[tool.poetry.group.dev.dependencies]
2728
pytest = "^7.1.3"

taskiq_dependencies/ctx.py

Lines changed: 25 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
from logging import getLogger
66
from typing import TYPE_CHECKING, Any, DefaultDict, Dict, Generator, List, Optional
77

8-
from taskiq_dependencies.utils import ParamInfo
8+
from taskiq_dependencies.utils import ParamInfo, isasynccontextmanager, iscontextmanager
99

1010
if TYPE_CHECKING:
1111
from taskiq_dependencies.graph import DependencyGraph # pragma: no cover
@@ -59,7 +59,7 @@ def traverse_deps( # noqa: C901, WPS210
5959
# later.
6060
if not dep.use_cache:
6161
continue
62-
# If somehow we have dependency with unknwon function.
62+
# If somehow we have dependency with unknown function.
6363
if dep.dependency is None:
6464
continue
6565
# If dependency is already calculated.
@@ -180,6 +180,8 @@ def close(self, *args: Any) -> None: # noqa: C901
180180
continue
181181
for _ in dep: # noqa: WPS328
182182
pass # noqa: WPS420
183+
elif iscontextmanager(dep):
184+
dep.__exit__(*args) # noqa: WPS609
183185

184186
def resolver(self, executed_func: Any, initial_cache: Dict[Any, Any]) -> Any:
185187
"""
@@ -194,7 +196,7 @@ def resolver(self, executed_func: Any, initial_cache: Dict[Any, Any]) -> Any:
194196
195197
:return: dict with resolved kwargs.
196198
"""
197-
if getattr(executed_func, "dep_graph", False):
199+
if getattr(executed_func, "dep_graph", False): # noqa: WPS223
198200
ctx = SyncResolveContext(executed_func, initial_cache)
199201
self.sub_contexts.append(ctx)
200202
sub_result = ctx.resolve_kwargs()
@@ -206,7 +208,10 @@ def resolver(self, executed_func: Any, initial_cache: Dict[Any, Any]) -> Any:
206208
"Coroutines cannot be used in sync context. "
207209
"Please use async context instead.",
208210
)
209-
elif inspect.isasyncgen(executed_func):
211+
elif iscontextmanager(executed_func):
212+
sub_result = executed_func.__enter__() # noqa: WPS609
213+
self.opened_dependencies.append(executed_func)
214+
elif inspect.isasyncgen(executed_func) or isasynccontextmanager(executed_func):
210215
raise RuntimeError(
211216
"Coroutines cannot be used in sync context. "
212217
"Please use async context instead.",
@@ -299,8 +304,16 @@ async def close(self, *args: Any) -> None: # noqa: C901
299304
continue
300305
async for _ in dep: # noqa: WPS328
301306
pass # noqa: WPS420
307+
elif iscontextmanager(dep):
308+
dep.__exit__(*args) # noqa: WPS609
309+
elif isasynccontextmanager(dep):
310+
await dep.__aexit__(*args) # noqa: WPS609
302311

303-
async def resolver(self, executed_func: Any, initial_cache: Dict[Any, Any]) -> Any:
312+
async def resolver( # noqa: C901
313+
self,
314+
executed_func: Any,
315+
initial_cache: Dict[Any, Any],
316+
) -> Any:
304317
"""
305318
Async resolver.
306319
@@ -311,7 +324,7 @@ async def resolver(self, executed_func: Any, initial_cache: Dict[Any, Any]) -> A
311324
:param initial_cache: cache to build a context if graph was passed.
312325
:return: dict with resolved kwargs.
313326
"""
314-
if getattr(executed_func, "dep_graph", False):
327+
if getattr(executed_func, "dep_graph", False): # noqa: WPS223
315328
ctx = AsyncResolveContext(executed_func, initial_cache) # type: ignore
316329
self.sub_contexts.append(ctx)
317330
sub_result = await ctx.resolve_kwargs()
@@ -323,6 +336,12 @@ async def resolver(self, executed_func: Any, initial_cache: Dict[Any, Any]) -> A
323336
elif inspect.isasyncgen(executed_func):
324337
sub_result = await executed_func.__anext__() # noqa: WPS609
325338
self.opened_dependencies.append(executed_func)
339+
elif iscontextmanager(executed_func):
340+
sub_result = executed_func.__enter__() # noqa: WPS609
341+
self.opened_dependencies.append(executed_func)
342+
elif isasynccontextmanager(executed_func):
343+
sub_result = await executed_func.__aenter__() # noqa: WPS609
344+
self.opened_dependencies.append(executed_func)
326345
else:
327346
sub_result = executed_func
328347
return sub_result

taskiq_dependencies/dependency.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,10 @@
22
import uuid
33
from typing import ( # noqa: WPS235
44
Any,
5+
AsyncContextManager,
56
AsyncGenerator,
67
Callable,
8+
ContextManager,
79
Coroutine,
810
Dict,
911
Generator,
@@ -17,6 +19,26 @@
1719
_T = TypeVar("_T") # noqa: WPS111
1820

1921

22+
@overload
23+
def Depends( # noqa: WPS234
24+
dependency: Optional[Callable[..., ContextManager[_T]]] = None,
25+
*,
26+
use_cache: bool = True,
27+
kwargs: Optional[Dict[str, Any]] = None,
28+
) -> _T: # pragma: no cover
29+
...
30+
31+
32+
@overload
33+
def Depends( # noqa: WPS234
34+
dependency: Optional[Callable[..., AsyncContextManager[_T]]] = None,
35+
*,
36+
use_cache: bool = True,
37+
kwargs: Optional[Dict[str, Any]] = None,
38+
) -> _T: # pragma: no cover
39+
...
40+
41+
2042
@overload
2143
def Depends( # noqa: WPS234
2244
dependency: Optional[Callable[..., AsyncGenerator[_T, None]]] = None,

taskiq_dependencies/utils.py

Lines changed: 35 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,11 @@
11
import inspect
2-
from typing import Optional
2+
import sys
3+
from typing import Any, AsyncContextManager, ContextManager, Optional
4+
5+
if sys.version_info >= (3, 10):
6+
from typing import TypeGuard # noqa: WPS433
7+
else:
8+
from typing_extensions import TypeGuard # noqa: WPS433
39

410

511
class ParamInfo:
@@ -23,3 +29,31 @@ def __init__(
2329

2430
def __repr__(self) -> str:
2531
return f"ParamInfo<name={self.name}>"
32+
33+
34+
def iscontextmanager(obj: Any) -> TypeGuard[ContextManager[Any]]:
35+
"""
36+
Return true if the object is a sync context manager.
37+
38+
:param obj: object to check.
39+
:return: bool that indicates whether the object is a context manager or not.
40+
"""
41+
if not hasattr(obj, "__enter__"): # noqa: WPS421
42+
return False
43+
elif not hasattr(obj, "__exit__"): # noqa: WPS421
44+
return False
45+
return True
46+
47+
48+
def isasynccontextmanager(obj: Any) -> TypeGuard[AsyncContextManager[Any]]:
49+
"""
50+
Return true if the object is a async context manager.
51+
52+
:param obj: object to check.
53+
:return: bool that indicates whether the object is a async context manager or not.
54+
"""
55+
if not hasattr(obj, "__aenter__"): # noqa: WPS421
56+
return False
57+
elif not hasattr(obj, "__aexit__"): # noqa: WPS421
58+
return False
59+
return True

tests/test_annotated.py

Lines changed: 93 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,12 @@
11
import sys
2+
from contextlib import asynccontextmanager, contextmanager
23

34
import pytest
45

56
if sys.version_info < (3, 10):
67
pytest.skip("Annotated is available only for python 3.10+", allow_module_level=True)
78

8-
from typing import Annotated, AsyncGenerator, Generic, Tuple, TypeVar
9+
from typing import Annotated, AsyncGenerator, Generator, Generic, Tuple, TypeVar
910

1011
from taskiq_dependencies import DependencyGraph, Depends
1112

@@ -53,6 +54,38 @@ def test_func(a: Annotated[MainClass[MyClass], Depends()]) -> MyClass:
5354
assert isinstance(value, MyClass)
5455

5556

57+
@pytest.mark.anyio
58+
async def test_annotated_gen() -> None:
59+
opened = False
60+
closed = False
61+
62+
def my_gen() -> Generator[int, None, None]:
63+
nonlocal opened, closed
64+
opened = True
65+
66+
yield 1
67+
68+
closed = True
69+
70+
def test_func(dep: Annotated[int, Depends(my_gen)]) -> int:
71+
return dep
72+
73+
with DependencyGraph(target=test_func).sync_ctx() as sctx:
74+
value = test_func(**sctx.resolve_kwargs())
75+
assert value == 1
76+
77+
assert opened and closed
78+
79+
opened = False
80+
closed = False
81+
82+
async with DependencyGraph(target=test_func).async_ctx() as actx:
83+
value = test_func(**(await actx.resolve_kwargs()))
84+
assert value == 1
85+
86+
assert opened and closed
87+
88+
5689
@pytest.mark.anyio
5790
async def test_annotated_asyncgen() -> None:
5891
opened = False
@@ -76,6 +109,65 @@ def test_func(dep: Annotated[int, Depends(my_gen)]) -> int:
76109
assert opened and closed
77110

78111

112+
@pytest.mark.anyio
113+
async def test_annotated_manager() -> None:
114+
opened = False
115+
closed = False
116+
117+
@contextmanager
118+
def my_gen() -> Generator[int, None, None]:
119+
nonlocal opened, closed
120+
opened = True
121+
122+
try:
123+
yield 1
124+
finally:
125+
closed = True
126+
127+
def test_func(dep: Annotated[int, Depends(my_gen)]) -> int:
128+
return dep
129+
130+
with DependencyGraph(target=test_func).sync_ctx() as sctx:
131+
value = test_func(**sctx.resolve_kwargs())
132+
assert value == 1
133+
134+
assert opened and closed
135+
136+
opened = False
137+
closed = False
138+
139+
async with DependencyGraph(target=test_func).async_ctx() as actx:
140+
value = test_func(**(await actx.resolve_kwargs()))
141+
assert value == 1
142+
143+
assert opened and closed
144+
145+
146+
@pytest.mark.anyio
147+
async def test_annotated_asyncmanager() -> None:
148+
opened = False
149+
closed = False
150+
151+
@asynccontextmanager
152+
async def my_gen() -> AsyncGenerator[int, None]:
153+
nonlocal opened, closed
154+
opened = True
155+
156+
try:
157+
yield 1
158+
finally:
159+
closed = True
160+
161+
def test_func(dep: Annotated[int, Depends(my_gen)]) -> int:
162+
return dep
163+
164+
async with DependencyGraph(target=test_func).async_ctx() as g:
165+
value = test_func(**(await g.resolve_kwargs()))
166+
assert value == 1
167+
168+
assert opened and closed
169+
170+
79171
def test_multiple() -> None:
80172
class TestClass:
81173
pass

0 commit comments

Comments
 (0)