Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions faststream_concurrent_aiokafka/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,11 @@
initialize_concurrent_processing,
stop_concurrent_processing,
)
from faststream_concurrent_aiokafka.rebalance import ConsumerRebalanceListener


__all__ = [
"ConsumerRebalanceListener",
"KafkaConcurrentProcessingMiddleware",
"initialize_concurrent_processing",
"is_kafka_handler_healthy",
Expand Down
52 changes: 33 additions & 19 deletions faststream_concurrent_aiokafka/batch_committer.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
import asyncio
import contextlib
import dataclasses
import itertools
import logging
import typing

from aiokafka.errors import CommitFailedError, KafkaError
from faststream.kafka import TopicPartition


Expand Down Expand Up @@ -38,6 +38,7 @@ def __init__(
self._messages_queue: asyncio.Queue[KafkaCommitTask] = asyncio.Queue()
self._commit_task: asyncio.Task[typing.Any] | None = None
self._flush_batch_event = asyncio.Event()
self._stop_requested: bool = False

self._commit_batch_timeout_sec = commit_batch_timeout_sec
self._commit_batch_size = commit_batch_size
Expand Down Expand Up @@ -77,11 +78,11 @@ async def _populate_commit_batch(self) -> tuple[list[KafkaCommitTask], bool]:
else:
queue_get_task.cancel()

# commit_all was called — flush remaining queue items and stop
# flush event — drain remaining queue items; stop only if close() was called
if flush_wait_task in done:
uncommited_tasks.extend(self._flush_tasks_queue())
self._flush_batch_event.clear()
should_shutdown = True
should_shutdown = self._stop_requested
break

if timeout_task in done:
Expand All @@ -104,25 +105,35 @@ async def _call_committer(
) -> bool:
if not partitions_to_offsets:
return True
commit_succeeded = True
consumer: typing.Final[AIOKafkaConsumer] = tasks_batch[0].consumer
try:
await consumer.commit(partitions_to_offsets)
except Exception as exc:
commit_succeeded = False
logger.exception("Error during commit to kafka", exc_info=exc)
except CommitFailedError:
# Partition reassignment in progress — safe to ignore, offsets will be re-committed
logger.exception("Cannot commit due to rebalancing, ignoring batch")
return False
except KafkaError:
# Transient error — re-queue batch for retry on next cycle
logger.exception("Error during commit to kafka, re-queuing batch")
for task in tasks_batch:
await self._messages_queue.put(task)
return commit_succeeded
return False
else:
return True

@staticmethod
def _map_offsets_per_partition(consumer_tasks: list[KafkaCommitTask]) -> dict[TopicPartition, int]:
partitions_to_tasks = itertools.groupby(
sorted(consumer_tasks, key=lambda x: x.topic_partition), lambda x: x.topic_partition
)
by_partition: dict[TopicPartition, list[KafkaCommitTask]] = {}
for task in consumer_tasks:
by_partition.setdefault(task.topic_partition, []).append(task)

partitions_to_offsets: dict[TopicPartition, int] = {}
for partition, partition_tasks in partitions_to_tasks:
max_offset = max((task.offset for task in partition_tasks), default=None)
for partition, tasks in by_partition.items():
max_offset: int | None = None
for task in sorted(tasks, key=lambda x: x.offset):
if task.asyncio_task.cancelled():
break # stop committing at first cancelled task — message was not processed
max_offset = task.offset
if max_offset is not None:
# Kafka commits the *next* offset to fetch, so committed = processed_max + 1
partitions_to_offsets[partition] = max_offset + 1
Expand All @@ -133,7 +144,7 @@ async def _commit_tasks_batch(self, tasks_batch: list[KafkaCommitTask]) -> bool:
*[task.asyncio_task for task in tasks_batch], return_exceptions=True
)
for result in results:
if isinstance(result, BaseException):
if isinstance(result, BaseException) and not isinstance(result, asyncio.CancelledError):
logger.error("Task has finished with an exception", exc_info=result)

# Group by consumer instance — each AIOKafkaConsumer can only commit its own partitions
Expand All @@ -159,15 +170,17 @@ async def _run_commit_process(self) -> None:
await self._commit_tasks_batch(commit_batch)

async def commit_all(self) -> None:
"""Flush and commit all pending tasks, then stop the committer loop."""
"""Flush and commit all pending tasks without stopping the committer loop.

Safe to call during Kafka rebalance (on_partitions_revoked). The committer
continues running after this returns.
"""
self._flush_batch_event.set()
await self._messages_queue.join()

async def send_task(self, new_task: KafkaCommitTask) -> None:
self._check_is_commit_task_running()
await self._messages_queue.put(
new_task,
)
await self._messages_queue.put(new_task)

def spawn(self) -> None:
if not self._commit_task:
Expand All @@ -176,11 +189,12 @@ def spawn(self) -> None:
logger.error("Committer main task already running")

async def close(self) -> None:
"""Close committer."""
"""Flush all pending tasks and shut down the committer."""
if not self._commit_task:
logger.error("Committer main task is not running, cannot close committer properly")
return

self._stop_requested = True
self._flush_batch_event.set()
try:
await asyncio.wait_for(self._commit_task, timeout=self._shutdown_timeout)
Expand Down
52 changes: 10 additions & 42 deletions faststream_concurrent_aiokafka/processing.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import asyncio
import contextlib
import functools
import logging
import signal
Expand All @@ -10,14 +9,14 @@

from faststream_concurrent_aiokafka import batch_committer
from faststream_concurrent_aiokafka.batch_committer import KafkaBatchCommitter
from faststream_concurrent_aiokafka.rebalance import ConsumerRebalanceListener


logger = logging.getLogger(__name__)


SIGNALS: typing.Final = (signal.SIGTERM, signal.SIGINT, signal.SIGQUIT)
GRACEFUL_TIMEOUT_SEC: typing.Final[int] = 10
DEFAULT_OBSERVER_INTERVAL_SEC: typing.Final[float] = 5.0
DEFAULT_CONCURRENCY_LIMIT: typing.Final = 10


Expand All @@ -26,20 +25,14 @@ def __init__(
self,
committer: KafkaBatchCommitter,
concurrency_limit: int = DEFAULT_CONCURRENCY_LIMIT,
observer_interval: float = DEFAULT_OBSERVER_INTERVAL_SEC,
) -> None:
if concurrency_limit < 1:
msg = f"concurrency_limit must be >= 1, got {concurrency_limit}"
raise ValueError(msg)

self._limiter = asyncio.Semaphore(concurrency_limit)
self._current_tasks: set[asyncio.Task[typing.Any]] = set()

self._observer_task: asyncio.Task[typing.Any] | None = None
self._shutdown_event: asyncio.Event = asyncio.Event()
self._observer_interval: float = observer_interval
self._is_running: bool = False

self._committer: KafkaBatchCommitter = committer
self._stop_task: asyncio.Task[typing.Any] | None = None

Expand Down Expand Up @@ -83,25 +76,6 @@ async def handle_task(
await self.stop()
raise

async def _check_tasks_health(self) -> None:
done_tasks: typing.Final = {t for t in self._current_tasks if t.done()}
self._current_tasks -= done_tasks
if done_tasks:
logger.info(f"Kafka middleware. Found completed but not discarded tasks, amount: {len(done_tasks)}")

async def observer(self) -> None:
"""Background observer task that monitors system health."""
logger.info("Kafka middleware. Observer task started")

while not self._shutdown_event.is_set():
try:
await asyncio.wait_for(
self._shutdown_event.wait(),
timeout=self._observer_interval,
)
except TimeoutError:
await self._check_tasks_health()

def _setup_signal_handlers(self) -> None:
loop: typing.Final = asyncio.get_running_loop()
for sig in SIGNALS:
Expand All @@ -121,10 +95,8 @@ async def start(self) -> None:

logger.info("Kafka middleware. Start middleware handler")
self._is_running = True
self._shutdown_event.clear()

self._committer.spawn()
self._observer_task = asyncio.create_task(self.observer())
self._setup_signal_handlers()
logger.info("Kafka middleware is ready to process messages.")

Expand All @@ -133,13 +105,8 @@ async def stop(self) -> None:
return
logger.info("Kafka middleware. Shutting down middleware handler")
self._is_running = False
self._shutdown_event.set()

await self._committer.close()
if self._observer_task and not self._observer_task.done():
self._observer_task.cancel()
with contextlib.suppress(asyncio.CancelledError):
await self._observer_task
await self.wait_for_subtasks()

try:
Expand All @@ -157,19 +124,20 @@ async def force_cancel_all(self) -> None:
tasks = list(self._current_tasks)
for task in tasks:
task.cancel()
if self._observer_task and not self._observer_task.done():
self._observer_task.cancel()
await asyncio.gather(*tasks, return_exceptions=True)
self._current_tasks.clear()

def create_rebalance_listener(self) -> ConsumerRebalanceListener:
"""Return a ConsumerRebalanceListener that flushes pending commits on partition revocation.

Pass the returned listener to ``@broker.subscriber(..., listener=listener)`` so that
in-flight offsets are committed before Kafka hands the partition to another consumer.
"""
return ConsumerRebalanceListener(self._committer)

@property
def is_healthy(self) -> bool:
return (
self._is_running
and self._observer_task is not None
and not self._observer_task.done()
and self._committer.is_healthy
)
return self._is_running and self._committer.is_healthy

@property
def is_running(self) -> bool:
Expand Down
36 changes: 36 additions & 0 deletions faststream_concurrent_aiokafka/rebalance.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
from aiokafka import ConsumerRebalanceListener as BaseConsumerRebalanceListener

from faststream_concurrent_aiokafka.batch_committer import KafkaBatchCommitter


class ConsumerRebalanceListener(BaseConsumerRebalanceListener): # type: ignore[misc]
"""Commits all pending offsets when Kafka revokes partitions during rebalance.

Without this listener, in-flight message tasks whose offsets have not yet been
batch-committed will be redelivered to another consumer after a rebalance, causing
duplicate processing.

Usage::

@asynccontextmanager
async def lifespan(context: ContextRepo) -> AsyncIterator[None]:
handler = await initialize_concurrent_processing(context, ...)
listener = handler.create_rebalance_listener()

@broker.subscriber("my-topic", listener=listener)
async def handle(msg: str) -> None:
...

Yield:
await stop_concurrent_processing(context)

"""

def __init__(self, committer: KafkaBatchCommitter) -> None:
self._committer = committer

async def on_partitions_assigned(self, _assigned: object) -> None: # ty: ignore[invalid-method-override]
pass

async def on_partitions_revoked(self, _revoked: object) -> None: # ty: ignore[invalid-method-override]
await self._committer.commit_all()
16 changes: 14 additions & 2 deletions tests/mocks.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
"""Shared mock classes used across multiple test modules."""

import asyncio
import typing
from unittest.mock import AsyncMock, Mock

Expand All @@ -11,13 +12,24 @@ def __init__(self, group_id: str = "test-group") -> None:


class MockAsyncioTask:
def __init__(self, result: str | None = None, exception: Exception | None = None, done: bool = True) -> None:
def __init__(
self,
result: str | None = None,
exception: Exception | None = None,
done: bool = True,
cancelled: bool = False,
) -> None:
self._result: str | None = result
self._exception: Exception | None = exception
self._done: bool = done
self._cancelled: bool = False
self._cancelled: bool = cancelled

def cancelled(self) -> bool:
return self._cancelled

def __await__(self) -> typing.Generator[typing.Any, None, str | None]:
if self._cancelled:
raise asyncio.CancelledError
if self._exception:
raise self._exception
if False: # pragma: no cover
Expand Down
Loading
Loading