Skip to content

Commit 4dd8a4e

Browse files
authored
Merge pull request #12 from taskiq-python/improve_jetstream_broker
2 parents d7ac8c9 + a45d8e3 commit 4dd8a4e

8 files changed

Lines changed: 501 additions & 190 deletions

File tree

.github/workflows/test.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
name: Testing taskiq-redis
1+
name: Testing taskiq-nats
22

33
on: pull_request
44

README.md

Lines changed: 58 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ pip install taskiq taskiq-nats
1515

1616
Here's a minimal setup example with a broker and one task.
1717

18+
### Default NATS broker.
1819
```python
1920
import asyncio
2021
from taskiq_nats import NatsBroker, JetStreamBroker
@@ -27,15 +28,47 @@ broker = NatsBroker(
2728
queue="random_queue_name",
2829
)
2930

30-
# Or alternatively you can use a JetStream broker:
31-
broker = JetStreamBroker(
32-
[
31+
32+
@broker.task
33+
async def my_lovely_task():
34+
print("I love taskiq")
35+
36+
37+
async def main():
38+
await broker.startup()
39+
40+
await my_lovely_task.kiq()
41+
42+
await broker.shutdown()
43+
44+
45+
if __name__ == "__main__":
46+
asyncio.run(main())
47+
48+
```
49+
### NATS broker based on JetStream
50+
```python
51+
import asyncio
52+
from taskiq_nats import (
53+
PushBasedJetStreamBroker,
54+
PullBasedJetStreamBroker
55+
)
56+
57+
broker = PushBasedJetStreamBroker(
58+
servers=[
3359
"nats://nats1:4222",
3460
"nats://nats2:4222",
3561
],
36-
queue="random_queue_name",
37-
subject="my-subj",
38-
stream_name="my-stream"
62+
queue="awesome_queue_name",
63+
)
64+
65+
# Or you can use pull based variant
66+
broker = PullBasedJetStreamBroker(
67+
servers=[
68+
"nats://nats1:4222",
69+
"nats://nats2:4222",
70+
],
71+
durable="awesome_durable_consumer_name",
3972
)
4073

4174

@@ -54,7 +87,6 @@ async def main():
5487

5588
if __name__ == "__main__":
5689
asyncio.run(main())
57-
5890
```
5991

6092
## NatsBroker configuration
@@ -68,3 +100,22 @@ Here's the constructor parameters:
68100
* `result_backend` - custom result backend.
69101
* `task_id_generator` - custom function to generate task ids.
70102
* Every other keyword argument will be sent to `nats.connect` function.
103+
104+
## JetStreamBroker configuration
105+
### Common
106+
* `servers` - a single string or a list of strings with nats nodes addresses.
107+
* `subject` - name of the subect that will be used to exchange tasks betwee workers and clients.
108+
* `stream_name` - name of the stream where subjects will be located.
109+
* `queue` - a single string or a list of strings with nats nodes addresses.
110+
* `result_backend` - custom result backend.
111+
* `task_id_generator` - custom function to generate task ids.
112+
* `stream_config` - a config for stream.
113+
* `consumer_config` - a config for consumer.
114+
115+
### PushBasedJetStreamBroker
116+
* `queue` - name of the queue. It's used to share messages between different consumers.
117+
118+
### PullBasedJetStreamBroker
119+
* `durable` - durable name of the consumer. It's used to share messages between different consumers.
120+
* `pull_consume_batch` - maximum number of message that can be fetched each time.
121+
* `pull_consume_timeout` - timeout for messages fetch. If there is no messages, we start fetching messages again.

poetry.lock

Lines changed: 182 additions & 101 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

