Skip to content

Commit 29f6faf

Browse files
shantanu patilshantanu patil
authored andcommitted
Merge branch 'worktree-agent-a36bc4ac'
2 parents 5a610fa + 86775aa commit 29f6faf

4 files changed

Lines changed: 312 additions & 11 deletions

File tree

api/context_budget.py

Lines changed: 114 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,114 @@
1+
"""Context Budget Manager --- manages token allocation for LLM context assembly."""
2+
3+
import logging
4+
from typing import List, Dict, Optional, Any, Callable
5+
6+
logger = logging.getLogger(__name__)
7+
8+
# Known context window sizes by provider and model
9+
CONTEXT_WINDOWS = {
10+
"google": {
11+
"gemini-2.5-flash": 1_048_576,
12+
"gemini-2.5-flash-lite": 1_048_576,
13+
"gemini-2.5-pro": 1_048_576,
14+
"gemini-2.0-flash": 1_048_576,
15+
"gemini-1.5-pro": 2_097_152,
16+
"gemini-1.5-flash": 1_048_576,
17+
},
18+
"openai": {
19+
"gpt-4o": 128_000,
20+
"gpt-4o-mini": 128_000,
21+
"gpt-4.1": 1_000_000,
22+
"gpt-4.1-mini": 1_000_000,
23+
"o3-mini": 200_000,
24+
},
25+
"openrouter": {}, # varies by model
26+
"ollama": {}, # varies by model, default 32K
27+
"bedrock": {},
28+
"azure": {},
29+
"dashscope": {},
30+
}
31+
32+
DEFAULT_CONTEXT_WINDOW = 128_000
33+
DEFAULT_OUTPUT_RESERVE = 8_192
34+
35+
36+
class ContextBudgetManager:
37+
"""Manages token budget for LLM context assembly."""
38+
39+
def get_context_window(self, provider: str, model: str) -> int:
40+
"""Get the context window size for a provider/model combo."""
41+
provider_windows = CONTEXT_WINDOWS.get(provider, {})
42+
# Try exact match first
43+
if model in provider_windows:
44+
return provider_windows[model]
45+
# Try prefix match (e.g., "gemini-2.5-flash-preview" matches "gemini-2.5-flash")
46+
for known_model, window in provider_windows.items():
47+
if model.startswith(known_model):
48+
return window
49+
return DEFAULT_CONTEXT_WINDOW
50+
51+
def get_context_budget(self, provider: str, model: str,
52+
prompt_tokens: int,
53+
output_reserve: int = DEFAULT_OUTPUT_RESERVE) -> int:
54+
"""Calculate available tokens for RAG context."""
55+
window = self.get_context_window(provider, model)
56+
budget = window - prompt_tokens - output_reserve
57+
logger.info(f"Context budget: window={window}, prompt={prompt_tokens}, "
58+
f"reserve={output_reserve}, available={budget}")
59+
return max(budget, 0)
60+
61+
def get_dynamic_top_k(self, provider: str, model: str,
62+
avg_chunk_tokens: int = 600) -> int:
63+
"""Calculate how many chunks to retrieve based on context window."""
64+
window = self.get_context_window(provider, model)
65+
# Use at most 30% of context window for retrieved chunks
66+
chunk_budget = int(window * 0.3)
67+
top_k = max(20, min(200, chunk_budget // avg_chunk_tokens))
68+
logger.info(f"Dynamic top_k for {provider}/{model}: {top_k}")
69+
return top_k
70+
71+
def assemble_context(self, documents: List[Dict],
72+
budget_tokens: int,
73+
count_tokens_fn: Callable[[str], int]) -> str:
74+
"""Greedily pack documents into the token budget, highest relevance first.
75+
76+
Args:
77+
documents: List of dicts with 'content' and optionally 'score', 'file_path'
78+
budget_tokens: Maximum tokens for context
79+
count_tokens_fn: Function that counts tokens in a string
80+
81+
Returns:
82+
Assembled context string fitting within budget
83+
"""
84+
# Sort by relevance score (highest first)
85+
sorted_docs = sorted(documents, key=lambda d: d.get('score', 0), reverse=True)
86+
87+
assembled = []
88+
used_tokens = 0
89+
90+
for doc in sorted_docs:
91+
content = doc.get('content', '')
92+
doc_tokens = count_tokens_fn(content)
93+
94+
if used_tokens + doc_tokens > budget_tokens:
95+
# Try to fit a truncated version
96+
remaining = budget_tokens - used_tokens
97+
if remaining > 100: # Only include if meaningful content fits
98+
# Rough truncation by ratio
99+
ratio = remaining / doc_tokens
100+
truncated = content[:int(len(content) * ratio)]
101+
assembled.append(truncated + "\n... [truncated]")
102+
break
103+
104+
file_path = doc.get('file_path', 'unknown')
105+
assembled.append(f"--- {file_path} ---\n{content}")
106+
used_tokens += doc_tokens
107+
108+
logger.info(f"Assembled context: {len(assembled)} docs, ~{used_tokens} tokens "
109+
f"(budget: {budget_tokens})")
110+
return "\n\n".join(assembled)
111+
112+
113+
# Module-level singleton
114+
context_budget_manager = ContextBudgetManager()

api/rag.py

Lines changed: 31 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
import adalflow as adal
99

1010
from api.tools.embedder import get_embedder
11+
from api.context_budget import context_budget_manager
1112
from api.prompts import RAG_SYSTEM_PROMPT as system_prompt, RAG_TEMPLATE
1213

1314
# Create our own implementation of the conversation classes
@@ -382,13 +383,20 @@ def prepare_retriever(self, repo_url_or_path: str, type: str = "github", access_
382383
try:
383384
# Use the appropriate embedder for retrieval
384385
retrieve_embedder = self.query_embedder if self.is_ollama_embedder else self.embedder
386+
387+
# Calculate dynamic top_k based on model context window
388+
retriever_kwargs = dict(configs["retriever"])
389+
dynamic_top_k = self._get_dynamic_top_k()
390+
if dynamic_top_k is not None:
391+
retriever_kwargs["top_k"] = dynamic_top_k
392+
385393
self.retriever = FAISSRetriever(
386-
**configs["retriever"],
394+
**retriever_kwargs,
387395
embedder=retrieve_embedder,
388396
documents=self.transformed_docs,
389397
document_map_func=lambda doc: doc.vector,
390398
)
391-
logger.info("FAISS retriever created successfully")
399+
logger.info(f"FAISS retriever created successfully (top_k={retriever_kwargs.get('top_k', 'default')})")
392400
except Exception as e:
393401
logger.error(f"Error creating FAISS retriever: {str(e)}")
394402
# Try to provide more specific error information
@@ -413,6 +421,27 @@ def prepare_retriever(self, repo_url_or_path: str, type: str = "github", access_
413421
logger.error(f"Sample embedding sizes: {', '.join(sizes)}")
414422
raise
415423

424+
def _get_dynamic_top_k(self) -> int:
425+
"""Calculate dynamic top_k based on the model's context window.
426+
427+
Uses the context_budget_manager to determine how many chunks to
428+
retrieve based on the provider and model. Returns None if the
429+
provider/model are not set (falls back to config default).
430+
"""
431+
if self.provider and self.model:
432+
try:
433+
# Get average chunk size from text_splitter config
434+
chunk_size = configs.get("text_splitter", {}).get("chunk_size", 350)
435+
# Rough estimate: 1 word ~ 1.3 tokens, so chunk_tokens ~ chunk_size * 1.3
436+
avg_chunk_tokens = int(chunk_size * 1.3)
437+
dynamic_top_k = context_budget_manager.get_dynamic_top_k(
438+
self.provider, self.model, avg_chunk_tokens=avg_chunk_tokens
439+
)
440+
return dynamic_top_k
441+
except Exception as e:
442+
logger.warning(f"Failed to calculate dynamic top_k: {e}")
443+
return None
444+
416445
def call(self, query: str, language: str = "en") -> Tuple[List]:
417446
"""
418447
Process a query using RAG.

api/rag_session.py

Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,70 @@
1+
"""RAG Session Manager --- caches RAG instances for reuse across page generations."""
2+
3+
import time
4+
import threading
5+
import logging
6+
from typing import Optional, Dict, Tuple, Any
7+
8+
logger = logging.getLogger(__name__)
9+
10+
11+
class RAGSessionManager:
12+
"""Caches RAG instances keyed by repo URL to avoid rebuilding FAISS index per page."""
13+
14+
_instance = None
15+
_lock = threading.Lock()
16+
17+
def __new__(cls):
18+
if cls._instance is None:
19+
with cls._lock:
20+
if cls._instance is None:
21+
cls._instance = super().__new__(cls)
22+
cls._instance._sessions: Dict[str, Tuple[Any, float]] = {}
23+
cls._instance._ttl = 3600 # 1 hour TTL
24+
cls._instance._max_sessions = 10
25+
return cls._instance
26+
27+
def get_session_key(self, repo_url: str, embedder_type: str = "default") -> str:
28+
"""Generate a cache key for a RAG session."""
29+
return f"{repo_url}:{embedder_type}"
30+
31+
def get(self, key: str):
32+
"""Get a cached RAG instance if it exists and hasn't expired."""
33+
with self._lock:
34+
if key in self._sessions:
35+
rag, last_access = self._sessions[key]
36+
if time.time() - last_access < self._ttl:
37+
self._sessions[key] = (rag, time.time())
38+
logger.info(f"RAG session cache hit for {key}")
39+
return rag
40+
else:
41+
# Expired
42+
del self._sessions[key]
43+
logger.info(f"RAG session expired for {key}")
44+
return None
45+
46+
def put(self, key: str, rag_instance):
47+
"""Cache a RAG instance."""
48+
with self._lock:
49+
# Evict oldest if at capacity
50+
if len(self._sessions) >= self._max_sessions and key not in self._sessions:
51+
oldest_key = min(self._sessions, key=lambda k: self._sessions[k][1])
52+
del self._sessions[oldest_key]
53+
logger.info(f"Evicted oldest RAG session: {oldest_key}")
54+
55+
self._sessions[key] = (rag_instance, time.time())
56+
logger.info(f"Cached RAG session for {key}")
57+
58+
def invalidate(self, key: str):
59+
"""Remove a cached session."""
60+
with self._lock:
61+
self._sessions.pop(key, None)
62+
63+
def clear(self):
64+
"""Clear all cached sessions."""
65+
with self._lock:
66+
self._sessions.clear()
67+
68+
69+
# Module-level singleton
70+
rag_session_manager = RAGSessionManager()

api/websocket_wiki.py

Lines changed: 97 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import asyncio
12
import logging
23
import os
34
from typing import List, Optional, Dict, Any
@@ -23,7 +24,8 @@
2324
from api.openrouter_client import OpenRouterClient
2425
from api.azureai_client import AzureAIClient
2526
from api.dashscope_client import DashscopeClient
26-
from api.rag import RAG
27+
from api.rag import RAG, Memory
28+
from api.rag_session import rag_session_manager
2729

2830
# Configure logging
2931
from api.logging_config import setup_logging
@@ -60,6 +62,70 @@ class ChatCompletionRequest(BaseModel):
6062
included_dirs: Optional[str] = Field(None, description="Comma-separated list of directories to include exclusively")
6163
included_files: Optional[str] = Field(None, description="Comma-separated list of file patterns to include exclusively")
6264

65+
async def generate_with_retry(rag, query, context_docs, provider, model, language="en", max_retries=3):
66+
"""Generate content with retry and context reduction on failure.
67+
68+
On token limit errors, reduces context by 50% per retry.
69+
On transient errors (timeout, 503, 429), retries with exponential backoff.
70+
Non-retryable errors are raised immediately.
71+
72+
Args:
73+
rag: RAG instance to use for generation
74+
query: The user query
75+
context_docs: List of retrieved documents
76+
provider: AI provider name
77+
model: Model name
78+
language: Language code for content generation
79+
max_retries: Maximum number of retry attempts
80+
81+
Returns:
82+
Retrieved documents result from RAG
83+
"""
84+
context_fraction = 1.0
85+
86+
for attempt in range(max_retries):
87+
try:
88+
# Reduce context on retries
89+
if context_fraction < 1.0 and context_docs:
90+
reduced_count = max(1, int(len(context_docs) * context_fraction))
91+
docs_to_use = context_docs[:reduced_count]
92+
logger.info(f"Using {len(docs_to_use)}/{len(context_docs)} context docs "
93+
f"({context_fraction:.0%})")
94+
else:
95+
docs_to_use = context_docs
96+
97+
result = rag(query, language=language)
98+
return result
99+
100+
except Exception as e:
101+
error_str = str(e).lower()
102+
103+
# Token limit errors -- reduce context
104+
if any(phrase in error_str for phrase in [
105+
'maximum context length', 'token limit', 'too many tokens',
106+
'content too large', 'request too large', 'input too long'
107+
]):
108+
context_fraction *= 0.5
109+
logger.warning(f"Token limit hit, reducing context to {context_fraction:.0%} "
110+
f"(attempt {attempt + 1}/{max_retries})")
111+
continue
112+
113+
# Transient errors -- retry with backoff
114+
if any(phrase in error_str for phrase in [
115+
'timeout', 'connection', '503', '502', '429', 'rate limit'
116+
]):
117+
wait_time = (2 ** attempt) # 1s, 2s, 4s
118+
logger.warning(f"Transient error, retrying in {wait_time}s "
119+
f"(attempt {attempt + 1}/{max_retries}): {e}")
120+
await asyncio.sleep(wait_time)
121+
continue
122+
123+
# Non-retryable error
124+
raise
125+
126+
raise Exception(f"Failed after {max_retries} retries with context at {context_fraction:.0%}")
127+
128+
63129
async def handle_websocket_chat(websocket: WebSocket):
64130
"""
65131
Handle WebSocket connection for chat completions.
@@ -83,10 +149,8 @@ async def handle_websocket_chat(websocket: WebSocket):
83149
logger.warning(f"Request exceeds recommended token limit ({tokens} > 7500)")
84150
input_too_large = True
85151

86-
# Create a new RAG instance for this request
152+
# Create or reuse a cached RAG instance for this request
87153
try:
88-
request_rag = RAG(provider=request.provider, model=request.model)
89-
90154
# Extract custom file filter parameters if provided
91155
excluded_dirs = None
92156
excluded_files = None
@@ -106,8 +170,28 @@ async def handle_websocket_chat(websocket: WebSocket):
106170
included_files = [unquote(file_pattern) for file_pattern in request.included_files.split('\n') if file_pattern.strip()]
107171
logger.info(f"Using custom included files: {included_files}")
108172

109-
request_rag.prepare_retriever(request.repo_url, request.type, request.token, excluded_dirs, excluded_files, included_dirs, included_files)
110-
logger.info(f"Retriever prepared for {request.repo_url}")
173+
# Check for a cached RAG session (only when no custom file filters)
174+
has_custom_filters = any([excluded_dirs, excluded_files, included_dirs, included_files])
175+
from api.config import get_embedder_type
176+
embedder_type = get_embedder_type()
177+
session_key = rag_session_manager.get_session_key(request.repo_url, embedder_type) if not has_custom_filters else None
178+
request_rag = rag_session_manager.get(session_key) if session_key else None
179+
180+
if request_rag is not None:
181+
# Reuse cached RAG instance, update provider/model for this request
182+
request_rag.provider = request.provider
183+
request_rag.model = request.model
184+
# Reset memory for this new conversation
185+
request_rag.memory = Memory()
186+
logger.info(f"Reusing cached RAG session for {request.repo_url}")
187+
else:
188+
# Create a new RAG instance
189+
request_rag = RAG(provider=request.provider, model=request.model)
190+
request_rag.prepare_retriever(request.repo_url, request.type, request.token, excluded_dirs, excluded_files, included_dirs, included_files)
191+
# Cache the session if no custom filters were used
192+
if session_key:
193+
rag_session_manager.put(session_key, request_rag)
194+
logger.info(f"Created new RAG session for {request.repo_url}")
111195
except ValueError as e:
112196
if "No valid documents with embeddings found" in str(e):
113197
logger.error(f"No valid embeddings found: {str(e)}")
@@ -202,10 +286,14 @@ async def handle_websocket_chat(websocket: WebSocket):
202286
rag_query = f"Contexts related to {request.filePath}"
203287
logger.info(f"Modified RAG query to focus on file: {request.filePath}")
204288

205-
# Try to perform RAG retrieval
289+
# Try to perform RAG retrieval with retry logic
206290
try:
207-
# This will use the actual RAG implementation
208-
retrieved_documents = request_rag(rag_query, language=request.language)
291+
# Use retry wrapper for resilient retrieval
292+
retrieved_documents = await generate_with_retry(
293+
request_rag, rag_query, None,
294+
request.provider, request.model,
295+
language=request.language
296+
)
209297

210298
if retrieved_documents and retrieved_documents[0].documents:
211299
# Format context for the prompt in a more structured way

0 commit comments

Comments
 (0)