forked from taskiq-python/taskiq-redis
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathredis_broker.py
More file actions
319 lines (278 loc) · 12.3 KB
/
redis_broker.py
File metadata and controls
319 lines (278 loc) · 12.3 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
import uuid
from collections.abc import AsyncGenerator, Awaitable, Callable
from logging import getLogger
from typing import (
TYPE_CHECKING,
Any,
TypeAlias,
TypeVar,
)
from redis.asyncio import BlockingConnectionPool, Connection, Redis, ResponseError
from taskiq import AckableMessage
from taskiq.abc.broker import AsyncBroker
from taskiq.abc.result_backend import AsyncResultBackend
from taskiq.message import BrokerMessage
_T = TypeVar("_T")
logger = getLogger("taskiq.redis_broker")
if TYPE_CHECKING:
_BlockingConnectionPool: TypeAlias = BlockingConnectionPool[Connection] # type: ignore
else:
_BlockingConnectionPool: TypeAlias = BlockingConnectionPool
class BaseRedisBroker(AsyncBroker):
"""Base broker that works with Redis."""
def __init__(
self,
url: str,
task_id_generator: Callable[[], str] | None = None,
result_backend: AsyncResultBackend[_T] | None = None,
queue_name: str = "taskiq",
max_connection_pool_size: int | None = None,
**connection_kwargs: Any,
) -> None:
"""
Constructs a new broker.
:param url: url to redis.
:param task_id_generator: custom task_id generator.
:param result_backend: custom result backend.
:param queue_name: name for a list in redis.
:param max_connection_pool_size: maximum number of connections in pool.
Each worker opens its own connection. Therefore this value has to be
at least number of workers + 1.
:param connection_kwargs: additional arguments for redis BlockingConnectionPool.
"""
super().__init__(
result_backend=result_backend,
task_id_generator=task_id_generator,
)
self.connection_pool: _BlockingConnectionPool = BlockingConnectionPool.from_url(
url=url,
max_connections=max_connection_pool_size,
**connection_kwargs,
)
self.queue_name = queue_name
async def shutdown(self) -> None:
"""Closes redis connection pool."""
await super().shutdown()
await self.connection_pool.disconnect()
class PubSubBroker(BaseRedisBroker):
"""Broker that works with Redis and broadcasts tasks to all workers."""
async def kick(self, message: BrokerMessage) -> None:
"""
Publish message over PUBSUB channel.
:param message: message to send.
"""
queue_name = message.labels.get("queue_name") or self.queue_name
async with Redis(connection_pool=self.connection_pool) as redis_conn:
await redis_conn.publish(queue_name, message.message)
async def listen(self) -> AsyncGenerator[bytes, None]:
"""
Listen redis queue for new messages.
This function listens to the pubsub channel
and yields all messages with proper types.
:yields: broker messages.
"""
async with Redis(connection_pool=self.connection_pool) as redis_conn:
redis_pubsub_channel = redis_conn.pubsub()
await redis_pubsub_channel.subscribe(self.queue_name)
async for message in redis_pubsub_channel.listen():
if not message:
continue
if message["type"] != "message":
logger.debug("Received non-message from redis: %s", message)
continue
yield message["data"]
class ListQueueBroker(BaseRedisBroker):
"""Broker that works with Redis and distributes tasks between workers."""
async def kick(self, message: BrokerMessage) -> None:
"""
Put a message in a list.
This method appends a message to the list of all messages.
:param message: message to append.
"""
queue_name = message.labels.get("queue_name") or self.queue_name
async with Redis(connection_pool=self.connection_pool) as redis_conn:
await redis_conn.lpush(queue_name, message.message) # type: ignore
async def listen(self) -> AsyncGenerator[bytes, None]:
"""
Listen redis queue for new messages.
This function listens to the queue
and yields new messages if they have BrokerMessage type.
:yields: broker messages.
"""
redis_brpop_data_position = 1
while True:
try:
async with Redis(connection_pool=self.connection_pool) as redis_conn:
yield (await redis_conn.brpop(self.queue_name))[ # type: ignore
redis_brpop_data_position
]
except ConnectionError as exc:
logger.warning("Redis connection error: %s", exc)
continue
class RedisStreamBroker(BaseRedisBroker):
"""
Redis broker that uses streams for task distribution.
You can read more about streams here:
https://redis.io/docs/latest/develop/data-types/streams
This broker supports acknowledgment of messages.
"""
def __init__(
self,
url: str,
queue_name: str = "taskiq",
max_connection_pool_size: int | None = None,
consumer_group_name: str = "taskiq",
consumer_name: str | None = None,
consumer_id: str = "$",
mkstream: bool = True,
xread_block: int = 2000,
maxlen: int | None = None,
approximate: bool = True,
idle_timeout: int = 600000, # 10 minutes
unacknowledged_batch_size: int = 100,
unacknowledged_lock_timeout: float | None = None,
xread_count: int | None = 100,
additional_streams: dict[str, str | int] | None = None,
**connection_kwargs: Any,
) -> None:
"""
Constructs a new broker that uses streams.
:param url: url to redis.
:param queue_name: name for a key with stream in redis.
:param max_connection_pool_size: maximum number of connections in pool.
Each worker opens its own connection. Therefore this value has to be
at least number of workers + 1.
:param consumer_group_name: name for a consumer group.
Redis will keep track of acked messages for this group.
:param consumer_name: name for a consumer. By default it is a random uuid.
:param consumer_id: id for a consumer. ID of a message to start reading from.
$ means start from the latest message.
:param mkstream: create stream if it does not exist.
:param xread_block: block time in ms for xreadgroup.
Better to set it to a bigger value, to avoid unnecessary calls.
:param maxlen: sets the maximum length of the stream
trims (the old values of) the stream each time a new element is added
:param approximate: decides wether to trim the stream immediately (False) or
later on (True)
:param xread_count: number of messages to fetch from the stream at once.
:param additional_streams: additional streams to read from.
Each key is a stream name, value is a consumer id.
:param unacknowledged_batch_size: number of unacknowledged messages to fetch.
:param unacknowledged_lock_timeout: time in seconds before auto-releasing
the lock. Useful when the worker crashes or gets killed.
If not set, the lock can remain locked indefinitely.
"""
super().__init__(
url,
task_id_generator=None,
result_backend=None,
queue_name=queue_name,
max_connection_pool_size=max_connection_pool_size,
**connection_kwargs,
)
self.consumer_group_name = consumer_group_name
self.consumer_name = consumer_name or str(uuid.uuid4())
self.consumer_id = consumer_id
self.mkstream = mkstream
self.block = xread_block
self.maxlen = maxlen
self.approximate = approximate
self.additional_streams = additional_streams or {}
self.idle_timeout = idle_timeout
self.unacknowledged_batch_size = unacknowledged_batch_size
self.unacknowledged_lock_timeout = unacknowledged_lock_timeout
self.count = xread_count
async def _declare_consumer_group(self) -> None:
"""
Declare consumber group.
Required for proper work of the broker.
"""
streams = {self.queue_name, *self.additional_streams.keys()}
async with Redis(connection_pool=self.connection_pool) as redis_conn:
for stream_name in streams:
try:
await redis_conn.xgroup_create(
stream_name,
self.consumer_group_name,
id=self.consumer_id,
mkstream=self.mkstream,
)
except ResponseError as err:
logger.debug(err)
async def startup(self) -> None:
"""Declare consumer group on startup."""
await super().startup()
await self._declare_consumer_group()
async def kick(self, message: BrokerMessage) -> None:
"""
Put a message in a list.
This method appends a message to the list of all messages.
:param message: message to append.
"""
queue_name = message.labels.get("queue_name") or self.queue_name
async with Redis(connection_pool=self.connection_pool) as redis_conn:
await redis_conn.xadd(
queue_name,
{b"data": message.message},
maxlen=self.maxlen,
approximate=self.approximate,
)
def _ack_generator(self, id: str, queue_name: str) -> Callable[[], Awaitable[None]]:
async def _ack() -> None:
async with Redis(connection_pool=self.connection_pool) as redis_conn:
await redis_conn.xack(
queue_name,
self.consumer_group_name,
id,
)
return _ack
async def listen(self) -> AsyncGenerator[AckableMessage, None]:
"""Listen to incoming messages."""
async with Redis(connection_pool=self.connection_pool) as redis_conn:
while True:
logger.debug("Starting fetching new messages")
fetched = await redis_conn.xreadgroup(
self.consumer_group_name,
self.consumer_name,
{
self.queue_name: ">",
**self.additional_streams, # type: ignore[dict-item]
},
block=self.block,
noack=False,
count=self.count,
)
for stream, msg_list in fetched:
for msg_id, msg in msg_list:
logger.debug("Received message: %s", msg)
yield AckableMessage(
data=msg[b"data"],
ack=self._ack_generator(id=msg_id, queue_name=stream),
)
logger.debug("Starting fetching unacknowledged messages")
for stream in [self.queue_name, *self.additional_streams.keys()]:
lock = redis_conn.lock(
f"autoclaim:{self.consumer_group_name}:{stream}",
timeout=self.unacknowledged_lock_timeout,
)
if await lock.locked():
continue
async with lock:
pending = await redis_conn.xautoclaim(
name=stream,
groupname=self.consumer_group_name,
consumername=self.consumer_name,
min_idle_time=self.idle_timeout,
count=self.unacknowledged_batch_size,
)
logger.debug(
"Found %d pending messages in stream %s",
len(pending[1]),
stream,
)
for msg_id, msg in pending[1]:
logger.debug("Received message: %s", msg)
yield AckableMessage(
data=msg[b"data"],
ack=self._ack_generator(id=msg_id, queue_name=stream),
)