Skip to content

Commit 5a1891c

Browse files
committed
small refactoring
1 parent b05c768 commit 5a1891c

9 files changed

Lines changed: 222 additions & 338 deletions

File tree

faststream_concurrent_aiokafka/middleware.py

Lines changed: 13 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
from faststream import BaseMiddleware, ContextRepo
55
from faststream.kafka.message import KafkaAckableMessage
66

7+
from faststream_concurrent_aiokafka.batch_committer import KafkaBatchCommitter
78
from faststream_concurrent_aiokafka.processing import DEFAULT_CONCURRENCY_LIMIT, KafkaConcurrentHandler
89

910

@@ -32,7 +33,7 @@ async def consume_scope( # ty: ignore[invalid-method-override]
3233
raise RuntimeError(info)
3334

3435
kafka_message: typing.Final = self.context.get("message")
35-
if concurrent_processing.enable_batch_commit and not kafka_message:
36+
if concurrent_processing.has_batch_commit and not kafka_message:
3637
logger.error("Kafka middleware. No kafka message in the middleware, it means no consumer to commit batch.")
3738
info = "No kafka message in the middleware"
3839
raise RuntimeError(info)
@@ -50,39 +51,31 @@ async def handler_wrapper() -> typing.Any: # noqa: ANN401
5051
async def initialize_concurrent_processing(
5152
context: ContextRepo,
5253
concurrency_limit: int = DEFAULT_CONCURRENCY_LIMIT,
53-
enable_batch_commit: bool = False,
5454
commit_batch_size: int = 10,
5555
commit_batch_timeout_sec: float = 10.0,
56-
) -> None:
56+
) -> KafkaConcurrentHandler:
57+
existing: KafkaConcurrentHandler | None = context.get(_PROCESSING_CONTEXT_KEY)
58+
if existing and existing.is_running:
59+
logger.warning("Kafka middleware. Processing is already active")
60+
return existing
61+
5762
concurrent_processing: typing.Final = KafkaConcurrentHandler(
58-
commit_batch_size=commit_batch_size,
59-
commit_batch_timeout_sec=commit_batch_timeout_sec,
6063
concurrency_limit=concurrency_limit,
61-
enable_batch_commit=enable_batch_commit,
64+
committer=KafkaBatchCommitter(commit_batch_timeout_sec, commit_batch_size),
6265
)
63-
if concurrent_processing.is_running:
64-
logger.warning("Kafka middleware. Processing is already active")
65-
return
66-
try:
67-
await concurrent_processing.start()
68-
except Exception as exc:
69-
logger.exception("Kafka middleware. Cannot start concurrent processing")
70-
msg: typing.Final = "Kafka middleware. Cannot start concurrent processing"
71-
raise RuntimeError(msg) from exc
72-
66+
await concurrent_processing.start()
7367
context.set_global(_PROCESSING_CONTEXT_KEY, concurrent_processing)
7468
logger.info("Kafka middleware. Concurrent processing is active")
69+
return concurrent_processing
7570

7671

7772
async def stop_concurrent_processing(
7873
context: ContextRepo,
7974
) -> None:
80-
concurrent_processing: typing.Final = KafkaConcurrentHandler()
81-
if not concurrent_processing.is_healthy:
75+
concurrent_processing: typing.Final[KafkaConcurrentHandler | None] = context.get(_PROCESSING_CONTEXT_KEY)
76+
if not concurrent_processing or not concurrent_processing.is_healthy:
8277
logger.warning("Kafka middleware. Concurrent processing is not running. Cannot stop")
8378
return
8479

8580
await concurrent_processing.stop()
8681
context.set_global(_PROCESSING_CONTEXT_KEY, None)
87-
88-
KafkaConcurrentHandler.reset()

faststream_concurrent_aiokafka/processing.py

Lines changed: 8 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@
33
import functools
44
import logging
55
import signal
6-
import threading
76
import typing
87

98
from faststream.kafka import ConsumerRecord, TopicPartition
@@ -24,27 +23,12 @@
2423

2524

