Skip to content

Commit 055e8b7

Browse files
authored
Merge pull request #9 from modern-python/fix-concurrent-processing-bugs
fix offset safety, rebalance handling, and commit error handling
2 parents d0683ef + 63a55c2 commit 055e8b7

9 files changed

Lines changed: 413 additions & 256 deletions

File tree

faststream_concurrent_aiokafka/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,9 +4,11 @@
44
initialize_concurrent_processing,
55
stop_concurrent_processing,
66
)
7+
from faststream_concurrent_aiokafka.rebalance import ConsumerRebalanceListener
78

89

910
__all__ = [
11+
"ConsumerRebalanceListener",
1012
"KafkaConcurrentProcessingMiddleware",
1113
"initialize_concurrent_processing",
1214
"is_kafka_handler_healthy",

faststream_concurrent_aiokafka/batch_committer.py

Lines changed: 33 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,10 @@
11
import asyncio
22
import contextlib
33
import dataclasses
4-
import itertools
54
import logging
65
import typing
76

7+
from aiokafka.errors import CommitFailedError, KafkaError
88
from faststream.kafka import TopicPartition
99

1010

@@ -38,6 +38,7 @@ def __init__(
3838
self._messages_queue: asyncio.Queue[KafkaCommitTask] = asyncio.Queue()
3939
self._commit_task: asyncio.Task[typing.Any] | None = None
4040
self._flush_batch_event = asyncio.Event()
41+
self._stop_requested: bool = False
4142

4243
self._commit_batch_timeout_sec = commit_batch_timeout_sec
4344
self._commit_batch_size = commit_batch_size
@@ -77,11 +78,11 @@ async def _populate_commit_batch(self) -> tuple[list[KafkaCommitTask], bool]:
7778
else:
7879
queue_get_task.cancel()
7980

80-
# commit_all was called — flush remaining queue items and stop
81+
# flush event — drain remaining queue items; stop only if close() was called
8182
if flush_wait_task in done:
8283
uncommited_tasks.extend(self._flush_tasks_queue())
8384
self._flush_batch_event.clear()
84-
should_shutdown = True
85+
should_shutdown = self._stop_requested
8586
break
8687

8788
if timeout_task in done:
@@ -104,25 +105,35 @@ async def _call_committer(
104105
) -> bool:
105106
if not partitions_to_offsets:
106107
return True
107-
commit_succeeded = True
108108
consumer: typing.Final[AIOKafkaConsumer] = tasks_batch[0].consumer
109109
try:
110110
await consumer.commit(partitions_to_offsets)
111-
except Exception as exc:
112-
commit_succeeded = False
113-
logger.exception("Error during commit to kafka", exc_info=exc)
111+
except CommitFailedError:
112+
# Partition reassignment in progress — safe to ignore, offsets will be re-committed
113+
logger.exception("Cannot commit due to rebalancing, ignoring batch")
114+
return False
115+
except KafkaError:
116+
# Transient error — re-queue batch for retry on next cycle
117+
logger.exception("Error during commit to kafka, re-queuing batch")
114118
for task in tasks_batch:
115119
await self._messages_queue.put(task)
116-
return commit_succeeded
120+
return False
121+
else:
122+
return True
117123

118124
@staticmethod
119125
def _map_offsets_per_partition(consumer_tasks: list[KafkaCommitTask]) -> dict[TopicPartition, int]:
120-
partitions_to_tasks = itertools.groupby(
121-
sorted(consumer_tasks, key=lambda x: x.topic_partition), lambda x: x.topic_partition
122-
)
126+
by_partition: dict[TopicPartition, list[KafkaCommitTask]] = {}
127+
for task in consumer_tasks:
128+
by_partition.setdefault(task.topic_partition, []).append(task)
129+
123130
partitions_to_offsets: dict[TopicPartition, int] = {}
124-
for partition, partition_tasks in partitions_to_tasks:
125-
max_offset = max((task.offset for task in partition_tasks), default=None)
131+
for partition, tasks in by_partition.items():
132+
max_offset: int | None = None
133+
for task in sorted(tasks, key=lambda x: x.offset):
134+
if task.asyncio_task.cancelled():
135+
break # stop committing at first cancelled task — message was not processed
136+
max_offset = task.offset
126137
if max_offset is not None:
127138
# Kafka commits the *next* offset to fetch, so committed = processed_max + 1
128139
partitions_to_offsets[partition] = max_offset + 1
@@ -133,7 +144,7 @@ async def _commit_tasks_batch(self, tasks_batch: list[KafkaCommitTask]) -> bool:
133144
*[task.asyncio_task for task in tasks_batch], return_exceptions=True
134145
)
135146
for result in results:
136-
if isinstance(result, BaseException):
147+
if isinstance(result, BaseException) and not isinstance(result, asyncio.CancelledError):
137148
logger.error("Task has finished with an exception", exc_info=result)
138149

139150
# Group by consumer instance — each AIOKafkaConsumer can only commit its own partitions
@@ -159,15 +170,17 @@ async def _run_commit_process(self) -> None:
159170
await self._commit_tasks_batch(commit_batch)
160171

161172
async def commit_all(self) -> None:
162-
"""Flush and commit all pending tasks, then stop the committer loop."""
173+
"""Flush and commit all pending tasks without stopping the committer loop.
174+
175+
Safe to call during Kafka rebalance (on_partitions_revoked). The committer
176+
continues running after this returns.
177+
"""
163178
self._flush_batch_event.set()
164179
await self._messages_queue.join()
165180

166181
async def send_task(self, new_task: KafkaCommitTask) -> None:
167182
self._check_is_commit_task_running()
168-
await self._messages_queue.put(
169-
new_task,
170-
)
183+
await self._messages_queue.put(new_task)
171184

172185
def spawn(self) -> None:
173186
if not self._commit_task:
@@ -176,11 +189,12 @@ def spawn(self) -> None:
176189
logger.error("Committer main task already running")
177190

178191
async def close(self) -> None:
179-
"""Close committer."""
192+
"""Flush all pending tasks and shut down the committer."""
180193
if not self._commit_task:
181194
logger.error("Committer main task is not running, cannot close committer properly")
182195
return
183196

197+
self._stop_requested = True
184198
self._flush_batch_event.set()
185199
try:
186200
await asyncio.wait_for(self._commit_task, timeout=self._shutdown_timeout)

faststream_concurrent_aiokafka/processing.py

Lines changed: 10 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
import asyncio
2-
import contextlib
32
import functools
43
import logging
54
import signal
@@ -10,14 +9,14 @@
109

1110
from faststream_concurrent_aiokafka import batch_committer
1211
from faststream_concurrent_aiokafka.batch_committer import KafkaBatchCommitter
12+
from faststream_concurrent_aiokafka.rebalance import ConsumerRebalanceListener
1313

1414

1515
logger = logging.getLogger(__name__)
1616

1717

1818
SIGNALS: typing.Final = (signal.SIGTERM, signal.SIGINT, signal.SIGQUIT)
1919
GRACEFUL_TIMEOUT_SEC: typing.Final[int] = 10
20-
DEFAULT_OBSERVER_INTERVAL_SEC: typing.Final[float] = 5.0
2120
DEFAULT_CONCURRENCY_LIMIT: typing.Final = 10
2221

2322

@@ -26,20 +25,14 @@ def __init__(
2625
self,
2726
committer: KafkaBatchCommitter,
2827
concurrency_limit: int = DEFAULT_CONCURRENCY_LIMIT,
29-
observer_interval: float = DEFAULT_OBSERVER_INTERVAL_SEC,
3028
) -> None:
3129
if concurrency_limit < 1:
3230
msg = f"concurrency_limit must be >= 1, got {concurrency_limit}"
3331
raise ValueError(msg)
3432

3533
self._limiter = asyncio.Semaphore(concurrency_limit)
3634
self._current_tasks: set[asyncio.Task[typing.Any]] = set()
37-
38-
self._observer_task: asyncio.Task[typing.Any] | None = None
39-
self._shutdown_event: asyncio.Event = asyncio.Event()
40-
self._observer_interval: float = observer_interval
4135
self._is_running: bool = False
42-
4336
self._committer: KafkaBatchCommitter = committer
4437
self._stop_task: asyncio.Task[typing.Any] | None = None
4538

@@ -83,25 +76,6 @@ async def handle_task(
8376
await self.stop()
8477
raise
8578

86-
async def _check_tasks_health(self) -> None:
87-
done_tasks: typing.Final = {t for t in self._current_tasks if t.done()}
88-
self._current_tasks -= done_tasks
89-
if done_tasks:
90-
logger.info(f"Kafka middleware. Found completed but not discarded tasks, amount: {len(done_tasks)}")
91-
92-
async def observer(self) -> None:
93-
"""Background observer task that monitors system health."""
94-
logger.info("Kafka middleware. Observer task started")
95-
96-
while not self._shutdown_event.is_set():
97-
try:
98-
await asyncio.wait_for(
99-
self._shutdown_event.wait(),
100-
timeout=self._observer_interval,
101-
)
102-
except TimeoutError:
103-
await self._check_tasks_health()
104-
10579
def _setup_signal_handlers(self) -> None:
10680
loop: typing.Final = asyncio.get_running_loop()
10781
for sig in SIGNALS:
@@ -121,10 +95,8 @@ async def start(self) -> None:
12195

12296
logger.info("Kafka middleware. Start middleware handler")
12397
self._is_running = True
124-
self._shutdown_event.clear()
12598

12699
self._committer.spawn()
127-
self._observer_task = asyncio.create_task(self.observer())
128100
self._setup_signal_handlers()
129101
logger.info("Kafka middleware is ready to process messages.")
130102

@@ -133,13 +105,8 @@ async def stop(self) -> None:
133105
return
134106
logger.info("Kafka middleware. Shutting down middleware handler")
135107
self._is_running = False
136-
self._shutdown_event.set()
137108

138109
await self._committer.close()
139-
if self._observer_task and not self._observer_task.done():
140-
self._observer_task.cancel()
141-
with contextlib.suppress(asyncio.CancelledError):
142-
await self._observer_task
143110
await self.wait_for_subtasks()
144111

145112
try:
@@ -157,19 +124,20 @@ async def force_cancel_all(self) -> None:
157124
tasks = list(self._current_tasks)
158125
for task in tasks:
159126
task.cancel()
160-
if self._observer_task and not self._observer_task.done():
161-
self._observer_task.cancel()
162127
await asyncio.gather(*tasks, return_exceptions=True)
163128
self._current_tasks.clear()
164129

130+
def create_rebalance_listener(self) -> ConsumerRebalanceListener:
131+
"""Return a ConsumerRebalanceListener that flushes pending commits on partition revocation.
132+
133+
Pass the returned listener to ``@broker.subscriber(..., listener=listener)`` so that
134+
in-flight offsets are committed before Kafka hands the partition to another consumer.
135+
"""
136+
return ConsumerRebalanceListener(self._committer)
137+
165138
@property
166139
def is_healthy(self) -> bool:
167-
return (
168-
self._is_running
169-
and self._observer_task is not None
170-
and not self._observer_task.done()
171-
and self._committer.is_healthy
172-
)
140+
return self._is_running and self._committer.is_healthy
173141

174142
@property
175143
def is_running(self) -> bool:
Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
from aiokafka import ConsumerRebalanceListener as BaseConsumerRebalanceListener
2+
3+
from faststream_concurrent_aiokafka.batch_committer import KafkaBatchCommitter
4+
5+
6+
class ConsumerRebalanceListener(BaseConsumerRebalanceListener): # type: ignore[misc]
7+
"""Commits all pending offsets when Kafka revokes partitions during rebalance.
8+
9+
Without this listener, in-flight message tasks whose offsets have not yet been
10+
batch-committed will be redelivered to another consumer after a rebalance, causing
11+
duplicate processing.
12+
13+
Usage::
14+
15+
@asynccontextmanager
16+
async def lifespan(context: ContextRepo) -> AsyncIterator[None]:
17+
handler = await initialize_concurrent_processing(context, ...)
18+
listener = handler.create_rebalance_listener()
19+
20+
@broker.subscriber("my-topic", listener=listener)
21+
async def handle(msg: str) -> None:
22+
...
23+
24+
Yield:
25+
await stop_concurrent_processing(context)
26+
27+
"""
28+
29+
def __init__(self, committer: KafkaBatchCommitter) -> None:
30+
self._committer = committer
31+
32+
async def on_partitions_assigned(self, _assigned: object) -> None: # ty: ignore[invalid-method-override]
33+
pass
34+
35+
async def on_partitions_revoked(self, _revoked: object) -> None: # ty: ignore[invalid-method-override]
36+
await self._committer.commit_all()

tests/mocks.py

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
"""Shared mock classes used across multiple test modules."""
22

3+
import asyncio
34
import typing
45
from unittest.mock import AsyncMock, Mock
56

@@ -11,13 +12,24 @@ def __init__(self, group_id: str = "test-group") -> None:
1112

1213

1314
class MockAsyncioTask:
14-
def __init__(self, result: str | None = None, exception: Exception | None = None, done: bool = True) -> None:
15+
def __init__(
16+
self,
17+
result: str | None = None,
18+
exception: Exception | None = None,
19+
done: bool = True,
20+
cancelled: bool = False,
21+
) -> None:
1522
self._result: str | None = result
1623
self._exception: Exception | None = exception
1724
self._done: bool = done
18-
self._cancelled: bool = False
25+
self._cancelled: bool = cancelled
26+
27+
def cancelled(self) -> bool:
28+
return self._cancelled
1929

2030
def __await__(self) -> typing.Generator[typing.Any, None, str | None]:
31+
if self._cancelled:
32+
raise asyncio.CancelledError
2133
if self._exception:
2234
raise self._exception
2335
if False: # pragma: no cover

0 commit comments

Comments
 (0)