Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
116 changes: 47 additions & 69 deletions src/cachekit/decorators/wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
)
from ..key_generator import CacheKeyGenerator
from ..l1_cache import get_l1_cache
from ..object_cache import ObjectCache
from ..reliability import CircuitBreakerConfig

# Config import removed - using direct DecoratorConfig integration
Expand Down Expand Up @@ -479,6 +480,10 @@ def create_cache_wrapper(
# FIX: Initialize L1 cache if enabled
_l1_cache = get_l1_cache(namespace or "default") if l1_enabled else None

# L1-only mode: use ObjectCache for raw Python object storage (no serialization).
# This preserves types (tuples, sets, frozensets) that MessagePack would degrade.
_object_cache: ObjectCache | None = ObjectCache(max_entries=256) if _l1_only_mode else None

# Create per-function statistics tracker with lazy session ID generation
# Session ID format: "{process_uuid}:{module}.{function_name}"
# Generated lazily on first use or regenerated after cache_clear()
Expand Down Expand Up @@ -569,41 +574,22 @@ def sync_wrapper(*args: Any, **kwargs: Any) -> Any: # noqa: PLR0912
reset_current_function_stats(token)
return func(*args, **kwargs)

# L1-ONLY MODE: Skip backend initialization entirely
# This is the fix for the sentinel problem: when backend=None is explicitly passed,
# we should NOT try to get a backend from the provider
if _l1_only_mode:
# L1-only mode: Check L1 cache, execute function on miss, store in L1
if _l1_cache and cache_key:
l1_found, l1_bytes = _l1_cache.get(cache_key)
if l1_found and l1_bytes:
# L1 cache hit
try:
# Pass cache_key for AAD verification (required for encryption)
l1_value = operation_handler.serialization_handler.deserialize_data(l1_bytes, cache_key=cache_key)
_stats.record_l1_hit()
reset_current_function_stats(token)
return l1_value
except Exception:
# L1 deserialization failed - invalidate and continue
_l1_cache.invalidate(cache_key)
# L1-ONLY MODE: Store raw Python objects (no serialization).
# Preserves types (tuples, sets, frozensets) that MessagePack would degrade.
if _l1_only_mode and _object_cache:
found, cached_value = _object_cache.get(cache_key)
if found:
_stats.record_l1_hit()
features.clear_correlation_id()
reset_current_function_stats(token)
return cached_value

# L1 cache miss - execute function and store in L1
# Cache miss - execute function and store raw result
_stats.record_miss()
try:
result = func(*args, **kwargs)
# Serialize and store in L1
try:
# Pass cache_key for AAD binding (required for encryption)
serialized_bytes = operation_handler.serialization_handler.serialize_data(
result, args, kwargs, cache_key=cache_key
)
if _l1_cache and cache_key and serialized_bytes:
_l1_cache.put(cache_key, serialized_bytes, redis_ttl=ttl)
_cached_keys.add(cache_key)
except Exception as e:
# Serialization/storage failed but function succeeded - log and return result
logger().debug(f"L1-only mode: serialization/storage failed for {cache_key}: {e}")
_object_cache.put(cache_key, result, ttl=ttl if ttl is not None else 31536000)
_cached_keys.add(cache_key)
return result
finally:
features.clear_correlation_id()
Expand Down Expand Up @@ -913,39 +899,20 @@ async def async_wrapper(*args: Any, **kwargs: Any) -> Any:
)
return await func(*args, **kwargs)

# L1-ONLY MODE: Skip backend initialization entirely
# This is the fix for the sentinel problem: when backend=None is explicitly passed,
# we should NOT try to get a backend from the provider
if _l1_only_mode:
# L1-only mode: Check L1 cache, execute function on miss, store in L1
if _l1_cache and cache_key:
l1_found, l1_bytes = _l1_cache.get(cache_key)
if l1_found and l1_bytes:
# L1 cache hit
try:
# Pass cache_key for AAD verification (required for encryption)
l1_value = operation_handler.serialization_handler.deserialize_data(l1_bytes, cache_key=cache_key)
_stats.record_l1_hit()
return l1_value
except Exception:
# L1 deserialization failed - invalidate and continue
_l1_cache.invalidate(cache_key)

# L1 cache miss - execute function and store in L1
# L1-ONLY MODE: Store raw Python objects (no serialization).
# Preserves types (tuples, sets, frozensets) that MessagePack would degrade.
if _l1_only_mode and _object_cache:
found, cached_value = _object_cache.get(cache_key)
if found:
_stats.record_l1_hit()
features.clear_correlation_id()
return cached_value