taskiq_nats/__init__.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,14 @@
55
uses NATS as a message queue.
66
"""
77

8-
from taskiq_nats.broker import JetStreamBroker, NatsBroker
8+
from taskiq_nats.broker import (
9+
NatsBroker,
10+
PullBasedJetStreamBroker,
11+
PushBasedJetStreamBroker,
12+
)
913

10-
__all__ = ["NatsBroker", "JetStreamBroker"]
14+
__all__ = [
15+
"NatsBroker",
16+
"PushBasedJetStreamBroker",
17+
"PullBasedJetStreamBroker",
18+
]

taskiq_nats/broker.py

Lines changed: 137 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,20 @@
1+
import typing
2+
from abc import ABC, abstractmethod
13
from logging import getLogger
2-
from typing import Any, AsyncGenerator, Callable, List, Optional, TypeVar, Union
34

45
from nats.aio.client import Client
6+
from nats.aio.msg import Msg as NatsMessage
7+
from nats.errors import TimeoutError as NatsTimeoutError
58
from nats.js import JetStreamContext
6-
from nats.js.api import StreamConfig
9+
from nats.js.api import ConsumerConfig, StreamConfig
710
from taskiq import AckableMessage, AsyncBroker, AsyncResultBackend, BrokerMessage
811

9-
_T = TypeVar("_T") # noqa: WPS111 (Too short)
12+
_T = typing.TypeVar("_T") # noqa: WPS111 (Too short)
13+
14+
15+
JetStreamConsumerType = typing.TypeVar(
16+
"JetStreamConsumerType",
17+
)
1018

1119

1220
logger = getLogger("taskiq_nats")
@@ -29,12 +37,12 @@ class NatsBroker(AsyncBroker):
2937

3038
def __init__( # noqa: WPS211 (too many args)
3139
self,
32-
servers: Union[str, List[str]],
40+
servers: typing.Union[str, typing.List[str]],
3341
subject: str = "taskiq_tasks",
34-
queue: Optional[str] = None,
35-
result_backend: "Optional[AsyncResultBackend[_T]]" = None,
36-
task_id_generator: Optional[Callable[[], str]] = None,
37-
**connection_kwargs: Any,
42+
queue: typing.Optional[str] = None,
43+
result_backend: "typing.Optional[AsyncResultBackend[_T]]" = None,
44+
task_id_generator: typing.Optional[typing.Callable[[], str]] = None,
45+
**connection_kwargs: typing.Any,
3846
) -> None:
3947
super().__init__(result_backend, task_id_generator)
4048
self.servers = servers
@@ -64,7 +72,7 @@ async def kick(self, message: BrokerMessage) -> None:
6472
headers=message.labels,
6573
)
6674

67-
async def listen(self) -> AsyncGenerator[bytes, None]:
75+
async def listen(self) -> typing.AsyncGenerator[bytes, None]:
6876
"""
6977
Start listen to new messages.
7078
@@ -80,37 +88,55 @@ async def shutdown(self) -> None:
8088
await super().shutdown()
8189

8290

83-
class JetStreamBroker(AsyncBroker): # noqa: WPS230
84-
"""
85-
JetStream broker for taskiq.
91+
class BaseJetStreamBroker( # noqa: WPS230 (too many attrs)
92+
AsyncBroker,
93+
ABC,
94+
typing.Generic[JetStreamConsumerType],
95+
):
96+
"""Base JetStream broker for taskiq.
8697
87-
This broker creates a JetStream context
88-
and uses it to send and receive messages.
98+
It has two subclasses - PullBasedJetStreamBroker
99+
and PushBasedJetStreamBroker.
100+
101+
These brokers create a JetStream context
102+
and use it to send and receive messages.
89103
90104
This is useful for systems where you need to
91105
be sure that messages are delivered to the workers.
92106
"""
93107

94108
def __init__( # noqa: WPS211 (too many args)
95109
self,
96-
servers: Union[str, List[str]],
97-
subject: str = "tasiq_tasks",
98-
stream_name: str = "taskiq_jstream",
99-
queue: Optional[str] = None,
100-
result_backend: "Optional[AsyncResultBackend[_T]]" = None,
101-
task_id_generator: Optional[Callable[[], str]] = None,
102-
stream_config: Optional[StreamConfig] = None,
103-
**connection_kwargs: Any,
110+
servers: typing.Union[str, typing.List[str]],
111+
subject: str = "taskiq_tasks",
112+
stream_name: str = "taskiq_jetstream",
113+
queue: typing.Optional[str] = None,
114+
durable: str = "taskiq_durable",
115+
stream_config: typing.Optional[StreamConfig] = None,
116+
consumer_config: typing.Optional[ConsumerConfig] = None,
117+
pull_consume_batch: int = 1,
118+
pull_consume_timeout: typing.Optional[float] = None,
119+
**connection_kwargs: typing.Any,
104120
) -> None:
105-
super().__init__(result_backend, task_id_generator)
106-
self.servers = servers
107-
self.client: Client = Client()
108-
self.connection_kwargs = connection_kwargs
109-
self.queue = queue
110-
self.subject = subject
111-
self.stream_name = stream_name
121+
super().__init__()
122+
self.servers: typing.Final = servers
123+
self.client: typing.Final = Client()
124+
self.connection_kwargs: typing.Final = connection_kwargs
125+
self.subject: typing.Final = subject
126+
self.stream_name: typing.Final = stream_name
112127
self.js: JetStreamContext
113128
self.stream_config = stream_config or StreamConfig()
129+
self.consumer_config = consumer_config
130+
131+
# Only for push based consumer
132+
self.queue: typing.Final = queue
133+
self.default_consumer_name: typing.Final = "taskiq_consumer"
134+
# Only for pull based consumer
135+
self.durable: typing.Final = durable
136+
self.pull_consume_batch: typing.Final = pull_consume_batch
137+
self.pull_consume_timeout: typing.Final = pull_consume_timeout
138+
139+
self.consumer: JetStreamConsumerType
114140

115141
async def startup(self) -> None:
116142
"""
@@ -127,6 +153,12 @@ async def startup(self) -> None:
127153
if not self.stream_config.subjects:
128154
self.stream_config.subjects = [self.subject]
129155
await self.js.add_stream(config=self.stream_config)
156+
await self._startup_consumer()
157+
158+
async def shutdown(self) -> None:
159+
"""Close connections to NATS."""
160+
await self.client.close()
161+
await super().shutdown()
130162

