|
2 | 2 | from typing import Any, AsyncGenerator, Callable, List, Optional, TypeVar, Union |
3 | 3 |
|
4 | 4 | from nats.aio.client import Client |
| 5 | +from nats.js import JetStreamContext |
| 6 | +from nats.js.api import StreamConfig |
5 | 7 | from taskiq import AsyncBroker, AsyncResultBackend, BrokerMessage |
6 | 8 |
|
7 | 9 | _T = TypeVar("_T") # noqa: WPS111 (Too short) |
@@ -76,3 +78,80 @@ async def shutdown(self) -> None: |
76 | 78 | """Close connections to NATS.""" |
77 | 79 | await self.client.close() |
78 | 80 | await super().shutdown() |
| 81 | + |
| 82 | + |
| 83 | +class JetStreamBroker(AsyncBroker): # noqa: WPS230 |
| 84 | + """ |
| 85 | + JetStream broker for taskiq. |
| 86 | +
|
| 87 | + This broker creates a JetStream context |
| 88 | + and uses it to send and receive messages. |
| 89 | +
|
| 90 | + This is useful for systems where you need to |
| 91 | + be sure that messages are delivered to the workers. |
| 92 | + """ |
| 93 | + |
| 94 | + def __init__( # noqa: WPS211 (too many args) |
| 95 | + self, |
| 96 | + servers: Union[str, List[str]], |
| 97 | + subject: str = "tasiq_tasks", |
| 98 | + stream_name: str = "taskiq_jstream", |
| 99 | + queue: Optional[str] = None, |
| 100 | + result_backend: "Optional[AsyncResultBackend[_T]]" = None, |
| 101 | + task_id_generator: Optional[Callable[[], str]] = None, |
| 102 | + stream_config: Optional[StreamConfig] = None, |
| 103 | + **connection_kwargs: Any, |
| 104 | + ) -> None: |
| 105 | + super().__init__(result_backend, task_id_generator) |
| 106 | + self.servers = servers |
| 107 | + self.client: Client = Client() |
| 108 | + self.connection_kwargs = connection_kwargs |
| 109 | + self.queue = queue |
| 110 | + self.subject = subject |
| 111 | + self.stream_name = stream_name |
| 112 | + self.js: JetStreamContext |
| 113 | + self.stream_config = stream_config or StreamConfig() |
| 114 | + |
| 115 | + async def startup(self) -> None: |
| 116 | + """ |
| 117 | + Startup event handler. |
| 118 | +
|
| 119 | + It simply connects to NATS cluster, and |
| 120 | + setup JetStream. |
| 121 | + """ |
| 122 | + await super().startup() |
| 123 | + await self.client.connect(self.servers, **self.connection_kwargs) |
| 124 | + self.js = self.client.jetstream() |
| 125 | + if self.stream_config.name is None: |
| 126 | + self.stream_config.name = self.stream_name |
| 127 | + if not self.stream_config.subjects: |
| 128 | + self.stream_config.subjects = [self.subject] |
| 129 | + await self.js.add_stream(config=self.stream_config) |
| 130 | + |
| 131 | + async def kick(self, message: BrokerMessage) -> None: |
| 132 | + """ |
| 133 | + Send a message using NATS. |
| 134 | +
|
| 135 | + :param message: message to send. |
| 136 | + """ |
| 137 | + await self.js.publish( |
| 138 | + self.subject, |
| 139 | + payload=message.message, |
| 140 | + headers=message.labels, |
| 141 | + ) |
| 142 | + |
| 143 | + async def listen(self) -> AsyncGenerator[bytes, None]: |
| 144 | + """ |
| 145 | + Start listen to new messages. |
| 146 | +
|
| 147 | + :yield: incoming messages. |
| 148 | + """ |
| 149 | + subscribe = await self.js.subscribe(self.subject, queue=self.queue or "") |
| 150 | + async for message in subscribe.messages: |
| 151 | + yield message.data |
| 152 | + await message.ack() |
| 153 | + |
| 154 | + async def shutdown(self) -> None: |
| 155 | + """Close connections to NATS.""" |
| 156 | + await self.client.close() |
| 157 | + await super().shutdown() |
0 commit comments