11import asyncio
2+ import inspect
23from unittest .mock import AsyncMock
34
45import pytest
@@ -87,11 +88,11 @@ async def test_run_async_function_with_semaphore(use_semaphore_tuple: tuple[bool
8788 """Test that run_async_function_with_semaphore correctly manages the semaphore."""
8889 use_semaphore = use_semaphore_tuple [0 ]
8990 mock_async_func = AsyncMock (return_value = 'result' )
90- mock_semaphore = AsyncMock () if use_semaphore else None
91+ mock_semaphore : AsyncMock | None = AsyncMock () if use_semaphore else None
9192
9293 # If a semaphore is provided, it should be used via async with
9394 if use_semaphore :
94- assert mock_semaphore is not None # Type narrowing for mypy
95+ assert mock_semaphore is not None
9596 mock_semaphore .__aenter__ = AsyncMock ()
9697 mock_semaphore .__aexit__ = AsyncMock ()
9798
@@ -102,7 +103,7 @@ async def test_run_async_function_with_semaphore(use_semaphore_tuple: tuple[bool
102103
103104 # If a semaphore was provided, verify it was acquired and released
104105 if use_semaphore :
105- assert mock_semaphore is not None # Type narrowing for mypy
106+ assert mock_semaphore is not None
106107 mock_semaphore .__aenter__ .assert_called_once ()
107108 mock_semaphore .__aexit__ .assert_called_once ()
108109
@@ -131,30 +132,20 @@ async def call(self, *args: object, **kwargs: object) -> str:
131132 async def test_semaphore_limits_concurrency (self , concurrency : int , num_tasks : int ) -> None :
132133 """Test that AsyncResource correctly limits concurrency using its semaphore."""
133134 resource = self .TestResource (concurrency = concurrency )
134- resource .call_mock .return_value = 'test_result'
135-
136- # Track the number of concurrent executions
137135 max_concurrent = 0
138136 current_concurrent = 0
139- original_aenter = resource .semaphore .__aenter__
140137
141- async def tracking_aenter ( self : asyncio . Semaphore ) -> asyncio . Semaphore :
138+ async def tracked_call ( * _args : object , ** _kwargs : object ) -> str :
142139 nonlocal current_concurrent , max_concurrent
143- await original_aenter ()
144140 current_concurrent += 1
145141 max_concurrent = max (max_concurrent , current_concurrent )
146- return self
147-
148- original_aexit = resource .semaphore .__aexit__
149-
150- async def tracking_aexit (_self : asyncio .Semaphore , * args : object ) -> object :
151- nonlocal current_concurrent
152- current_concurrent -= 1
153- return await original_aexit (* args )
142+ try :
143+ await asyncio .sleep (0.01 )
144+ return 'test_result'
145+ finally :
146+ current_concurrent -= 1
154147
155- # Replace the enter and exit methods to track concurrency
156- resource .semaphore .__aenter__ = tracking_aenter .__get__ (resource .semaphore )
157- resource .semaphore .__aexit__ = tracking_aexit .__get__ (resource .semaphore )
148+ resource .call_mock .side_effect = tracked_call
158149
159150 # Create and gather multiple tasks
160151 tasks = [resource .task (f'arg{ i } ' ) for i in range (num_tasks )]
@@ -174,6 +165,4 @@ async def tracking_aexit(_self: asyncio.Semaphore, *args: object) -> object:
174165 @pytest .mark .asyncio
175166 async def test_abstract_call_method (self ) -> None :
176167 """Test that AsyncResource.call is abstract and must be implemented."""
177- # We can't instantiate AsyncResource directly because it's abstract
178- with pytest .raises (TypeError , match = r'abstract method' ):
179- AsyncResource ()
168+ assert inspect .isabstract (AsyncResource )
0 commit comments