131163
async def kick(self, message: BrokerMessage) -> None:
132164
"""
@@ -140,20 +172,89 @@ async def kick(self, message: BrokerMessage) -> None:
140172
headers=message.labels,
141173
)
142174

143-
async def listen(self) -> AsyncGenerator[AckableMessage, None]:
175+
@abstractmethod
176+
async def _startup_consumer(self) -> None:
177+
"""Create consumer."""
178+
179+
180+
class PushBasedJetStreamBroker(
181+
BaseJetStreamBroker[JetStreamContext.PushSubscription],
182+
):
183+
"""JetStream broker for push based message consumption.
184+
185+
It's named `push` based because nats server push messages to
186+
the consumer, not consumer requests them.
187+
"""
188+
189+
async def listen(self) -> typing.AsyncGenerator[AckableMessage, None]:
144190
"""
145191
Start listen to new messages.
146192
147193
:yield: incoming messages.
148194
"""
149-
subscribe = await self.js.subscribe(self.subject, queue=self.queue or "")
150-
async for message in subscribe.messages:
195+
async for message in self.consumer.messages:
151196
yield AckableMessage(
152197
data=message.data,
153198
ack=message.ack,
154199
)
155200

156-
async def shutdown(self) -> None:
157-
"""Close connections to NATS."""
158-
await self.client.close()
159-
await super().shutdown()
201+
async def _startup_consumer(self) -> None:
202+
if not self.consumer_config:
203+
self.consumer_config = ConsumerConfig(
204+
name=self.default_consumer_name,
205+
durable_name=self.default_consumer_name,
206+
)
207+
208+
self.consumer = await self.js.subscribe(
209+
subject=self.subject,
210+
queue=self.queue or "",
211+
config=self.consumer_config,
212+
)
213+
214+
215+
class PullBasedJetStreamBroker(
216+
BaseJetStreamBroker[JetStreamContext.PullSubscription],
217+
):
218+
"""JetStream broker for pull based message consumption.
219+
220+
It's named `pull` based because consumer requests messages,
221+
not NATS server sends them.
222+
"""
223+
224+
async def listen(self) -> typing.AsyncGenerator[AckableMessage, None]:
225+
"""
226+
Start listen to new messages.
227+
228+
:yield: incoming messages.
229+
"""
230+
while True: # noqa: WPS327
231+
try:
232+
nats_messages: typing.List[NatsMessage] = await self.consumer.fetch(
233+
batch=self.pull_consume_batch,
234+
timeout=self.pull_consume_timeout,
235+
)
236+
for nats_message in nats_messages:
237+
yield AckableMessage(
238+
data=nats_message.data,
239+
ack=nats_message.ack,
240+
)
241+
except NatsTimeoutError:
242+
continue
243+
244+
async def _startup_consumer(self) -> None:
245+
if not self.consumer_config:
246+
self.consumer_config = ConsumerConfig(
247+
durable_name=self.durable,
248+
)
249+
250+
# We must use this method to create pull based consumer
251+
# because consumer config won't change without it.
252+
await self.js.add_consumer(
253+
stream=self.stream_config.name or self.stream_name,
254+
config=self.consumer_config,
255+
)
256+
self.consumer = await self.js.pull_subscribe(
257+
subject=self.subject,
258+
durable=self.durable,
259+
config=self.consumer_config,
260+
)

tests/conftest.py

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,10 @@
11
import os
22
import uuid
3-
from typing import List
3+
from typing import AsyncGenerator, Final, List
44

55
import pytest
6+
from nats import NATS
7+
from nats.js import JetStreamContext
68

79

810
@pytest.fixture(scope="session")
@@ -38,3 +40,20 @@ def nats_urls() -> List[str]:
3840
"""
3941
urls = os.environ.get("NATS_URLS") or "nats://localhost:4222"
4042
return urls.split(",")
43+
44+
45+
@pytest.fixture()
46+
async def nats_jetstream(
47+
nats_urls: List[str], # noqa: WPS442
48+
) -> AsyncGenerator[JetStreamContext, None]:
49+
"""Create and yield nats client and jetstream instances.
50+
51+
:param nats_urls: urls to nats.
52+
53+
:yields: NATS JetStream.
54+
"""
55+
nats: Final = NATS()
56+
await nats.connect(servers=nats_urls)
57+
jetstream: Final = nats.jetstream()
58+
yield jetstream
59+
await nats.close()

0 commit comments

Comments
 (0)