# Cache miss - execute function and store raw result
_stats.record_miss()
result = await func(*args, **kwargs)
# Serialize and store in L1
try:
# Pass cache_key for AAD binding (required for encryption)
serialized_bytes = operation_handler.serialization_handler.serialize_data(
result, args, kwargs, cache_key=cache_key
)
if _l1_cache and cache_key and serialized_bytes:
_l1_cache.put(cache_key, serialized_bytes, redis_ttl=ttl)
_cached_keys.add(cache_key)
except Exception as e:
# Serialization/storage failed but function succeeded - log and return result
logger().debug(f"L1-only mode: serialization/storage failed for {cache_key}: {e}")
_object_cache.put(cache_key, result, ttl=ttl if ttl is not None else 31536000)
_cached_keys.add(cache_key)
return result

# L1+L2 MODE: Original behavior with backend initialization
Expand Down Expand Up @@ -1314,7 +1281,9 @@ def invalidate_cache(*args: Any, **kwargs: Any) -> None:
# Snapshot prevents RuntimeError if another thread adds during iteration
keys_snapshot = set(_cached_keys)
for key in keys_snapshot:
if _l1_cache:
if _object_cache:
_object_cache.delete(key)
elif _l1_cache:
_l1_cache.invalidate(key)
if _backend and not _l1_only_mode:
invalidator.set_backend(_backend)
Expand All @@ -1329,7 +1298,9 @@ def invalidate_cache(*args: Any, **kwargs: Any) -> None:
# Single-key invalidation (specific args provided, or zero-param function)
cache_key = operation_handler.get_cache_key(func, args, kwargs, namespace, integrity_checking)

if _l1_cache and cache_key:
if _object_cache and cache_key:
_object_cache.delete(cache_key)
elif _l1_cache and cache_key:
_l1_cache.invalidate(cache_key)
_cached_keys.discard(cache_key)

Expand All @@ -1355,7 +1326,9 @@ async def ainvalidate_cache(*args: Any, **kwargs: Any) -> None:
if not args and not kwargs and _func_has_params:
keys_snapshot = set(_cached_keys)
for key in keys_snapshot:
if _l1_cache:
if _object_cache:
_object_cache.delete(key)
elif _l1_cache:
_l1_cache.invalidate(key)
if _backend and not _l1_only_mode:
invalidator.set_backend(_backend)
Expand All @@ -1370,7 +1343,9 @@ async def ainvalidate_cache(*args: Any, **kwargs: Any) -> None:
# Single-key invalidation (specific args provided, or zero-param function)
cache_key = operation_handler.get_cache_key(func, args, kwargs, namespace, integrity_checking)

if _l1_cache and cache_key:
if _object_cache and cache_key:
_object_cache.delete(cache_key)
elif _l1_cache and cache_key:
_l1_cache.invalidate(cache_key)
_cached_keys.discard(cache_key)

Expand Down Expand Up @@ -1432,9 +1407,12 @@ def thread_func(x): ...
def cache_clear() -> None:
"""Clear cache statistics and invalidate all cached entries."""
_stats.clear()
# Also invalidate actual cache entries
if inspect.iscoroutinefunction(func):
raise TypeError("cache_clear() cannot clear cache for async functions. Use 'await fn.ainvalidate_cache()' instead.")
# In L1-only mode, invalidation is synchronous (no backend I/O needed)
# so cache_clear() works for both sync and async functions.
if inspect.iscoroutinefunction(func) and not _l1_only_mode:
raise TypeError(
"cache_clear() cannot clear cache for async functions with a backend. Use 'await fn.ainvalidate_cache()' instead."
)
invalidate_cache()

if inspect.iscoroutinefunction(func):
Expand Down
151 changes: 89 additions & 62 deletions tests/unit/test_cache_clear_async.py
Original file line number Diff line number Diff line change
@@ -1,101 +1,128 @@
"""Bug #49: cache_clear() broken for async-decorated functions.
"""Bug #49/#76: cache_clear() behavior for async-decorated functions.

Symptom: Calling cache_clear() on an async-decorated function creates
an unawaited coroutine (ainvalidate_cache()) that gets GC'd silently.
The cache is never cleared and Python emits RuntimeWarning.
History:
- #49: cache_clear() on async created an unawaited coroutine (fixed by raising TypeError)
- #76: TypeError is unnecessary in L1-only mode (no backend I/O needed)

Fix: cache_clear() is sync -- it cannot await. Raise TypeError telling
the user to use 'await fn.ainvalidate_cache()' instead.
Current behavior:
- L1-only mode (backend=None): cache_clear() works synchronously for both sync and async
- With backend: cache_clear() raises TypeError for async (must use ainvalidate_cache)
"""

import asyncio
from unittest.mock import patch
from unittest.mock import MagicMock, patch

import pytest

from cachekit.decorators import cache


class TestCacheClearAsyncBug:
"""Regression tests for GitHub Issue #49."""
class TestCacheClearAsyncL1Only:
"""cache_clear() works synchronously in L1-only mode (#76 fix)."""

def test_cache_clear_raises_type_error_for_async_function(self):
"""cache_clear() on an async function must raise TypeError.
def test_cache_clear_works_for_async_l1_only(self):
"""cache_clear() on async + backend=None must NOT raise.

