1+ import typing
2+ from abc import ABC , abstractmethod
13from logging import getLogger
2- from typing import Any , AsyncGenerator , Callable , List , Optional , TypeVar , Union
34
45from nats .aio .client import Client
6+ from nats .aio .msg import Msg as NatsMessage
7+ from nats .errors import TimeoutError as NatsTimeoutError
58from nats .js import JetStreamContext
6- from nats .js .api import StreamConfig
9+ from nats .js .api import ConsumerConfig , StreamConfig
710from taskiq import AckableMessage , AsyncBroker , AsyncResultBackend , BrokerMessage
811
9- _T = TypeVar ("_T" ) # noqa: WPS111 (Too short)
12+ _T = typing .TypeVar ("_T" ) # noqa: WPS111 (Too short)
13+
14+
15+ JetStreamConsumerType = typing .TypeVar (
16+ "JetStreamConsumerType" ,
17+ )
1018
1119
1220logger = getLogger ("taskiq_nats" )
@@ -29,12 +37,12 @@ class NatsBroker(AsyncBroker):
2937
3038 def __init__ ( # noqa: WPS211 (too many args)
3139 self ,
32- servers : Union [str , List [str ]],
40+ servers : typing . Union [str , typing . List [str ]],
3341 subject : str = "taskiq_tasks" ,
34- queue : Optional [str ] = None ,
35- result_backend : "Optional[AsyncResultBackend[_T]]" = None ,
36- task_id_generator : Optional [Callable [[], str ]] = None ,
37- ** connection_kwargs : Any ,
42+ queue : typing . Optional [str ] = None ,
43+ result_backend : "typing. Optional[AsyncResultBackend[_T]]" = None ,
44+ task_id_generator : typing . Optional [typing . Callable [[], str ]] = None ,
45+ ** connection_kwargs : typing . Any ,
3846 ) -> None :
3947 super ().__init__ (result_backend , task_id_generator )
4048 self .servers = servers
@@ -64,7 +72,7 @@ async def kick(self, message: BrokerMessage) -> None:
6472 headers = message .labels ,
6573 )
6674
67- async def listen (self ) -> AsyncGenerator [bytes , None ]:
75+ async def listen (self ) -> typing . AsyncGenerator [bytes , None ]:
6876 """
6977 Start listen to new messages.
7078
@@ -80,37 +88,55 @@ async def shutdown(self) -> None:
8088 await super ().shutdown ()
8189
8290
83- class JetStreamBroker (AsyncBroker ): # noqa: WPS230
84- """
85- JetStream broker for taskiq.
91+ class BaseJetStreamBroker ( # noqa: WPS230 (too many attrs)
92+ AsyncBroker ,
93+ ABC ,
94+ typing .Generic [JetStreamConsumerType ],
95+ ):
96+ """Base JetStream broker for taskiq.
8697
87- This broker creates a JetStream context
88- and uses it to send and receive messages.
98+ It has two subclasses - PullBasedJetStreamBroker
99+ and PushBasedJetStreamBroker.
100+
101+ These brokers create a JetStream context
102+ and use it to send and receive messages.
89103
90104 This is useful for systems where you need to
91105 be sure that messages are delivered to the workers.
92106 """
93107
94108 def __init__ ( # noqa: WPS211 (too many args)
95109 self ,
96- servers : Union [str , List [str ]],
97- subject : str = "tasiq_tasks" ,
98- stream_name : str = "taskiq_jstream" ,
99- queue : Optional [str ] = None ,
100- result_backend : "Optional[AsyncResultBackend[_T]]" = None ,
101- task_id_generator : Optional [Callable [[], str ]] = None ,
102- stream_config : Optional [StreamConfig ] = None ,
103- ** connection_kwargs : Any ,
110+ servers : typing .Union [str , typing .List [str ]],
111+ subject : str = "taskiq_tasks" ,
112+ stream_name : str = "taskiq_jetstream" ,
113+ queue : typing .Optional [str ] = None ,
114+ durable : str = "taskiq_durable" ,
115+ stream_config : typing .Optional [StreamConfig ] = None ,
116+ consumer_config : typing .Optional [ConsumerConfig ] = None ,
117+ pull_consume_batch : int = 1 ,
118+ pull_consume_timeout : typing .Optional [float ] = None ,
119+ ** connection_kwargs : typing .Any ,
104120 ) -> None :
105- super ().__init__ (result_backend , task_id_generator )
106- self .servers = servers
107- self .client : Client = Client ()
108- self .connection_kwargs = connection_kwargs
109- self .queue = queue
110- self .subject = subject
111- self .stream_name = stream_name
121+ super ().__init__ ()
122+ self .servers : typing .Final = servers
123+ self .client : typing .Final = Client ()
124+ self .connection_kwargs : typing .Final = connection_kwargs
125+ self .subject : typing .Final = subject
126+ self .stream_name : typing .Final = stream_name
112127 self .js : JetStreamContext
113128 self .stream_config = stream_config or StreamConfig ()
129+ self .consumer_config = consumer_config
130+
131+ # Only for push based consumer
132+ self .queue : typing .Final = queue
133+ self .default_consumer_name : typing .Final = "taskiq_consumer"
134+ # Only for pull based consumer
135+ self .durable : typing .Final = durable
136+ self .pull_consume_batch : typing .Final = pull_consume_batch
137+ self .pull_consume_timeout : typing .Final = pull_consume_timeout
138+
139+ self .consumer : JetStreamConsumerType
114140
115141 async def startup (self ) -> None :
116142 """
@@ -127,6 +153,12 @@ async def startup(self) -> None:
127153 if not self .stream_config .subjects :
128154 self .stream_config .subjects = [self .subject ]
129155 await self .js .add_stream (config = self .stream_config )
156+ await self ._startup_consumer ()
157+
158+ async def shutdown (self ) -> None :
159+ """Close connections to NATS."""
160+ await self .client .close ()
161+ await super ().shutdown ()
130162
131163 async def kick (self , message : BrokerMessage ) -> None :
132164 """
@@ -140,20 +172,89 @@ async def kick(self, message: BrokerMessage) -> None:
140172 headers = message .labels ,
141173 )
142174
143- async def listen (self ) -> AsyncGenerator [AckableMessage , None ]:
175+ @abstractmethod
176+ async def _startup_consumer (self ) -> None :
177+ """Create consumer."""
178+
179+
180+ class PushBasedJetStreamBroker (
181+ BaseJetStreamBroker [JetStreamContext .PushSubscription ],
182+ ):
183+ """JetStream broker for push based message consumption.
184+
185+ It's named `push` based because nats server push messages to
186+ the consumer, not consumer requests them.
187+ """
188+
189+ async def listen (self ) -> typing .AsyncGenerator [AckableMessage , None ]:
144190 """
145191 Start listen to new messages.
146192
147193 :yield: incoming messages.
148194 """
149- subscribe = await self .js .subscribe (self .subject , queue = self .queue or "" )
150- async for message in subscribe .messages :
195+ async for message in self .consumer .messages :
151196 yield AckableMessage (
152197 data = message .data ,
153198 ack = message .ack ,
154199 )
155200
156- async def shutdown (self ) -> None :
157- """Close connections to NATS."""
158- await self .client .close ()
159- await super ().shutdown ()
201+ async def _startup_consumer (self ) -> None :
202+ if not self .consumer_config :
203+ self .consumer_config = ConsumerConfig (
204+ name = self .default_consumer_name ,
205+ durable_name = self .default_consumer_name ,
206+ )
207+
208+ self .consumer = await self .js .subscribe (
209+ subject = self .subject ,
210+ queue = self .queue or "" ,
211+ config = self .consumer_config ,
212+ )
213+
214+
215+ class PullBasedJetStreamBroker (
216+ BaseJetStreamBroker [JetStreamContext .PullSubscription ],
217+ ):
218+ """JetStream broker for pull based message consumption.
219+
220+ It's named `pull` based because consumer requests messages,
221+ not NATS server sends them.
222+ """
223+
224+ async def listen (self ) -> typing .AsyncGenerator [AckableMessage , None ]:
225+ """
226+ Start listen to new messages.
227+
228+ :yield: incoming messages.
229+ """
230+ while True : # noqa: WPS327
231+ try :
232+ nats_messages : typing .List [NatsMessage ] = await self .consumer .fetch (
233+ batch = self .pull_consume_batch ,
234+ timeout = self .pull_consume_timeout ,
235+ )
236+ for nats_message in nats_messages :
237+ yield AckableMessage (
238+ data = nats_message .data ,
239+ ack = nats_message .ack ,
240+ )
241+ except NatsTimeoutError :
242+ continue
243+
244+ async def _startup_consumer (self ) -> None :
245+ if not self .consumer_config :
246+ self .consumer_config = ConsumerConfig (
247+ durable_name = self .durable ,
248+ )
249+
250+ # We must use this method to create pull based consumer
251+ # because consumer config won't change without it.
252+ await self .js .add_consumer (
253+ stream = self .stream_config .name or self .stream_name ,
254+ config = self .consumer_config ,
255+ )
256+ self .consumer = await self .js .pull_subscribe (
257+ subject = self .subject ,
258+ durable = self .durable ,
259+ config = self .consumer_config ,
260+ )
0 commit comments