-
Notifications
You must be signed in to change notification settings - Fork 14
Expand file tree
/
Copy pathtest_faststream_di_pass_message.py
More file actions
61 lines (44 loc) · 1.83 KB
/
test_faststream_di_pass_message.py
File metadata and controls
61 lines (44 loc) · 1.83 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
import typing
from faststream import BaseMiddleware, Context, Depends
from faststream.nats import NatsBroker, TestNatsBroker
from faststream.nats.message import NatsMessage
from packaging.version import Version
from that_depends import BaseContainer, container_context, fetch_context_item, providers
from that_depends.integrations.faststream import _FASTSTREAM_VERSION
if Version(_FASTSTREAM_VERSION) >= Version("0.6.0"): # pragma: no cover
from faststream.message import StreamMessage
else: # pragma: no cover
from faststream.broker.message import StreamMessage # type: ignore[import-not-found, no-redef]
class ContextMiddleware(BaseMiddleware):
async def consume_scope(
self,
call_next: typing.Callable[..., typing.Awaitable[typing.Any]],
msg: StreamMessage[typing.Any],
) -> typing.Any: # noqa: ANN401
async with container_context(global_context={"request": msg}):
return await call_next(msg)
if Version(_FASTSTREAM_VERSION) >= Version("0.6.0"): # pragma: no cover
broker = NatsBroker(middlewares=(ContextMiddleware,))
else: # pragma: no cover
broker = NatsBroker(middlewares=(ContextMiddleware,), validate=False) # type: ignore[call-arg]
TEST_SUBJECT = "test"
class DIContainer(BaseContainer):
context_request = providers.Factory(
lambda: fetch_context_item("request"),
)
@broker.subscriber(TEST_SUBJECT)
async def index_subscriber(
context_request: typing.Annotated[
NatsMessage,
Depends(DIContainer.context_request, cast=False),
],
message: typing.Annotated[
NatsMessage,
Context(),
],
) -> bool:
return message is context_request
async def test_read_main() -> None:
async with TestNatsBroker(broker) as br:
result = await br.request(None, TEST_SUBJECT)
assert (await result.decode()) is True