-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathbatch_committer.py
More file actions
198 lines (161 loc) · 7.66 KB
/
batch_committer.py
File metadata and controls
198 lines (161 loc) · 7.66 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
import asyncio
import contextlib
import dataclasses
import itertools
import logging
import threading
import typing
from faststream.kafka import TopicPartition
if typing.TYPE_CHECKING:
from aiokafka import AIOKafkaConsumer
logger = logging.getLogger(__name__)
SHUTDOWN_TIMEOUT_SEC: typing.Final = 20
class CommitterIsDeadError(Exception): ...
@dataclasses.dataclass(frozen=True, kw_only=True, slots=True)
class KafkaCommitTask:
asyncio_task: asyncio.Task[typing.Any]
topic_partition: TopicPartition
offset: int
consumer: typing.Any
class KafkaBatchCommitter:
def __init__(
self,
commit_batch_timeout_sec: float = 10.0,
commit_batch_size: int = 10,
) -> None:
self._messages_queue: asyncio.Queue[KafkaCommitTask] = asyncio.Queue()
self._asyncio_commit_process_task: asyncio.Task[typing.Any] | None = None
self._flush_batch_event = asyncio.Event()
self._commit_batch_timeout_sec = commit_batch_timeout_sec
self._commit_batch_size = commit_batch_size
self._shutdown_timeout = SHUTDOWN_TIMEOUT_SEC
self._spawn_lock = threading.Lock()
def _check_is_commit_task_running(self) -> None:
is_commit_task_running: typing.Final[bool] = bool(
self._asyncio_commit_process_task
and not self._asyncio_commit_process_task.cancelled()
and not self._asyncio_commit_process_task.done(),
)
if not is_commit_task_running:
msg: typing.Final = "Committer main task is not running"
raise CommitterIsDeadError(msg)
def _flush_tasks_queue(self) -> list[KafkaCommitTask]:
tasks_to_return: typing.Final[list[KafkaCommitTask]] = []
while not self._messages_queue.empty():
tasks_to_return.append(self._messages_queue.get_nowait())
return tasks_to_return
async def _populate_commit_batch(self) -> tuple[list[KafkaCommitTask], bool]:
uncommited_tasks: typing.Final[list[KafkaCommitTask]] = []
should_shutdown = False
queue_get_task: asyncio.Task[typing.Any] | None = None
# Create timeout and flush-wait tasks once; reused across queue-get iterations.
timeout_task: asyncio.Task[None] = asyncio.create_task(asyncio.sleep(self._commit_batch_timeout_sec))
flush_wait_task: asyncio.Task[bool] = asyncio.create_task(self._flush_batch_event.wait())
try:
while len(uncommited_tasks) < self._commit_batch_size:
queue_get_task = asyncio.create_task(self._messages_queue.get())
done, _ = await asyncio.wait(
[queue_get_task, flush_wait_task, timeout_task],
return_when=asyncio.FIRST_COMPLETED,
)
if queue_get_task in done:
uncommited_tasks.append(queue_get_task.result())
else:
queue_get_task.cancel()
# commit_all was called — flush remaining queue items and stop
if flush_wait_task in done:
uncommited_tasks.extend(self._flush_tasks_queue())
self._flush_batch_event.clear()
should_shutdown = True
break
if timeout_task in done:
logger.debug("Timeout exceeded, batch contains %s elements", len(uncommited_tasks))
break
logger.debug("Batch condition reached with %s elements", len(uncommited_tasks))
except asyncio.CancelledError:
should_shutdown = True
uncommited_tasks.extend(self._flush_tasks_queue())
for task in (queue_get_task, flush_wait_task, timeout_task):
if task:
task.cancel()
return uncommited_tasks, should_shutdown
async def _call_committer(
self, tasks_batch: list[KafkaCommitTask], partitions_to_offsets: dict[TopicPartition, int]
) -> bool:
if not partitions_to_offsets:
return True
commit_succeeded = True
consumer: typing.Final[AIOKafkaConsumer] = tasks_batch[0].consumer
try:
await consumer.commit(partitions_to_offsets)
except Exception as exc:
commit_succeeded = False
logger.exception("Error during commit to kafka", exc_info=exc)
for task in tasks_batch:
await self._messages_queue.put(task)
return commit_succeeded
async def _commit_tasks_batch(self, tasks_batch: list[KafkaCommitTask]) -> bool:
partitions_to_tasks: typing.Final = itertools.groupby(
sorted(tasks_batch, key=lambda x: x.topic_partition), lambda x: x.topic_partition
)
results: typing.Final = await asyncio.gather(
*[task.asyncio_task for task in tasks_batch], return_exceptions=True
)
for result in results:
if isinstance(result, BaseException):
logger.error("Task has finished with an exception", exc_info=result)
partitions_to_offsets: typing.Final[dict[TopicPartition, int]] = {}
partition: TopicPartition
tasks: typing.Iterator[KafkaCommitTask]
for partition, tasks in partitions_to_tasks:
max_message_offset: int | None = None
for task in tasks:
if max_message_offset is None or task.offset > max_message_offset:
max_message_offset = task.offset
if max_message_offset is not None:
# Kafka commits the *next* offset to fetch, so committed = processed_max + 1
partitions_to_offsets[partition] = max_message_offset + 1
commit_succeeded: typing.Final = await self._call_committer(tasks_batch, partitions_to_offsets)
for _ in tasks_batch:
self._messages_queue.task_done()
return commit_succeeded
async def _run_commit_process(self) -> None:
should_shutdown = False
while not should_shutdown:
commit_batch, should_shutdown = await self._populate_commit_batch()
if commit_batch:
await self._commit_tasks_batch(commit_batch)
async def commit_all(self) -> None:
"""Commit all without shutting down the main process."""
self._flush_batch_event.set()
await self._messages_queue.join()
async def send_task(self, new_task: KafkaCommitTask) -> None:
self._check_is_commit_task_running()
await self._messages_queue.put(
new_task,
)
def spawn(self) -> None:
with self._spawn_lock:
if not self._asyncio_commit_process_task:
self._asyncio_commit_process_task = asyncio.create_task(self._run_commit_process())
else:
logger.error("Committer main task already running")
async def close(self) -> None:
"""Close committer."""
if not self._asyncio_commit_process_task:
logger.error("Committer main task is not running, cannot close committer properly")
return
self._flush_batch_event.set()
try:
await asyncio.wait_for(self._asyncio_commit_process_task, timeout=self._shutdown_timeout)
except TimeoutError:
logger.exception("Committer main task shutdown timed out, forcing cancellation")
self._asyncio_commit_process_task.cancel()
with contextlib.suppress(asyncio.CancelledError):
await self._asyncio_commit_process_task
except Exception as exc:
logger.exception("Committer task failed during shutdown", exc_info=exc)
raise
@property
def is_healthy(self) -> bool:
return self._asyncio_commit_process_task is not None and not self._asyncio_commit_process_task.done()