Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 13 additions & 20 deletions faststream_concurrent_aiokafka/middleware.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from faststream import BaseMiddleware, ContextRepo
from faststream.kafka.message import KafkaAckableMessage

from faststream_concurrent_aiokafka.batch_committer import KafkaBatchCommitter
from faststream_concurrent_aiokafka.processing import DEFAULT_CONCURRENCY_LIMIT, KafkaConcurrentHandler


Expand Down Expand Up @@ -32,7 +33,7 @@ async def consume_scope( # ty: ignore[invalid-method-override]
raise RuntimeError(info)

kafka_message: typing.Final = self.context.get("message")
if concurrent_processing.enable_batch_commit and not kafka_message:
if concurrent_processing.has_batch_commit and not kafka_message:
logger.error("Kafka middleware. No kafka message in the middleware, it means no consumer to commit batch.")
info = "No kafka message in the middleware"
raise RuntimeError(info)
Expand All @@ -50,39 +51,31 @@ async def handler_wrapper() -> typing.Any: # noqa: ANN401
async def initialize_concurrent_processing(
context: ContextRepo,
concurrency_limit: int = DEFAULT_CONCURRENCY_LIMIT,
enable_batch_commit: bool = False,
commit_batch_size: int = 10,
commit_batch_timeout_sec: float = 10.0,
) -> None:
) -> KafkaConcurrentHandler:
existing: KafkaConcurrentHandler | None = context.get(_PROCESSING_CONTEXT_KEY)
if existing and existing.is_running:
logger.warning("Kafka middleware. Processing is already active")
return existing

concurrent_processing: typing.Final = KafkaConcurrentHandler(
commit_batch_size=commit_batch_size,
commit_batch_timeout_sec=commit_batch_timeout_sec,
concurrency_limit=concurrency_limit,
enable_batch_commit=enable_batch_commit,
committer=KafkaBatchCommitter(commit_batch_timeout_sec, commit_batch_size),
)
if concurrent_processing.is_running:
logger.warning("Kafka middleware. Processing is already active")
return
try:
await concurrent_processing.start()
except Exception as exc:
logger.exception("Kafka middleware. Cannot start concurrent processing")
msg: typing.Final = "Kafka middleware. Cannot start concurrent processing"
raise RuntimeError(msg) from exc

await concurrent_processing.start()
context.set_global(_PROCESSING_CONTEXT_KEY, concurrent_processing)
logger.info("Kafka middleware. Concurrent processing is active")
return concurrent_processing


async def stop_concurrent_processing(
context: ContextRepo,
) -> None:
concurrent_processing: typing.Final = KafkaConcurrentHandler()
if not concurrent_processing.is_healthy:
concurrent_processing: typing.Final[KafkaConcurrentHandler | None] = context.get(_PROCESSING_CONTEXT_KEY)
if not concurrent_processing or not concurrent_processing.is_healthy:
logger.warning("Kafka middleware. Concurrent processing is not running. Cannot stop")
return

await concurrent_processing.stop()
context.set_global(_PROCESSING_CONTEXT_KEY, None)

KafkaConcurrentHandler.reset()
40 changes: 8 additions & 32 deletions faststream_concurrent_aiokafka/processing.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
import functools
import logging
import signal
import threading
import typing

from faststream.kafka import ConsumerRecord, TopicPartition
Expand All @@ -24,27 +23,12 @@


class KafkaConcurrentHandler:
_instance: typing.ClassVar["typing.Self | None"] = None
_lock: typing.ClassVar[threading.Lock] = threading.Lock()
_initialized: bool = False

def __new__(cls, *args: typing.Any, **kwargs: typing.Any) -> typing.Self: # noqa: ARG004, ANN401
with cls._lock:
if cls._instance is None:
cls._instance = super().__new__(cls)
return cls._instance

def __init__(
self,
concurrency_limit: int = DEFAULT_CONCURRENCY_LIMIT,
commit_batch_timeout_sec: float = 10.0,
commit_batch_size: int = 10,
enable_batch_commit: bool = False,
committer: KafkaBatchCommitter | None = None,
observer_interval: float = DEFAULT_OBSERVER_INTERVAL_SEC,
) -> None:
if self._initialized:
return

self.limiter = asyncio.Semaphore(concurrency_limit) if concurrency_limit != 0 else None
self._current_tasks: set[asyncio.Task[typing.Any]] = set()