BUG REPRODUCTION: Previously, cache_clear() called ainvalidate_cache()
without awaiting, creating a dangling coroutine that was silently GC'd.
In L1-only mode, invalidation is synchronous (no backend I/O),
so cache_clear() can safely clear without awaiting.
"""
with patch("cachekit.decorators.wrapper.get_backend_provider") as mock_provider:
mock_provider.return_value.get_backend.side_effect = RuntimeError("Should not be called!")

@cache(backend=None)
async def async_func(x: int) -> int:
return x * 2
@cache(backend=None)
async def async_func(x: int) -> int:
return x * 2

with pytest.raises(TypeError, match="cache_clear\\(\\) cannot clear cache for async functions"):
async_func.cache_clear()
# Should NOT raise TypeError
async_func.cache_clear()

def test_cache_clear_error_message_suggests_ainvalidate(self):
"""TypeError message must tell the user what to use instead."""
with patch("cachekit.decorators.wrapper.get_backend_provider") as mock_provider:
mock_provider.return_value.get_backend.side_effect = RuntimeError("Should not be called!")
def test_cache_clear_actually_clears_async_l1_only(self):
"""cache_clear() must actually clear cached entries for async L1-only."""
call_count = 0

@cache(backend=None)
async def async_func(x: int) -> int:
return x * 2
@cache(backend=None, ttl=300, namespace="test_clear_async_l1")
async def async_func(x: int) -> int:
nonlocal call_count
call_count += 1
return x * 2

with pytest.raises(TypeError, match="await fn.ainvalidate_cache\\(\\)"):
async_func.cache_clear()
async def run():
nonlocal call_count

await async_func(5)
assert call_count == 1

await async_func(5) # cached
assert call_count == 1

async_func.cache_clear()

await async_func(5) # recomputed
assert call_count == 2

asyncio.run(run())

def test_cache_clear_does_not_raise_for_sync_function(self):
"""Sync cache_clear() must NOT raise TypeError -- no regression.

This test verifies that the async fix does not break sync cache_clear().
We only verify it runs without raising, not full invalidation behavior
(which depends on key generation with no args -- a separate concern).
class TestCacheClearAsyncWithBackend:
"""cache_clear() raises TypeError when a backend is involved."""

def test_cache_clear_raises_type_error_with_backend(self):
"""cache_clear() on async with a real backend must raise TypeError.

When a backend exists, invalidation requires async I/O (delete from Redis).
cache_clear() is sync, so it cannot safely invalidate L2.
"""
mock_backend = MagicMock()

@cache(backend=mock_backend)
async def async_func(x: int) -> int:
return x * 2

with pytest.raises(TypeError, match="cache_clear\\(\\) cannot clear cache for async functions with a backend"):
async_func.cache_clear()


class TestCacheClearSync:
"""cache_clear() always works for sync functions (no regression)."""

def test_cache_clear_does_not_raise_for_sync_function(self):
"""Sync cache_clear() must NOT raise TypeError."""
with patch("cachekit.decorators.wrapper.get_backend_provider") as mock_provider:
mock_provider.return_value.get_backend.side_effect = RuntimeError("Should not be called!")

@cache(backend=None)
def sync_func(x: int) -> int:
return x * 2

# cache_clear() should NOT raise TypeError for sync functions
sync_func.cache_clear() # No exception = pass

def test_async_ainvalidate_cache_still_works(self):
"""The recommended path (ainvalidate_cache) must still work for async."""
with patch("cachekit.decorators.wrapper.get_backend_provider") as mock_provider:
mock_provider.return_value.get_backend.side_effect = RuntimeError("Should not be called!")

call_count = 0
class TestAsyncInvalidateCacheL1Only:
"""ainvalidate_cache() works for async functions in L1-only mode."""

@cache(backend=None)
async def async_func(x: int) -> int:
nonlocal call_count
call_count += 1
return x * 2
def test_async_ainvalidate_cache_l1_only(self):
"""ainvalidate_cache() clears entries for async L1-only functions."""
call_count = 0

@cache(backend=None, namespace="test_ainvalidate_works")
async def async_func(x: int) -> int:
nonlocal call_count
call_count += 1
return x * 2

async def run_test():
nonlocal call_count
async def run_test():
nonlocal call_count

# Populate cache
result1 = await async_func(5)
assert result1 == 10
assert call_count == 1
result1 = await async_func(5)
assert result1 == 10
assert call_count == 1

# Cached hit
result2 = await async_func(5)
assert result2 == 10
assert call_count == 1
result2 = await async_func(5)
assert result2 == 10
assert call_count == 1

# Use the correct async invalidation path
await async_func.ainvalidate_cache(5)
await async_func.ainvalidate_cache(5)

# After invalidation, function should re-execute
result3 = await async_func(5)
assert result3 == 10
assert call_count == 2
result3 = await async_func(5)
assert result3 == 10
assert call_count == 2

asyncio.run(run_test())
asyncio.run(run_test())
Loading