From a728cb7b514e20a291cc7d03a9184c94b1d8cdbb Mon Sep 17 00:00:00 2001 From: sfw Date: Fri, 20 Mar 2026 15:47:53 -0600 Subject: [PATCH] Cache surplus practice blocks for instant follow-up delivery When the LLM generates multiple practice_problem blocks, only the first is delivered to the learner. Extra blocks are cached and served instantly on the next continue request, eliminating wait time while the next eager generation runs in the background. Surplus entries piggyback on the existing predictive warming invalidation via is_predictive_warm flag. Co-Authored-By: Claude Opus 4.6 --- src/dibble/bootstrap.py | 6 + src/dibble/services/generation_engine.py | 74 ++++ src/dibble/services/surplus_practice_cache.py | 226 +++++++++++++ tests/test_surplus_practice_cache.py | 317 ++++++++++++++++++ 4 files changed, 623 insertions(+) create mode 100644 src/dibble/services/surplus_practice_cache.py create mode 100644 tests/test_surplus_practice_cache.py diff --git a/src/dibble/bootstrap.py b/src/dibble/bootstrap.py index d09ed61..05ae688 100644 --- a/src/dibble/bootstrap.py +++ b/src/dibble/bootstrap.py @@ -29,6 +29,7 @@ from dibble.services.outcome_store import SQLiteOutcomeStore from dibble.services.strand_store import SQLiteStrandStore from dibble.services.generation_engine import GenerationEngine +from dibble.services.surplus_practice_cache import SurplusPracticeCache from dibble.services.generation_mode_calibration import GenerationModeCalibrator from dibble.services.generated_content_store import SQLiteGeneratedContentStore from dibble.services.knowledge_component_store import SQLiteKnowledgeComponentStore @@ -258,12 +259,17 @@ def build_application_services( strategy_signal_service=learner_strategy_signal_service, within_session_adaptation_service=within_session_adaptation_service, ) + surplus_practice_cache = SurplusPracticeCache( + generated_content_store=generated_content_store, + cache_ttl_seconds=settings.generation_cache_ttl_seconds, + ) generation_engine = GenerationEngine( retriever=plugins.retriever, router=router_plugin, provider=plugins.provider, validator=plugins.validator, generated_content_store=generated_content_store, + surplus_practice_cache=surplus_practice_cache, cache_ttl_seconds=settings.generation_cache_ttl_seconds, ) misconception_remediation_outcome_signal_service = ( diff --git a/src/dibble/services/generation_engine.py b/src/dibble/services/generation_engine.py index 00bda1b..f34ad82 100644 --- a/src/dibble/services/generation_engine.py +++ b/src/dibble/services/generation_engine.py @@ -31,6 +31,7 @@ from dibble.services.generation_modes import build_generation_mode_plan from dibble.services.protocols import GeneratedContentStore from dibble.services.runtime_telemetry import log_runtime_event +from dibble.services.surplus_practice_cache import SurplusPracticeCache logger = logging.getLogger(__name__) @@ -44,6 +45,7 @@ def __init__( validator: ValidatorPlugin, moderation_service: ContentModerationService | None = None, generated_content_store: GeneratedContentStore | None = None, + surplus_practice_cache: SurplusPracticeCache | None = None, cache_ttl_seconds: int = 3600, time_provider=monotonic, ) -> None: @@ -53,6 +55,7 @@ def __init__( self.validator = validator self.moderation_service = moderation_service or ContentModerationService() self.generated_content_store = generated_content_store + self.surplus_practice_cache = surplus_practice_cache self.cache_ttl_seconds = max(0, cache_ttl_seconds) self.time_provider = time_provider @@ -89,6 +92,10 @@ def generate( route=route.model_dump(mode="json"), grounding=[item.model_dump(mode="json") for item in grounding], ) + surplus = self._pop_surplus(profile, request) + if surplus is not None: + return surplus.response + cache_key = self._cache_key(profile, request, route, grounding) cached = self._get_cached_content(cache_key=cache_key) if cached is not None: @@ -104,6 +111,7 @@ def generate( return cached.response started_at = self.time_provider() + surplus_blocks: list[GeneratedBlock] = [] request_moderation = self.moderation_service.moderate_request(request) if request_moderation.status == "flagged": blocks = self._moderation_fallback_blocks( @@ -125,6 +133,7 @@ def generate( else: blocks = self.provider.generate(profile, request, route, grounding) blocks = normalize_generated_blocks(blocks) + blocks, surplus_blocks = self._split_surplus(blocks) moderation = self.moderation_service.moderate_blocks(blocks) if moderation.status == "flagged": original_blocks = len(blocks) @@ -158,6 +167,7 @@ def generate( ), ) self._store_generated_content(cache_key=cache_key, content=content) + self._cache_surplus(surplus_blocks, blocks, content, profile, request) log_runtime_event( logger, logging.DEBUG, @@ -178,6 +188,31 @@ def stream_generate( ) -> Iterator[GenerationStreamEvent]: grounding = self._safe_retrieve(profile, request) route = self.router.route(profile, request) + + surplus = self._pop_surplus(profile, request) + if surplus is not None: + yield GenerationStreamEvent( + event="start", + student_id=profile.student_id, + route=surplus.response.route, + grounding=surplus.response.grounding, + ) + for chunk in self._stream_cached_blocks(surplus.response.blocks): + yield GenerationStreamEvent( + event="delta", + student_id=profile.student_id, + chunk=chunk, + ) + yield GenerationStreamEvent( + event="complete", + student_id=profile.student_id, + route=surplus.response.route, + grounding=surplus.response.grounding, + validation_issues=surplus.response.validation_issues, + response=surplus.response, + ) + return + cache_key = self._cache_key(profile, request, route, grounding) cached = self._get_cached_content(cache_key=cache_key) if cached is not None: @@ -213,6 +248,7 @@ def stream_generate( return started_at = self.time_provider() + surplus_blocks: list[GeneratedBlock] = [] request_moderation = self.moderation_service.moderate_request(request) if request_moderation.status == "flagged": blocks = self._moderation_fallback_blocks( @@ -264,6 +300,7 @@ def stream_generate( blocks = normalize_generated_blocks( [block_buffers[index] for index in sorted(block_buffers)] ) + blocks, surplus_blocks = self._split_surplus(blocks) moderation = self.moderation_service.moderate_blocks(blocks) if moderation.status == "flagged": original_blocks = len(blocks) @@ -308,6 +345,7 @@ def stream_generate( ), ) self._store_generated_content(cache_key=cache_key, content=content) + self._cache_surplus(surplus_blocks, blocks, content, profile, request) log_runtime_event( logger, logging.DEBUG, @@ -330,6 +368,42 @@ def stream_generate( response=content.response, ) + def _pop_surplus( + self, profile: LearnerProfile, request: GenerationRequest + ) -> GeneratedContent | None: + if self.surplus_practice_cache is None: + return None + return self.surplus_practice_cache.pop_surplus( + student_id=profile.student_id, + learning_session_id=request.learning_session_id, + ) + + def _split_surplus( + self, blocks: list[GeneratedBlock] + ) -> tuple[list[GeneratedBlock], list[GeneratedBlock]]: + if self.surplus_practice_cache is None: + return blocks, [] + return SurplusPracticeCache.split_practice_blocks(blocks) + + def _cache_surplus( + self, + surplus_blocks: list[GeneratedBlock], + delivery_blocks: list[GeneratedBlock], + content: GeneratedContent, + profile: LearnerProfile, + request: GenerationRequest, + ) -> None: + if not surplus_blocks or self.surplus_practice_cache is None: + return + non_practice = [b for b in delivery_blocks if b.kind != "practice_problem"] + self.surplus_practice_cache.cache_surplus( + surplus_blocks=surplus_blocks, + non_practice_blocks=non_practice, + parent_content=content, + profile=profile, + request=request, + ) + def _build_response( self, profile: LearnerProfile, diff --git a/src/dibble/services/surplus_practice_cache.py b/src/dibble/services/surplus_practice_cache.py new file mode 100644 index 0000000..abc90d8 --- /dev/null +++ b/src/dibble/services/surplus_practice_cache.py @@ -0,0 +1,226 @@ +"""Split multi-block practice responses and cache surplus questions. + +When the LLM generates 2-3 practice_problem blocks in a single response, +only the first is delivered to the learner. The remaining blocks are stored +as individual ``GeneratedContent`` entries so they can be served instantly +on the next continue request, buying time before the next LLM generation. +""" + +from __future__ import annotations + +import logging +from datetime import datetime, timedelta, timezone +from uuid import UUID, uuid4 + +from dibble.models.generation import ( + GeneratedBlock, + GeneratedContent, + GenerationRequest, +) +from dibble.models.profile import LearnerProfile +from dibble.services.protocols import GeneratedContentStore + +logger = logging.getLogger(__name__) + + +class SurplusPracticeCache: + """Manages splitting and caching of surplus practice blocks.""" + + def __init__( + self, + generated_content_store: GeneratedContentStore, + cache_ttl_seconds: int = 3600, + ) -> None: + self.store = generated_content_store + self.cache_ttl_seconds = max(0, cache_ttl_seconds) + + # ------------------------------------------------------------------ + # Public API + # ------------------------------------------------------------------ + + @staticmethod + def split_practice_blocks( + blocks: list[GeneratedBlock], + ) -> tuple[list[GeneratedBlock], list[GeneratedBlock]]: + """Separate *blocks* into delivery blocks and surplus practice blocks. + + Returns ``(delivery, surplus)`` where *delivery* contains all + non-practice blocks plus the **first** ``practice_problem`` block, + and *surplus* contains any remaining ``practice_problem`` blocks. + """ + non_practice: list[GeneratedBlock] = [] + practice: list[GeneratedBlock] = [] + for block in blocks: + if block.kind == "practice_problem": + practice.append(block) + else: + non_practice.append(block) + + if len(practice) <= 1: + return blocks, [] + + delivery = non_practice + practice[:1] + surplus = practice[1:] + return delivery, surplus + + def cache_surplus( + self, + *, + surplus_blocks: list[GeneratedBlock], + non_practice_blocks: list[GeneratedBlock], + parent_content: GeneratedContent, + profile: LearnerProfile, + request: GenerationRequest, + ) -> int: + """Store each surplus practice block as a separate cache entry. + + Each entry wraps the surplus block alongside the original + non-practice blocks (e.g. the summary) so the learner still sees + context when the surplus is served. + + Returns the number of surplus entries stored. + """ + if not surplus_blocks or self.cache_ttl_seconds <= 0: + return 0 + + stored = 0 + now = datetime.now(timezone.utc) + expires_at = now + timedelta(seconds=self.cache_ttl_seconds) + + for index, practice_block in enumerate(surplus_blocks): + cache_key = self._surplus_cache_key( + student_id=profile.student_id, + learning_session_id=request.learning_session_id, + sequence_index=index, + ) + blocks = list(non_practice_blocks) + [practice_block] + response = parent_content.response.model_copy( + update={ + "blocks": blocks, + "generation_id": str(uuid4()), + } + ) + request_context = dict(parent_content.request_context) + request_context["is_surplus_practice"] = True + request_context["is_predictive_warm"] = True + request_context["source_generation_id"] = parent_content.generation_id + request_context["surplus_sequence_index"] = index + + content = GeneratedContent( + generation_id=response.generation_id or str(uuid4()), + student_id=profile.student_id, + content_type=parent_content.content_type, + request_context=request_context, + response=response, + quality=parent_content.quality.model_copy( + update={"cache_hit": False} + ), + created_at=now, + expires_at=expires_at, + ) + self.store.upsert(cache_key=cache_key, content=content) + stored += 1 + + logger.debug( + "Cached %d surplus practice blocks for student %s (session %s)", + stored, + profile.student_id, + request.learning_session_id, + ) + return stored + + def pop_surplus( + self, + *, + student_id: UUID, + learning_session_id: str | None, + ) -> GeneratedContent | None: + """Retrieve and expire the next surplus practice block, if any.""" + cache_key = self._surplus_cache_key( + student_id=student_id, + learning_session_id=learning_session_id, + sequence_index=0, + ) + content = self.store.get_fresh(cache_key=cache_key) + if content is None: + return None + + # Expire the entry so it is not served again. + expired = content.model_copy( + update={"expires_at": datetime.now(timezone.utc)} + ) + self.store.refresh(content=expired) + + # Promote sequence_index=1 → 0 so the next pop finds it. + self._promote_surplus( + student_id=student_id, + learning_session_id=learning_session_id, + ) + + logger.debug( + "Popped surplus practice block %s for student %s", + content.generation_id, + student_id, + ) + return content + + # ------------------------------------------------------------------ + # Internal helpers + # ------------------------------------------------------------------ + + def _promote_surplus( + self, + *, + student_id: UUID, + learning_session_id: str | None, + ) -> None: + """Shift surplus entries down by one so index 1 becomes index 0.""" + index = 1 + while True: + old_key = self._surplus_cache_key( + student_id=student_id, + learning_session_id=learning_session_id, + sequence_index=index, + ) + entry = self.store.get_fresh(cache_key=old_key) + if entry is None: + break + # Expire old slot. + self.store.refresh( + content=entry.model_copy( + update={"expires_at": datetime.now(timezone.utc)} + ) + ) + # Re-store at index - 1 with a fresh generation_id to avoid + # unique constraint conflicts. + new_key = self._surplus_cache_key( + student_id=student_id, + learning_session_id=learning_session_id, + sequence_index=index - 1, + ) + new_gen_id = str(uuid4()) + new_context = dict(entry.request_context) + new_context["surplus_sequence_index"] = index - 1 + new_response = entry.response.model_copy( + update={"generation_id": new_gen_id} + ) + promoted = entry.model_copy( + update={ + "generation_id": new_gen_id, + "request_context": new_context, + "response": new_response, + "expires_at": entry.expires_at, + } + ) + self.store.upsert(cache_key=new_key, content=promoted) + index += 1 + + @staticmethod + def _surplus_cache_key( + *, + student_id: UUID, + learning_session_id: str | None, + sequence_index: int, + ) -> str: + session = learning_session_id or "none" + return f"surplus:{student_id}:{session}:{sequence_index}" diff --git a/tests/test_surplus_practice_cache.py b/tests/test_surplus_practice_cache.py new file mode 100644 index 0000000..61d44a1 --- /dev/null +++ b/tests/test_surplus_practice_cache.py @@ -0,0 +1,317 @@ +"""Tests for the surplus practice block cache.""" + +from datetime import datetime, timedelta, timezone +from uuid import UUID, uuid4 + +import pytest + +from dibble.models.generation import ( + AdaptiveRouteDecision, + DeliveryMode, + GeneratedBlock, + GeneratedContent, + GenerationMetadata, + GenerationRequest, + GenerationResponse, + InterventionType, + MultipleChoiceInteraction, + MultipleChoiceOption, +) +from dibble.models.profile import LearnerProfile +from dibble.services.generated_content_store import SQLiteGeneratedContentStore +from dibble.services.surplus_practice_cache import SurplusPracticeCache +from dibble.services.sqlite_connection import create_connection +from dibble.storage import ensure_database + + +STUDENT_ID = UUID("aaaaaaaa-aaaa-aaaa-aaaa-aaaaaaaaaaaa") + + +@pytest.fixture() +def store(tmp_path): + db_path = str(tmp_path / "surplus.db") + ensure_database(db_path) + conn = create_connection(db_path) + return SQLiteGeneratedContentStore(conn) + + +@pytest.fixture() +def cache(store): + return SurplusPracticeCache(store, cache_ttl_seconds=3600) + + +def _practice_block(title: str = "Q1") -> GeneratedBlock: + return GeneratedBlock( + kind="practice_problem", + title=title, + body="Solve this.", + interaction=MultipleChoiceInteraction( + prompt="What is 2+2?", + options=[ + MultipleChoiceOption(option_id="A", label="A", body="3"), + MultipleChoiceOption(option_id="B", label="B", body="4"), + ], + correct_option_id="B", + ), + ) + + +def _summary_block() -> GeneratedBlock: + return GeneratedBlock(kind="summary", title="Summary", body="Context here.") + + +def _parent_content( + blocks: list[GeneratedBlock], + student_id: UUID = STUDENT_ID, +) -> GeneratedContent: + route = AdaptiveRouteDecision( + intervention_type=InterventionType.reteach, + delivery_mode=DeliveryMode.generated, + scaffolding_level="medium", + reasons=["test"], + ) + metadata = GenerationMetadata( + quality_score=0.8, validation_passed=True, grounding_count=0 + ) + gen_id = str(uuid4()) + response = GenerationResponse( + student_id=student_id, + route=route, + blocks=blocks, + curriculum_context=["fractions"], + grounding=[], + safety_notes=[], + generation_id=gen_id, + generation_metadata=metadata, + ) + now = datetime.now(timezone.utc) + return GeneratedContent( + generation_id=gen_id, + student_id=student_id, + content_type="practice_problem", + request_context={ + "target_kc_ids": ["KC-1"], + "target_lo_ids": ["LO-1"], + "learning_session_id": "session-1", + }, + response=response, + quality=metadata, + created_at=now, + expires_at=now + timedelta(hours=1), + ) + + +def _request(student_id: UUID = STUDENT_ID) -> GenerationRequest: + return GenerationRequest( + student_id=student_id, + target_kc_ids=["KC-1"], + learning_session_id="session-1", + ) + + +def _profile(student_id: UUID = STUDENT_ID) -> LearnerProfile: + return LearnerProfile(student_id=student_id, grade_level="7") + + +# ------------------------------------------------------------------ +# split_practice_blocks +# ------------------------------------------------------------------ + + +class TestSplitPracticeBlocks: + def test_single_practice_unchanged(self): + blocks = [_summary_block(), _practice_block("Q1")] + delivery, surplus = SurplusPracticeCache.split_practice_blocks(blocks) + assert delivery == blocks + assert surplus == [] + + def test_two_practice_splits(self): + blocks = [_summary_block(), _practice_block("Q1"), _practice_block("Q2")] + delivery, surplus = SurplusPracticeCache.split_practice_blocks(blocks) + assert len(delivery) == 2 # summary + Q1 + assert delivery[0].kind == "summary" + assert delivery[1].title == "Q1" + assert len(surplus) == 1 + assert surplus[0].title == "Q2" + + def test_three_practice_splits(self): + blocks = [ + _summary_block(), + _practice_block("Q1"), + _practice_block("Q2"), + _practice_block("Q3"), + ] + delivery, surplus = SurplusPracticeCache.split_practice_blocks(blocks) + assert len(delivery) == 2 + assert len(surplus) == 2 + assert surplus[0].title == "Q2" + assert surplus[1].title == "Q3" + + def test_no_practice_unchanged(self): + blocks = [_summary_block(), GeneratedBlock(kind="instruction", title="I", body="text")] + delivery, surplus = SurplusPracticeCache.split_practice_blocks(blocks) + assert delivery == blocks + assert surplus == [] + + def test_only_practice_no_summary(self): + blocks = [_practice_block("Q1"), _practice_block("Q2")] + delivery, surplus = SurplusPracticeCache.split_practice_blocks(blocks) + assert len(delivery) == 1 + assert delivery[0].title == "Q1" + assert len(surplus) == 1 + + +# ------------------------------------------------------------------ +# cache_surplus + pop_surplus +# ------------------------------------------------------------------ + + +class TestCacheAndPop: + def test_cache_and_pop_returns_surplus(self, cache): + summary = _summary_block() + surplus_blocks = [_practice_block("Q2")] + parent = _parent_content([summary, _practice_block("Q1")]) + req = _request() + profile = _profile() + + cache.cache_surplus( + surplus_blocks=surplus_blocks, + non_practice_blocks=[summary], + parent_content=parent, + profile=profile, + request=req, + ) + + popped = cache.pop_surplus( + student_id=STUDENT_ID, + learning_session_id="session-1", + ) + assert popped is not None + assert popped.generation_id != parent.generation_id + practice_blocks = [b for b in popped.response.blocks if b.kind == "practice_problem"] + assert len(practice_blocks) == 1 + assert practice_blocks[0].title == "Q2" + + def test_pop_returns_none_when_empty(self, cache): + popped = cache.pop_surplus( + student_id=STUDENT_ID, + learning_session_id="session-1", + ) + assert popped is None + + def test_pop_consumes_entry(self, cache): + summary = _summary_block() + surplus_blocks = [_practice_block("Q2")] + parent = _parent_content([summary, _practice_block("Q1")]) + + cache.cache_surplus( + surplus_blocks=surplus_blocks, + non_practice_blocks=[summary], + parent_content=parent, + profile=_profile(), + request=_request(), + ) + + first = cache.pop_surplus(student_id=STUDENT_ID, learning_session_id="session-1") + assert first is not None + second = cache.pop_surplus(student_id=STUDENT_ID, learning_session_id="session-1") + assert second is None + + def test_multiple_surplus_served_in_order(self, cache): + summary = _summary_block() + surplus_blocks = [_practice_block("Q2"), _practice_block("Q3")] + parent = _parent_content([summary, _practice_block("Q1")]) + + cache.cache_surplus( + surplus_blocks=surplus_blocks, + non_practice_blocks=[summary], + parent_content=parent, + profile=_profile(), + request=_request(), + ) + + first = cache.pop_surplus(student_id=STUDENT_ID, learning_session_id="session-1") + assert first is not None + q2_blocks = [b for b in first.response.blocks if b.kind == "practice_problem"] + assert q2_blocks[0].title == "Q2" + + second = cache.pop_surplus(student_id=STUDENT_ID, learning_session_id="session-1") + assert second is not None + q3_blocks = [b for b in second.response.blocks if b.kind == "practice_problem"] + assert q3_blocks[0].title == "Q3" + + third = cache.pop_surplus(student_id=STUDENT_ID, learning_session_id="session-1") + assert third is None + + +# ------------------------------------------------------------------ +# Invalidation piggyback +# ------------------------------------------------------------------ + + +class TestInvalidation: + def test_surplus_has_predictive_warm_flag(self, cache): + summary = _summary_block() + surplus_blocks = [_practice_block("Q2")] + parent = _parent_content([summary, _practice_block("Q1")]) + + cache.cache_surplus( + surplus_blocks=surplus_blocks, + non_practice_blocks=[summary], + parent_content=parent, + profile=_profile(), + request=_request(), + ) + + popped = cache.pop_surplus(student_id=STUDENT_ID, learning_session_id="session-1") + assert popped is not None + assert popped.request_context.get("is_predictive_warm") is True + assert popped.request_context.get("is_surplus_practice") is True + + +# ------------------------------------------------------------------ +# Session isolation +# ------------------------------------------------------------------ + + +class TestSessionIsolation: + def test_different_session_not_served(self, cache): + summary = _summary_block() + surplus_blocks = [_practice_block("Q2")] + parent = _parent_content([summary, _practice_block("Q1")]) + + cache.cache_surplus( + surplus_blocks=surplus_blocks, + non_practice_blocks=[summary], + parent_content=parent, + profile=_profile(), + request=_request(), + ) + + popped = cache.pop_surplus( + student_id=STUDENT_ID, + learning_session_id="different-session", + ) + assert popped is None + + +# ------------------------------------------------------------------ +# TTL +# ------------------------------------------------------------------ + + +class TestTTL: + def test_zero_ttl_does_not_cache(self, store): + no_cache = SurplusPracticeCache(store, cache_ttl_seconds=0) + summary = _summary_block() + surplus_blocks = [_practice_block("Q2")] + parent = _parent_content([summary, _practice_block("Q1")]) + + stored = no_cache.cache_surplus( + surplus_blocks=surplus_blocks, + non_practice_blocks=[summary], + parent_content=parent, + profile=_profile(), + request=_request(), + ) + assert stored == 0