Skip to content

Commit 0299d29

Browse files
committed
feat: NATSKeyValueScheduleSource
1 parent ff7e4eb commit 0299d29

2 files changed

Lines changed: 115 additions & 0 deletions

File tree

taskiq_nats/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,10 +11,12 @@
1111
PushBasedJetStreamBroker,
1212
)
1313
from taskiq_nats.result_backend import NATSObjectStoreResultBackend
14+
from taskiq_nats.schedule_source import NATSKeyValueScheduleSource
1415

1516
__all__ = [
1617
"NatsBroker",
1718
"PushBasedJetStreamBroker",
1819
"PullBasedJetStreamBroker",
1920
"NATSObjectStoreResultBackend",
21+
"NATSKeyValueScheduleSource",
2022
]

taskiq_nats/schedule_source.py

Lines changed: 113 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,113 @@
1+
import logging
2+
from typing import Any, Final, List, Optional, Union
3+
4+
import nats
5+
from nats import NATS
6+
from nats.js import JetStreamContext
7+
from nats.js.errors import BucketNotFoundError, NoKeysError
8+
from nats.js.kv import KeyValue
9+
from taskiq import ScheduledTask, ScheduleSource
10+
from taskiq.abc.serializer import TaskiqSerializer
11+
from taskiq.compat import model_dump, model_validate
12+
from taskiq.serializers import PickleSerializer
13+
14+
log = logging.getLogger(__name__)
15+
16+
17+
class NATSKeyValueScheduleSource(ScheduleSource):
18+
"""
19+
Source of schedules for NATS Key-Value storage.
20+
21+
This class allows you to store schedules in NATS Key-Value storage.
22+
Also it supports dynamic schedules.
23+
"""
24+
25+
def __init__(
26+
self,
27+
servers: Union[str, List[str]],
28+
bucket_name: str = "taskiq_schedules",
29+
prefix: str = "schedule",
30+
serializer: Optional[TaskiqSerializer] = None,
31+
**connect_options: Any,
32+
) -> None:
33+
"""Construct new result backend.
34+
35+
:param servers: NATS servers.
36+
:param bucket_name: name of the bucket where schedules would be stored.
37+
:param prefix: prefix for nats kv storage schedule keys.
38+
:param serializer: serializer for data.
39+
:param connect_kwargs: additional arguments for nats `connect()` method.
40+
"""
41+
self.servers: Final = servers
42+
self.bucket_name: Final = bucket_name
43+
self.prefix: Final = prefix
44+
self.serializer = serializer or PickleSerializer()
45+
self.connect_options: Final = connect_options
46+
47+
self.nats_client: NATS
48+
self.nats_jetstream: JetStreamContext
49+
self.kv: KeyValue
50+
51+
async def startup(self) -> None:
52+
"""Create new connection to NATS.
53+
54+
Initialize JetStream context and new KeyValue instance.
55+
"""
56+
self.nats_client = await nats.connect(
57+
servers=self.servers,
58+
**self.connect_options,
59+
)
60+
self.nats_jetstream = self.nats_client.jetstream()
61+
62+
try:
63+
self.kv = await self.nats_jetstream.key_value(self.bucket_name)
64+
except BucketNotFoundError:
65+
self.kv = await self.nats_jetstream.create_key_value(
66+
bucket=self.bucket_name,
67+
)
68+
69+
async def shutdown(self) -> None:
70+
"""Close nats connection."""
71+
if self.nats_client.is_closed:
72+
return
73+
await self.nats_client.close()
74+
75+
async def delete_schedule(self, schedule_id: str) -> None:
76+
"""Remove schedule by id."""
77+
await self.kv.delete(f"{self.prefix}.{schedule_id}")
78+
79+
async def add_schedule(self, schedule: ScheduledTask) -> None:
80+
"""
81+
Add schedule to NATS Key-Value storage.
82+
83+
:param schedule: schedule to add.
84+
:param schedule_id: schedule id.
85+
"""
86+
await self.kv.put(
87+
f"{self.prefix}.{schedule.schedule_id}",
88+
self.serializer.dumpb(model_dump(schedule)),
89+
)
90+
91+
async def get_schedules(self) -> List[ScheduledTask]:
92+
"""
93+
Get all schedules from NATS Key-Value storage.
94+
95+
This method is used by scheduler to get all schedules.
96+
97+
:return: list of schedules.
98+
"""
99+
try:
100+
schedules = await self.kv.history(f"{self.prefix}.*")
101+
except NoKeysError:
102+
return []
103+
104+
return [
105+
model_validate(ScheduledTask, self.serializer.loadb(schedule.value))
106+
for schedule in schedules
107+
if schedule and schedule.value
108+
]
109+
110+
async def post_send(self, task: ScheduledTask) -> None:
111+
"""Delete a task after it's completed."""
112+
if task.time is not None:
113+
await self.delete_schedule(task.schedule_id)

0 commit comments

Comments
 (0)