2625
class KafkaConcurrentHandler:
27-
_instance: typing.ClassVar["typing.Self | None"] = None
28-
_lock: typing.ClassVar[threading.Lock] = threading.Lock()
29-
_initialized: bool = False
30-
31-
def __new__(cls, *args: typing.Any, **kwargs: typing.Any) -> typing.Self: # noqa: ARG004, ANN401
32-
with cls._lock:
33-
if cls._instance is None:
34-
cls._instance = super().__new__(cls)
35-
return cls._instance
36-
3726
def __init__(
3827
self,
3928
concurrency_limit: int = DEFAULT_CONCURRENCY_LIMIT,
40-
commit_batch_timeout_sec: float = 10.0,
41-
commit_batch_size: int = 10,
42-
enable_batch_commit: bool = False,
29+
committer: KafkaBatchCommitter | None = None,
4330
observer_interval: float = DEFAULT_OBSERVER_INTERVAL_SEC,
4431
) -> None:
45-
if self._initialized:
46-
return
47-
4832
self.limiter = asyncio.Semaphore(concurrency_limit) if concurrency_limit != 0 else None
4933
self._current_tasks: set[asyncio.Task[typing.Any]] = set()
5034

@@ -53,19 +37,8 @@ def __init__(
5337
self._observer_interval: float = observer_interval
5438
self._is_running: bool = False
5539

56-
self.enable_batch_commit = enable_batch_commit
57-
self._commit_batch_timeout_sec: float = commit_batch_timeout_sec
58-
self._commit_batch_size: int = commit_batch_size
59-
60-
self._committer: KafkaBatchCommitter | None = None
40+
self._committer: KafkaBatchCommitter | None = committer
6141
self._stop_task: asyncio.Task[typing.Any] | None = None
62-
self._initialized = True
63-
64-
@classmethod
65-
def reset(cls) -> None:
66-
with cls._lock:
67-
cls._initialized = False
68-
cls._instance = None
6942

7043
def _is_need_to_process_message(self, message: KafkaAckableMessage) -> bool:
7144
headers_topic_group: typing.Final[str | None] = message.headers.get(TOPIC_GROUP_KEY)
@@ -101,7 +74,7 @@ async def handle_task(
10174
task: typing.Final = asyncio.create_task(coroutine)
10275
self._current_tasks.add(task)
10376
task.add_done_callback(self._finish_task)
104-
if self.enable_batch_commit and self._committer:
77+
if self._committer:
10578
try:
10679
await self._committer.send_task(
10780
batch_committer.KafkaCommitTask(
@@ -159,8 +132,7 @@ async def start(self) -> None:
159132
self._is_running = True
160133
self._shutdown_event.clear()
161134

162-
if self.enable_batch_commit:
163-
self._committer = KafkaBatchCommitter(self._commit_batch_timeout_sec, self._commit_batch_size)
135+
if self._committer:
164136
self._committer.spawn()
165137
self._observer_task = asyncio.create_task(self.observer())
166138
self._setup_signal_handlers()
@@ -200,6 +172,10 @@ async def force_cancel_all(self) -> None:
200172
await asyncio.sleep(0.5)
201173
self._current_tasks.clear()
202174

175+
@property
176+
def has_batch_commit(self) -> bool:
177+
return self._committer is not None
178+
203179
@property
204180
def is_healthy(self) -> bool:
205181
status = self._is_running and self._observer_task is not None and not self._observer_task.done()

pyproject.toml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,6 @@ dev = [
4242
lint = [
4343
"ruff",
4444
"ty",
45-
"auto-typing-final",
4645
"eof-fixer"
4746
]
4847

@@ -78,7 +77,8 @@ isort.no-lines-before = ["standard-library", "local-folder"]
7877

7978
[tool.ruff.lint.extend-per-file-ignores]
8079
"tests/*.py" = [
81-
"S101", # allow asserts
80+
"S101", # allow asserts
81+
"PLR2004", # allow magic values in test assertions
8282
]
8383

8484
[tool.pytest.ini_options]

tests/conftest.py

Lines changed: 0 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,8 @@
11
import os
2-
import typing
32

43
import pytest
54

6-
from faststream_concurrent_aiokafka.processing import KafkaConcurrentHandler
7-
85

96
@pytest.fixture(scope="session")
107
def kafka_bootstrap_servers() -> str:
118
return os.environ.get("KAFKA_BOOTSTRAP_SERVERS", "localhost:9092")
12-
13-
14-
@pytest.fixture(autouse=True)
15-
def reset_singleton() -> typing.Iterator[None]:
16-
KafkaConcurrentHandler.reset()
17-
yield
18-
KafkaConcurrentHandler.reset()

tests/test_concurrent_processing.py

Lines changed: 49 additions & 99 deletions
Original file line numberDiff line numberDiff line change
@@ -39,43 +39,24 @@ def __init__(self, topic: str = "test-topic", partition: int = 0, offset: int =
3939

4040
@pytest_asyncio.fixture
4141
async def handler() -> typing.AsyncIterator[KafkaConcurrentHandler]:
42-
with patch(
43-
"faststream_concurrent_aiokafka.processing.KafkaBatchCommitter",
44-
MockKafkaBatchCommitter,
45-
):
46-
handler: typing.Final = KafkaConcurrentHandler(
47-
enable_batch_commit=False,
48-
concurrency_limit=0,
49-
)
50-
yield handler
51-
if handler._is_running:
52-
await handler.stop()
42+
handler: typing.Final = KafkaConcurrentHandler(concurrency_limit=0)
43+
yield handler
44+
if handler._is_running:
45+
await handler.stop()
5346

5447

5548
@pytest_asyncio.fixture
5649
async def handler_with_committer() -> typing.AsyncIterator[KafkaConcurrentHandler]:
57-
with patch(
58-
"faststream_concurrent_aiokafka.processing.KafkaBatchCommitter",
59-
MockKafkaBatchCommitter,
60-
):
61-
h: typing.Final = KafkaConcurrentHandler(
62-
enable_batch_commit=True,
63-
commit_batch_timeout_sec=5,
64-
commit_batch_size=10,
65-
)
66-
yield h
67-
if h._is_running:
68-
await h.stop()
50+
h: typing.Final = KafkaConcurrentHandler(committer=MockKafkaBatchCommitter()) # ty: ignore[invalid-argument-type]
51+
yield h
52+
if h._is_running:
53+
await h.stop()
6954

7055

7156
@pytest_asyncio.fixture
7257
async def handler_with_limit() -> typing.AsyncIterator[KafkaConcurrentHandler]:
73-
with patch(
74-
"faststream_concurrent_aiokafka.processing.KafkaBatchCommitter",
75-
MockKafkaBatchCommitter,
76-
):
77-
h: typing.Final = KafkaConcurrentHandler(concurrency_limit=2)
78-
yield h
58+
h: typing.Final = KafkaConcurrentHandler(concurrency_limit=2)
59+
yield h
7960

8061

8162
@pytest.fixture
@@ -88,25 +69,6 @@ def sample_record() -> MockConsumerRecord:
8869
return MockConsumerRecord()
8970

9071

91-
def test_concurrent_singleton_same_instance() -> None:
92-
h1: typing.Final = KafkaConcurrentHandler()
93-
h2: typing.Final = KafkaConcurrentHandler()
94-
95-
assert h1 is h2
96-
97-
98-
def test_concurrent_init_called_several_times() -> None:
99-
expected: typing.Final = 5
100-
first: typing.Final = KafkaConcurrentHandler(concurrency_limit=expected)
101-
second: typing.Final = KafkaConcurrentHandler(concurrency_limit=expected * 2)
102-
103-
assert second is first
104-
assert first.limiter
105-
assert second.limiter
106-
assert isinstance(first.limiter, asyncio.Semaphore)
107-
assert first.limiter is second.limiter
108-
109-
11072
def test_concurrent_init_without_concurrency_limit() -> None:
11173
obj: typing.Final = KafkaConcurrentHandler(concurrency_limit=0)
11274
assert obj.limiter is None
@@ -541,75 +503,63 @@ async def test_concurrent_cancels_observer(handler: KafkaConcurrentHandler, capl
541503

542504

543505
async def test_concurrent_full_lifecycle() -> None:
544-
with patch(
545-
"faststream_concurrent_aiokafka.batch_committer.KafkaBatchCommitter",
546-
MockKafkaBatchCommitter,
547-
):
548-
handler: typing.Final = KafkaConcurrentHandler(enable_batch_commit=True, concurrency_limit=2)
506+
handler: typing.Final = KafkaConcurrentHandler(committer=MockKafkaBatchCommitter(), concurrency_limit=2) # ty: ignore[invalid-argument-type]
549507

550-
await handler.start()
551-
assert handler.is_healthy
508+
await handler.start()
509+
assert handler.is_healthy
552510

553-
processed: typing.Final = []
511+
processed: typing.Final = []
554512

555-
async def process_msg(msg_id: int) -> None:
556-
await asyncio.sleep(0.01)
557-
processed.append(msg_id)
513+
async def process_msg(msg_id: int) -> None:
514+
await asyncio.sleep(0.01)
515+
processed.append(msg_id)
558516

559-
msg: typing.Final = MockKafkaMessage()
560-
record: typing.Final = MockConsumerRecord()
517+
msg: typing.Final = MockKafkaMessage()
518+
record: typing.Final = MockConsumerRecord()
561519

562-
for i in range(5):
563-
await handler.handle_task(process_msg(i), record, msg) # ty: ignore[invalid-argument-type]
520+
for i in range(5):
521+
await handler.handle_task(process_msg(i), record, msg) # ty: ignore[invalid-argument-type]
564522

565-
await handler.wait_for_subtasks()
566-
await handler.stop()
523+
await handler.wait_for_subtasks()
524+
await handler.stop()
567525

568-
assert not handler.is_running
569-
assert len(processed) > 0
526+
assert not handler.is_running
527+
assert len(processed) > 0
570528

571529

572530
async def test_concurrent_message_processing() -> None:
573531
target_value: typing.Final = 5
574-
with patch(
575-
"faststream_concurrent_aiokafka.batch_committer.KafkaBatchCommitter",
576-
MockKafkaBatchCommitter,
577-
):
578-
handler: typing.Final = KafkaConcurrentHandler(concurrency_limit=target_value)
579-
await handler.start()
532+
handler: typing.Final = KafkaConcurrentHandler(concurrency_limit=target_value)
533+
await handler.start()
580534

581-
start_times: typing.Final = []
582-
end_times: typing.Final = []
535+
start_times: typing.Final = []
536+
end_times: typing.Final = []
583537

584-
async def tracked_task(idx: int) -> None:
585-
start_times.append((idx, asyncio.get_event_loop().time()))
586-
await asyncio.sleep(0.05)
587-
end_times.append((idx, asyncio.get_event_loop().time()))
538+
async def tracked_task(idx: int) -> None:
539+
start_times.append((idx, asyncio.get_event_loop().time()))
540+
await asyncio.sleep(0.05)
541+
end_times.append((idx, asyncio.get_event_loop().time()))
588542

589-
msg: typing.Final = MockKafkaMessage()
590-
record: typing.Final = MockConsumerRecord()
543+
msg: typing.Final = MockKafkaMessage()
544+
record: typing.Final = MockConsumerRecord()
591545

592-
for i in range(target_value):
593-
await handler.handle_task(tracked_task(i), record, msg) # ty: ignore[invalid-argument-type]
546+
for i in range(target_value):
547+
await handler.handle_task(tracked_task(i), record, msg) # ty: ignore[invalid-argument-type]
594548

595-
await handler.wait_for_subtasks()
596-
await handler.stop()
549+
await handler.wait_for_subtasks()
550+
await handler.stop()
597551

598-
if len(start_times) == target_value and len(end_times) == target_value:
599-
max_start: typing.Final = max(t for _, t in start_times)
600-
min_end: typing.Final = min(t for _, t in end_times)
601-
assert max_start < min_end
552+
if len(start_times) == target_value and len(end_times) == target_value:
553+
max_start: typing.Final = max(t for _, t in start_times)
554+
min_end: typing.Final = min(t for _, t in end_times)
555+
assert max_start < min_end
602556

603557

604558
async def test_concurrent_signal_handling_integration() -> None:
605-
with patch(
606-
"faststream_concurrent_aiokafka.batch_committer.KafkaBatchCommitter",
607-
MockKafkaBatchCommitter,
608-
):
609-
handler: typing.Final = KafkaConcurrentHandler()
610-
await handler.start()
559+
handler: typing.Final = KafkaConcurrentHandler()
560+
await handler.start()
611561

612-
handler._signal_handler(signal.SIGTERM)
613-
assert handler._stop_task is not None
614-
await handler._stop_task
615-
assert not handler.is_running
562+
handler._signal_handler(signal.SIGTERM)
563+
assert handler._stop_task is not None
564+
await handler._stop_task
565+
assert not handler.is_running

0 commit comments

Comments
 (0)