Expand All @@ -53,19 +37,8 @@ def __init__(
self._observer_interval: float = observer_interval
self._is_running: bool = False

self.enable_batch_commit = enable_batch_commit
self._commit_batch_timeout_sec: float = commit_batch_timeout_sec
self._commit_batch_size: int = commit_batch_size

self._committer: KafkaBatchCommitter | None = None
self._committer: KafkaBatchCommitter | None = committer
self._stop_task: asyncio.Task[typing.Any] | None = None
self._initialized = True

@classmethod
def reset(cls) -> None:
with cls._lock:
cls._initialized = False
cls._instance = None

def _is_need_to_process_message(self, message: KafkaAckableMessage) -> bool:
headers_topic_group: typing.Final[str | None] = message.headers.get(TOPIC_GROUP_KEY)
Expand Down Expand Up @@ -101,7 +74,7 @@ async def handle_task(
task: typing.Final = asyncio.create_task(coroutine)
self._current_tasks.add(task)
task.add_done_callback(self._finish_task)
if self.enable_batch_commit and self._committer:
if self._committer:
try:
await self._committer.send_task(
batch_committer.KafkaCommitTask(
Expand Down Expand Up @@ -159,8 +132,7 @@ async def start(self) -> None:
self._is_running = True
self._shutdown_event.clear()

if self.enable_batch_commit:
self._committer = KafkaBatchCommitter(self._commit_batch_timeout_sec, self._commit_batch_size)
if self._committer:
self._committer.spawn()
self._observer_task = asyncio.create_task(self.observer())
self._setup_signal_handlers()
Expand Down Expand Up @@ -200,6 +172,10 @@ async def force_cancel_all(self) -> None:
await asyncio.sleep(0.5)
self._current_tasks.clear()

@property
def has_batch_commit(self) -> bool:
return self._committer is not None

@property
def is_healthy(self) -> bool:
status = self._is_running and self._observer_task is not None and not self._observer_task.done()
Expand Down
4 changes: 2 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,6 @@ dev = [
lint = [
"ruff",
"ty",
"auto-typing-final",
"eof-fixer"
]

Expand Down Expand Up @@ -78,7 +77,8 @@ isort.no-lines-before = ["standard-library", "local-folder"]

[tool.ruff.lint.extend-per-file-ignores]
"tests/*.py" = [
"S101", # allow asserts
"S101", # allow asserts
"PLR2004", # allow magic values in test assertions
]

[tool.pytest.ini_options]
Expand Down
10 changes: 0 additions & 10 deletions tests/conftest.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,8 @@
import os
import typing

import pytest

from faststream_concurrent_aiokafka.processing import KafkaConcurrentHandler


@pytest.fixture(scope="session")
def kafka_bootstrap_servers() -> str:
return os.environ.get("KAFKA_BOOTSTRAP_SERVERS", "localhost:9092")


@pytest.fixture(autouse=True)
def reset_singleton() -> typing.Iterator[None]:
KafkaConcurrentHandler.reset()
yield
KafkaConcurrentHandler.reset()
148 changes: 49 additions & 99 deletions tests/test_concurrent_processing.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,43 +39,24 @@ def __init__(self, topic: str = "test-topic", partition: int = 0, offset: int =

@pytest_asyncio.fixture
async def handler() -> typing.AsyncIterator[KafkaConcurrentHandler]:
with patch(
"faststream_concurrent_aiokafka.processing.KafkaBatchCommitter",
MockKafkaBatchCommitter,
):
handler: typing.Final = KafkaConcurrentHandler(
enable_batch_commit=False,
concurrency_limit=0,
)
yield handler
if handler._is_running:
await handler.stop()
handler: typing.Final = KafkaConcurrentHandler(concurrency_limit=0)
yield handler
if handler._is_running:
await handler.stop()


@pytest_asyncio.fixture
async def handler_with_committer() -> typing.AsyncIterator[KafkaConcurrentHandler]:
with patch(
"faststream_concurrent_aiokafka.processing.KafkaBatchCommitter",
MockKafkaBatchCommitter,
):
h: typing.Final = KafkaConcurrentHandler(
enable_batch_commit=True,
commit_batch_timeout_sec=5,
commit_batch_size=10,
)
yield h
if h._is_running:
await h.stop()
h: typing.Final = KafkaConcurrentHandler(committer=MockKafkaBatchCommitter()) # ty: ignore[invalid-argument-type]
yield h
if h._is_running:
await h.stop()


@pytest_asyncio.fixture
async def handler_with_limit() -> typing.AsyncIterator[KafkaConcurrentHandler]:
with patch(
"faststream_concurrent_aiokafka.processing.KafkaBatchCommitter",
MockKafkaBatchCommitter,
):
h: typing.Final = KafkaConcurrentHandler(concurrency_limit=2)
yield h
h: typing.Final = KafkaConcurrentHandler(concurrency_limit=2)
yield h


@pytest.fixture
Expand All @@ -88,25 +69,6 @@ def sample_record() -> MockConsumerRecord:
return MockConsumerRecord()


def test_concurrent_singleton_same_instance() -> None:
h1: typing.Final = KafkaConcurrentHandler()
h2: typing.Final = KafkaConcurrentHandler()

assert h1 is h2


def test_concurrent_init_called_several_times() -> None:
expected: typing.Final = 5
first: typing.Final = KafkaConcurrentHandler(concurrency_limit=expected)
second: typing.Final = KafkaConcurrentHandler(concurrency_limit=expected * 2)

assert second is first
assert first.limiter
assert second.limiter
assert isinstance(first.limiter, asyncio.Semaphore)
assert first.limiter is second.limiter


def test_concurrent_init_without_concurrency_limit() -> None:
obj: typing.Final = KafkaConcurrentHandler(concurrency_limit=0)
assert obj.limiter is None
Expand Down Expand Up @@ -541,75 +503,63 @@ async def test_concurrent_cancels_observer(handler: KafkaConcurrentHandler, capl


async def test_concurrent_full_lifecycle() -> None:
with patch(
"faststream_concurrent_aiokafka.batch_committer.KafkaBatchCommitter",
MockKafkaBatchCommitter,
):
handler: typing.Final = KafkaConcurrentHandler(enable_batch_commit=True, concurrency_limit=2)
handler: typing.Final = KafkaConcurrentHandler(committer=MockKafkaBatchCommitter(), concurrency_limit=2) # ty: ignore[invalid-argument-type]

await handler.start()
assert handler.is_healthy
await handler.start()
assert handler.is_healthy

processed: typing.Final = []
processed: typing.Final = []

async def process_msg(msg_id: int) -> None:
await asyncio.sleep(0.01)
processed.append(msg_id)
async def process_msg(msg_id: int) -> None:
await asyncio.sleep(0.01)
processed.append(msg_id)

msg: typing.Final = MockKafkaMessage()
record: typing.Final = MockConsumerRecord()
msg: typing.Final = MockKafkaMessage()
record: typing.Final = MockConsumerRecord()

for i in range(5):
await handler.handle_task(process_msg(i), record, msg) # ty: ignore[invalid-argument-type]
for i in range(5):
await handler.handle_task(process_msg(i), record, msg) # ty: ignore[invalid-argument-type]

await handler.wait_for_subtasks()
await handler.stop()
await handler.wait_for_subtasks()
await handler.stop()

assert not handler.is_running
assert len(processed) > 0
assert not handler.is_running
assert len(processed) > 0


async def test_concurrent_message_processing() -> None:
target_value: typing.Final = 5
with patch(
"faststream_concurrent_aiokafka.batch_committer.KafkaBatchCommitter",
MockKafkaBatchCommitter,
):
handler: typing.Final = KafkaConcurrentHandler(concurrency_limit=target_value)
await handler.start()
handler: typing.Final = KafkaConcurrentHandler(concurrency_limit=target_value)
await handler.start()

start_times: typing.Final = []
end_times: typing.Final = []
start_times: typing.Final = []
end_times: typing.Final = []

async def tracked_task(idx: int) -> None:
start_times.append((idx, asyncio.get_event_loop().time()))
await asyncio.sleep(0.05)
end_times.append((idx, asyncio.get_event_loop().time()))
async def tracked_task(idx: int) -> None:
start_times.append((idx, asyncio.get_event_loop().time()))
await asyncio.sleep(0.05)
end_times.append((idx, asyncio.get_event_loop().time()))

msg: typing.Final = MockKafkaMessage()
record: typing.Final = MockConsumerRecord()
msg: typing.Final = MockKafkaMessage()
record: typing.Final = MockConsumerRecord()

for i in range(target_value):
await handler.handle_task(tracked_task(i), record, msg) # ty: ignore[invalid-argument-type]
for i in range(target_value):
await handler.handle_task(tracked_task(i), record, msg) # ty: ignore[invalid-argument-type]

await handler.wait_for_subtasks()
await handler.stop()
await handler.wait_for_subtasks()
await handler.stop()

if len(start_times) == target_value and len(end_times) == target_value:
max_start: typing.Final = max(t for _, t in start_times)
min_end: typing.Final = min(t for _, t in end_times)
assert max_start < min_end
if len(start_times) == target_value and len(end_times) == target_value:
max_start: typing.Final = max(t for _, t in start_times)
min_end: typing.Final = min(t for _, t in end_times)
assert max_start < min_end


async def test_concurrent_signal_handling_integration() -> None:
with patch(
"faststream_concurrent_aiokafka.batch_committer.KafkaBatchCommitter",
MockKafkaBatchCommitter,
):
handler: typing.Final = KafkaConcurrentHandler()
await handler.start()
handler: typing.Final = KafkaConcurrentHandler()
await handler.start()

handler._signal_handler(signal.SIGTERM)
assert handler._stop_task is not None
await handler._stop_task
assert not handler.is_running
handler._signal_handler(signal.SIGTERM)
assert handler._stop_task is not None
await handler._stop_task
assert not handler.is_running
Loading
Loading