Skip to content

Commit 4ef36c8

Browse files
authored
fix(execute_class): add async lock to prevent double deploy (#274)
* fix(execute_class): add async lock to _ensure_initialized to prevent double deploy Without a lock, concurrent calls to _ensure_initialized both pass the check and both call get_or_deploy_resource, wasting resources and orphaning one stub. Uses double-checked locking: fast-path check before lock acquisition, second check inside the lock. Closes AE-2370 * fix(execute_class): address PR review feedback - Replace misleading carried-over comment with accurate description - Add inline comments explaining double-checked locking pattern - Add failure-path test: deploy exception releases lock, allows retry * fix(tests): address PR #274 review feedback - Rename TestNEW1_EnsureInitializedRace to TestEnsureInitializedRace - Update docstring to describe regression guard, not pre-fix state - Replace asyncio.sleep(0.05) with explicit deploy_entered Event
1 parent 65aa1c9 commit 4ef36c8

3 files changed

Lines changed: 161 additions & 10 deletions

File tree

src/runpod_flash/execute_class.py

Lines changed: 15 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
prevent memory leaks through LRU eviction.
77
"""
88

9+
import asyncio
910
import hashlib
1011
import inspect
1112
import logging
@@ -216,6 +217,7 @@ def __init__(self, *args, **kwargs):
216217
f"{cls.__name__}_{uuid.uuid4().hex[:UUID_FALLBACK_LENGTH]}"
217218
)
218219
self._initialized = False
220+
self._init_lock = asyncio.Lock()
219221

220222
# Generate cache key and get class code
221223
self._cache_key = get_class_cache_key(cls, args, kwargs)
@@ -224,20 +226,23 @@ def __init__(self, *args, **kwargs):
224226
)
225227

226228
async def _ensure_initialized(self):
227-
"""Ensure the remote instance is created."""
229+
"""Ensure the remote instance is created exactly once, even under concurrent calls."""
230+
# Fast path: already initialized, no lock needed.
228231
if self._initialized:
229232
return
230233

231-
# Get remote resource
232-
resource_manager = ResourceManager()
233-
remote_resource = await resource_manager.get_or_deploy_resource(
234-
self._resource_config
235-
)
236-
self._stub = stub_resource(remote_resource)
234+
# Slow path: acquire lock and re-check to prevent double deployment
235+
# when multiple coroutines race past the fast-path check.
236+
async with self._init_lock:
237+
if self._initialized:
238+
return
237239

238-
# Create the remote instance by calling a method (which will trigger instance creation)
239-
# We'll do this on first method call
240-
self._initialized = True
240+
resource_manager = ResourceManager()
241+
remote_resource = await resource_manager.get_or_deploy_resource(
242+
self._resource_config
243+
)
244+
self._stub = stub_resource(remote_resource)
245+
self._initialized = True
241246

242247
def __getattr__(self, name):
243248
"""Dynamically create method proxies for all class methods."""

tests/bug_probes/__init__.py

Whitespace-only changes.
Lines changed: 146 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,146 @@
1+
"""Bug probe tests for execute_class.py race conditions."""
2+
3+
import asyncio
4+
from unittest.mock import AsyncMock, MagicMock, patch
5+
6+
import pytest
7+
8+
9+
class TestEnsureInitializedRace:
10+
"""AE-2370: concurrent _ensure_initialized calls must not double-deploy.
11+
12+
Regression guard: without the async lock added in this PR, two concurrent
13+
calls both pass the `if not self._initialized` check and both call
14+
get_or_deploy_resource, causing a double deploy and orphaning one stub.
15+
"""
16+
17+
@pytest.fixture
18+
def wrapper_instance(self):
19+
"""Create a RemoteClassWrapper instance with mocked dependencies."""
20+
21+
class FakeModel:
22+
__name__ = "FakeModel"
23+
24+
def predict(self, x):
25+
return x
26+
27+
resource_config = MagicMock()
28+
29+
with (
30+
patch("runpod_flash.execute_class.get_class_cache_key", return_value="key"),
31+
patch(
32+
"runpod_flash.execute_class.get_or_cache_class_data",
33+
return_value="code",
34+
),
35+
):
36+
from runpod_flash.execute_class import create_remote_class
37+
38+
wrapper_cls = create_remote_class(
39+
cls=FakeModel,
40+
resource_config=resource_config,
41+
dependencies=None,
42+
system_dependencies=None,
43+
accelerate_downloads=False,
44+
)
45+
instance = wrapper_cls()
46+
47+
return instance
48+
49+
@pytest.mark.asyncio
50+
async def test_concurrent_calls_deploy_only_once(self, wrapper_instance):
51+
"""Two concurrent _ensure_initialized calls must call get_or_deploy_resource exactly once."""
52+
deploy_call_count = 0
53+
deploy_entered = asyncio.Event()
54+
gate = asyncio.Event()
55+
56+
async def slow_deploy(config):
57+
nonlocal deploy_call_count
58+
deploy_call_count += 1
59+
deploy_entered.set()
60+
await gate.wait()
61+
return MagicMock()
62+
63+
with (
64+
patch("runpod_flash.execute_class.ResourceManager") as mock_rm_cls,
65+
patch("runpod_flash.execute_class.stub_resource", return_value=MagicMock()),
66+
):
67+
mock_rm = MagicMock()
68+
mock_rm.get_or_deploy_resource = slow_deploy
69+
mock_rm_cls.return_value = mock_rm
70+
71+
task1 = asyncio.create_task(wrapper_instance._ensure_initialized())
72+
task2 = asyncio.create_task(wrapper_instance._ensure_initialized())
73+
74+
await deploy_entered.wait()
75+
gate.set()
76+
77+
await asyncio.gather(task1, task2)
78+
79+
assert deploy_call_count == 1, (
80+
f"get_or_deploy_resource called {deploy_call_count} times, expected 1. "
81+
"Race condition: concurrent calls both passed the initialized check."
82+
)
83+
84+
@pytest.mark.asyncio
85+
async def test_initialized_flag_set_after_deploy(self, wrapper_instance):
86+
"""After _ensure_initialized completes, _initialized must be True."""
87+
with (
88+
patch("runpod_flash.execute_class.ResourceManager") as mock_rm_cls,
89+
patch("runpod_flash.execute_class.stub_resource", return_value=MagicMock()),
90+
):
91+
mock_rm = MagicMock()
92+
mock_rm.get_or_deploy_resource = AsyncMock(return_value=MagicMock())
93+
mock_rm_cls.return_value = mock_rm
94+
95+
await wrapper_instance._ensure_initialized()
96+
97+
assert wrapper_instance._initialized is True
98+
99+
@pytest.mark.asyncio
100+
async def test_second_call_skips_deploy(self, wrapper_instance):
101+
"""Once initialized, subsequent calls must not call get_or_deploy_resource."""
102+
with (
103+
patch("runpod_flash.execute_class.ResourceManager") as mock_rm_cls,
104+
patch("runpod_flash.execute_class.stub_resource", return_value=MagicMock()),
105+
):
106+
mock_rm = MagicMock()
107+
mock_rm.get_or_deploy_resource = AsyncMock(return_value=MagicMock())
108+
mock_rm_cls.return_value = mock_rm
109+
110+
await wrapper_instance._ensure_initialized()
111+
mock_rm.get_or_deploy_resource.assert_awaited_once()
112+
113+
await wrapper_instance._ensure_initialized()
114+
mock_rm.get_or_deploy_resource.assert_awaited_once()
115+
116+
@pytest.mark.asyncio
117+
async def test_deploy_failure_releases_lock_and_allows_retry(
118+
self, wrapper_instance
119+
):
120+
"""If deploy fails, the lock must be released and a subsequent call must retry."""
121+
call_count = 0
122+
123+
async def failing_then_succeeding_deploy(config):
124+
nonlocal call_count
125+
call_count += 1
126+
if call_count == 1:
127+
raise ConnectionError("transient failure")
128+
return MagicMock()
129+
130+
with (
131+
patch("runpod_flash.execute_class.ResourceManager") as mock_rm_cls,
132+
patch("runpod_flash.execute_class.stub_resource", return_value=MagicMock()),
133+
):
134+
mock_rm = MagicMock()
135+
mock_rm.get_or_deploy_resource = failing_then_succeeding_deploy
136+
mock_rm_cls.return_value = mock_rm
137+
138+
with pytest.raises(ConnectionError, match="transient failure"):
139+
await wrapper_instance._ensure_initialized()
140+
141+
assert not wrapper_instance._initialized
142+
143+
# Retry should succeed
144+
await wrapper_instance._ensure_initialized()
145+
assert wrapper_instance._initialized
146+
assert call_count == 2

0 commit comments

Comments
 (0)