diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml new file mode 100644 index 0000000..d8d0051 --- /dev/null +++ b/.pre-commit-config.yaml @@ -0,0 +1,24 @@ +repos: + - repo: https://github.com/pre-commit/pre-commit-hooks + rev: v4.6.0 + hooks: + - id: trailing-whitespace + - id: end-of-file-fixer + - id: check-yaml + args: [--allow-multiple-documents] + - id: check-toml + - id: check-json + - id: check-added-large-files + args: [--maxkb=500] + - id: check-merge-conflict + - id: detect-private-key + - id: debug-statements + - id: mixed-line-ending + args: [--fix=lf] + + - repo: https://github.com/astral-sh/ruff-pre-commit + rev: v0.4.10 + hooks: + - id: ruff + args: [--fix, --exit-non-zero-on-fix] + - id: ruff-format diff --git a/pyproject.toml b/pyproject.toml index c9c946c..8869841 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -72,3 +72,35 @@ testpaths = ["tests"] norecursedirs = [".git", ".routecode", "venv", ".venv", "build", "dist"] addopts = "--ignore=test.txt --ignore=test_progress.txt" python_files = "test_*.py" + +[tool.ruff] +target-version = "py310" +line-length = 100 +src = ["src"] + +[tool.ruff.lint] +select = [ + "E", # pycodestyle errors + "F", # pyflakes + "I", # isort (import sorting) + "W", # pycodestyle warnings + "UP", # pyupgrade + "B", # flake8-bugbear + "SIM", # flake8-simplify + "TCH", # type-checking imports +] +ignore = [ + "E501", # line too long (handled by formatter) + "B008", # do not perform function calls in argument defaults + "TCH004", # move import out of type-checking block +] +unfixable = [] + +[tool.ruff.lint.isort] +known-first-party = ["routecode"] +extra-standard-library = ["dataclasses"] + +[tool.ruff.format] +quote-style = "double" +indent-style = "space" +docstring-code-format = false diff --git a/src/routecode/agents/base.py b/src/routecode/agents/base.py index 3c1800b..09990e2 100644 --- a/src/routecode/agents/base.py +++ b/src/routecode/agents/base.py @@ -1,5 +1,6 @@ from abc import ABC, abstractmethod from typing import List, Dict, Optional, Any, AsyncGenerator +from .types import StreamChunk class AIProvider(ABC): @@ -13,9 +14,17 @@ async def ask( model: str, stream: bool = True, tools: Optional[List[Dict[str, Any]]] = None, - ) -> AsyncGenerator[Dict[str, Any], None]: + ) -> AsyncGenerator[StreamChunk, None]: """ - Send a prompt to the AI provider and return an async generator for the response chunks. + Send a prompt to the AI provider and return an async generator + yielding typed StreamChunk events. + + Chunk types: + - {"type": "text", "content": str} — response text + - {"type": "reasoning", "content": str} — reasoning tokens + - {"type": "tool_call", "tool_call": dict} — function call + - {"type": "usage", "usage": dict} — token usage + - {"type": "error", "content": str} — fatal error """ yield {} # Placeholder for abstract method diff --git a/src/routecode/agents/cloudflare_provider.py b/src/routecode/agents/cloudflare_provider.py new file mode 100644 index 0000000..7bc9bcf --- /dev/null +++ b/src/routecode/agents/cloudflare_provider.py @@ -0,0 +1,175 @@ +import json +import httpx +import os +from typing import List, Dict, Any, Optional, AsyncGenerator +from .base import AIProvider + + +class CloudflareProvider(AIProvider): + """ + Native Cloudflare Workers AI provider bypassing LiteLLM to avoid Pydantic issues. + """ + + def __init__( + self, + api_key: str, + account_id: Optional[str] = None, + base_url: Optional[str] = None, + models: Optional[List[Dict[str, Any]]] = None, + ): + super().__init__(api_key) + self.account_id = account_id or os.environ.get("CLOUDFLARE_ACCOUNT_ID") + self.base_url = base_url + self.models_list = models + + # Unpack JSON key if needed + if api_key.startswith("{") and api_key.endswith("}"): + try: + data = json.loads(api_key) + self.account_id = data.get("CLOUDFLARE_ACCOUNT_ID", self.account_id) + self.api_key = data.get("CLOUDFLARE_API_KEY", self.api_key) + except Exception: + pass + + async def ask( + self, + messages: List[Dict[str, Any]], + model: str, + stream: bool = True, + tools: Optional[List[Dict[str, Any]]] = None, + ) -> AsyncGenerator[Dict[str, Any], None]: + # Sanitize messages for Cloudflare (remove tool calls/results if present) + sanitized_messages = [] + for m in messages: + if m.get("role") in ["user", "system", "assistant"]: + msg = {"role": m["role"], "content": m.get("content") or ""} + sanitized_messages.append(msg) + + url = f"https://api.cloudflare.com/client/v4/accounts/{self.account_id}/ai/run/{model}" + headers = {"Authorization": f"Bearer {self.api_key}"} + + # Cloudflare native API expects 'messages' in the body + payload = {"messages": sanitized_messages, "stream": stream} + + from ..utils.logger import get_logger + logger = get_logger(__name__) + logger.debug(f"Cloudflare Request: POST {url}") + + try: + async with httpx.AsyncClient(timeout=60.0) as client: + if stream: + async with client.stream( + "POST", url, headers=headers, json=payload + ) as response: + if response.status_code != 200: + err_text = await response.aread() + logger.error(f"Cloudflare Error {response.status_code}: {err_text.decode()}") + yield { + "type": "error", + "content": f"Cloudflare error {response.status_code}: {err_text.decode()}", + } + return + + last_text_len = 0 + last_thought_len = 0 + + async for line in response.aiter_lines(): + if line.startswith("data:"): + data_str = line[len("data:"):].strip() + if data_str == "[DONE]": + break + try: + logger.debug(f"Raw Cloudflare Chunk: {data_str}") + chunk = json.loads(data_str) + + # Handle OpenAI-compatible streaming format (choices -> delta) + choices = chunk.get("choices", []) + if choices: + delta = choices[0].get("delta", {}) + + # Handle reasoning content + reasoning = delta.get("reasoning_content") + if reasoning: + # If reasoning is cumulative, yield only the new part + if len(reasoning) > last_thought_len: + yield {"type": "thought", "content": reasoning[last_thought_len:]} + last_thought_len = len(reasoning) + else: + # If it's a delta, just yield it + yield {"type": "thought", "content": reasoning} + + content = delta.get("content") + if content: + # If content is cumulative, yield only the new part + if len(content) > last_text_len: + yield {"type": "text", "content": content[last_text_len:]} + last_text_len = len(content) + else: + # If it's a delta, just yield it + yield {"type": "text", "content": content} + continue + + # Fallback for older Workers AI format (often cumulative) + content = chunk.get("response") or chunk.get("text") or chunk.get("result") + if content is None and "result" in chunk and isinstance(chunk["result"], dict): + content = chunk["result"].get("response") or chunk["result"].get("text") + + if content is not None: + content_str = str(content) + if len(content_str) > last_text_len: + yield {"type": "text", "content": content_str[last_text_len:]} + last_text_len = len(content_str) + else: + yield {"type": "text", "content": content_str} + except Exception as e: + logger.error(f"Failed to parse Cloudflare chunk: {e} | Raw: {data_str}") + continue + else: + resp = await client.post(url, headers=headers, json=payload) + if resp.status_code == 200: + data = resp.json() + result = data.get("result", {}) + + # Handle OpenAI-compatible non-streaming format + choices = result.get("choices", []) or data.get("choices", []) + if choices: + message = choices[0].get("message", {}) + reasoning = message.get("reasoning_content") + if reasoning: + yield {"type": "thought", "content": reasoning} + content = message.get("content") + if content: + yield {"type": "text", "content": content} + else: + # Fallback for older format + content = result.get("response") or result.get("text") or result.get("result") or data.get("response") + if content: + yield {"type": "text", "content": content} + else: + yield { + "type": "error", + "content": f"Cloudflare error {resp.status_code}: {resp.text}", + } + except Exception as e: + yield {"type": "error", "content": str(e)} + + async def get_models(self) -> List[Dict[str, Any]]: + if not self.account_id: + return self.models_list or [] + + url = f"https://api.cloudflare.com/client/v4/accounts/{self.account_id}/ai/models/search?task=Text%20Generation" + headers = {"Authorization": f"Bearer {self.api_key}"} + try: + async with httpx.AsyncClient(timeout=10.0) as client: + resp = await client.get(url, headers=headers) + if resp.status_code == 200: + data = resp.json() + result_data = data.get("result", []) + return [ + {"id": m["name"], "name": m["name"].split("/")[-1]} + for m in result_data + if "name" in m + ] + except Exception: + pass + return self.models_list or [] diff --git a/src/routecode/agents/litellm_provider.py b/src/routecode/agents/litellm_provider.py index 1a78267..d7546d0 100644 --- a/src/routecode/agents/litellm_provider.py +++ b/src/routecode/agents/litellm_provider.py @@ -2,8 +2,8 @@ import litellm from typing import List, Dict, Any, Optional, AsyncGenerator from .base import AIProvider +from .mapping import resolve_model_name, get_model_list_pattern -# Disable litellm's verbose logging unless needed litellm.set_verbose = False litellm.suppress_debug_info = True @@ -24,8 +24,27 @@ def __init__( self.provider_name = provider_name self.base_url = base_url self.models_list = models - # LiteLLM uses provider-specific model names like "anthropic/claude-3-opus-20240229" - # We'll prepend the provider if it's not already there. + + if api_key.startswith("{") and api_key.endswith("}"): + try: + import json + import os + + keys = json.loads(api_key) + for k, v in keys.items(): + os.environ[k] = v + + for k, v in keys.items(): + if "KEY" in k or "TOKEN" in k or "SECRET" in k: + self.api_key = v + break + + if self.base_url: + for k, v in keys.items(): + self.base_url = self.base_url.replace(f"${{{k}}}", v) + self.base_url = self.base_url.replace(f"${k}", v) + except Exception: + pass async def ask( self, @@ -34,30 +53,8 @@ async def ask( stream: bool = True, tools: Optional[List[Dict[str, Any]]] = None, ) -> AsyncGenerator[Dict[str, Any], None]: - # Prepare the model string for LiteLLM - litellm_model = model - if self.provider_name: - # Common LiteLLM provider prefixes - prefixes = [ - "openai/", - "anthropic/", - "gemini/", - "deepseek/", - "openrouter/", - "vertex_ai/", - "groq/", - "mistral/", - ] - has_prefix = any(model.startswith(p) for p in prefixes) - - if not has_prefix: - if self.provider_name == "google": - litellm_model = f"gemini/{model}" - else: - litellm_model = f"{self.provider_name}/{model}" - - # LiteLLM needs the API key. We can pass it directly or set it in environment. - # Passing it in acompletion is safer for multiple providers. + litellm_model = resolve_model_name(self.provider_name, model) + completion_args = { "model": litellm_model, "messages": messages, @@ -70,16 +67,17 @@ async def ask( if self.base_url: completion_args["base_url"] = self.base_url if self.provider_name == "openai": - # For custom OpenAI endpoints, use custom_llm_provider to avoid prefix issues completion_args["custom_llm_provider"] = "openai" - completion_args["model"] = model # Use unprefixed model name + completion_args["model"] = model + elif self.provider_name == "cloudflare": + completion_args["model"] = f"openai/{model}" + completion_args.pop("custom_llm_provider", None) if stream: completion_args["stream_options"] = {"include_usage": True} - if tools: + if tools and self.provider_name != "cloudflare": completion_args["tools"] = tools - # LiteLLM handles the translation of tools to the provider's format. try: response = await litellm.acompletion(**completion_args) @@ -154,45 +152,59 @@ async def ask( yield {"type": "error", "content": str(e)} async def get_models(self) -> List[Dict[str, Any]]: - """ - Return a list of models for the active provider. - Tries to fetch live models from LiteLLM first, falls back to metadata list. - """ - # 1. Try to fetch live list from LiteLLM try: - # For custom OpenAI endpoints and OpenRouter, we try a direct fetch if self.base_url and ( self.provider_name == "openai" or self.provider_name == "openrouter" ): - # Some providers don't support the 'openai/*' pattern in get_model_list - # but might work if we fetch directly from /models import httpx async with httpx.AsyncClient(timeout=10.0) as client: - # Try standard OpenAI /models endpoint url = self.base_url.rstrip("/") + "/models" headers = {"Authorization": f"Bearer {self.api_key}"} response = await client.get(url, headers=headers) if response.status_code == 200: data = response.json() - # Standard OpenAI response is {"data": [...]} models_data = data.get("data", []) if models_data: results = [] for m in models_data: m_id = m.get("id") if isinstance(m, dict) else str(m) if m_id: - # Strip prefix for display name (e.g. cohere/command-r -> command-r) display_name = ( m_id.split("/")[-1] if "/" in m_id else m_id ) results.append({"id": m_id, "name": display_name}) return results - # Standard LiteLLM fetch - pattern = f"{self.provider_name}/*" - if self.provider_name == "google": - pattern = "gemini/*" + pattern = get_model_list_pattern(self.provider_name) + + if self.provider_name == "cloudflare": + import os + + account_id = os.environ.get("CLOUDFLARE_ACCOUNT_ID") + api_key = os.environ.get("CLOUDFLARE_API_KEY") or self.api_key + if account_id and api_key: + try: + import httpx + + url = f"https://api.cloudflare.com/client/v4/accounts/{account_id}/ai/models/search?task=Text%20Generation" + headers = {"Authorization": f"Bearer {api_key}"} + async with httpx.AsyncClient(timeout=10.0) as client: + resp = await client.get(url, headers=headers) + if resp.status_code == 200: + data = resp.json() + result_data = data.get("result", []) + if result_data: + return [ + { + "id": m["name"], + "name": m["name"].split("/")[-1], + } + for m in result_data + if "name" in m + ] + except Exception: + pass live_models = await asyncio.to_thread( litellm.get_model_list, @@ -210,31 +222,7 @@ async def get_models(self) -> List[Dict[str, Any]]: except Exception: pass - # 2. Fallback to models_list (from models_api.json) if self.models_list: return self.models_list - # 3. Final hardcoded defaults for common providers - defaults = { - "openai": [ - {"id": "gpt-4o", "name": "GPT-4o (Omni)"}, - {"id": "gpt-4o-mini", "name": "GPT-4o Mini"}, - {"id": "gpt-4-turbo", "name": "GPT-4 Turbo"}, - {"id": "o1-preview", "name": "o1-preview (Reasoning)"}, - {"id": "o1-mini", "name": "o1-mini (Reasoning)"}, - ], - "anthropic": [ - {"id": "claude-3-5-sonnet-20240620", "name": "Claude 3.5 Sonnet"}, - {"id": "claude-3-opus-20240229", "name": "Claude 3 Opus"}, - {"id": "claude-3-haiku-20240307", "name": "Claude 3 Haiku"}, - ], - "google": [ - {"id": "gemini-1.5-pro", "name": "Gemini 1.5 Pro"}, - {"id": "gemini-1.5-flash", "name": "Gemini 1.5 Flash"}, - ], - "deepseek": [ - {"id": "deepseek-chat", "name": "DeepSeek Chat"}, - {"id": "deepseek-coder", "name": "DeepSeek Coder"}, - ], - } - return defaults.get(self.provider_name, []) + return [] diff --git a/src/routecode/agents/mapping.py b/src/routecode/agents/mapping.py new file mode 100644 index 0000000..6b841cb --- /dev/null +++ b/src/routecode/agents/mapping.py @@ -0,0 +1,52 @@ +""" +Single source of truth for provider-to-LiteLLM name mapping and model prefix logic. +""" + +PROVIDER_TO_PREFIX = { + "openai": "openai/", + "anthropic": "anthropic/", + "google": "gemini/", + "gemini": "gemini/", + "deepseek": "deepseek/", + "openrouter": "openrouter/", + "vertex_ai": "vertex_ai/", + "groq": "groq/", + "mistral": "mistral/", + "cloudflare": "openai/", +} + +LITELLM_PREFIXES = sorted(set(PROVIDER_TO_PREFIX.values())) + + +def get_model_prefix(provider_name: str) -> str: + """Returns the LiteLLM model prefix for a provider (e.g. 'openai/', 'gemini/').""" + return PROVIDER_TO_PREFIX.get(provider_name, f"{provider_name}/") + + +def resolve_model_name(provider_name: str, model: str) -> str: + """Prepends the LiteLLM prefix to model if not already present.""" + prefix = get_model_prefix(provider_name) + if model.startswith(tuple(LITELLM_PREFIXES)): + return model + return f"{prefix}{model}" + + +def get_model_list_pattern(provider_name: str) -> str: + """Returns the LiteLLM model list pattern for a provider.""" + return get_model_prefix(provider_name) + "*" + + +def normalize_provider_name(npm: str) -> str: + """Maps an npm-style metadata string to a canonical provider name.""" + npm_lower = npm.lower() + if "openai" in npm_lower: + return "openai" + if "anthropic" in npm_lower: + return "anthropic" + if "google" in npm_lower or "gemini" in npm_lower: + return "google" + if "deepseek" in npm_lower: + return "deepseek" + if "workers-ai" in npm_lower or ("cloudflare" in npm_lower and "gateway" not in npm_lower): + return "cloudflare" + return npm_lower diff --git a/src/routecode/agents/registry.py b/src/routecode/agents/registry.py index 096160b..dee542c 100644 --- a/src/routecode/agents/registry.py +++ b/src/routecode/agents/registry.py @@ -1,8 +1,9 @@ import json from typing import Dict, Any from .litellm_provider import LiteLLMProvider - +from .cloudflare_provider import CloudflareProvider from ..utils.paths import get_resource_path +from .mapping import normalize_provider_name MODELS_API_PATH = get_resource_path("models_api.json") @@ -20,50 +21,50 @@ def _load_registry() -> Dict[str, Any]: REGISTRY_DATA = _load_registry() -class DynamicLiteLLMProvider(LiteLLMProvider): +def get_provider_class(provider_id: str): """ - A LiteLLMProvider that configures itself from models_api.json metadata. + Returns a provider class factory that selects the best implementation + (Native or LiteLLM) based on metadata. """ + data = REGISTRY_DATA.get(provider_id, {}) + npm = data.get("npm", "") + base_url = data.get("api") - def __init__(self, api_key: str, provider_id: str): - data = REGISTRY_DATA.get(provider_id, {}) - base_url = data.get("api") - npm = data.get("npm", "") - - # Determine internal LiteLLM provider name from metadata hints - litellm_p = provider_id - if "openai" in npm: - litellm_p = "openai" - elif "anthropic" in npm: - litellm_p = "anthropic" - elif "google" in npm or "gemini" in npm: - litellm_p = "google" - elif "deepseek" in npm: - litellm_p = "deepseek" - - # Populate models list from metadata - models_dict = data.get("models", {}) - models_list = [ - {"id": mid, "name": m.get("name", mid)} for mid, m in models_dict.items() - ] + pid_lower = provider_id.lower() + is_native_cloudflare = ("workers-ai" in pid_lower) or ( + "cloudflare" in pid_lower and "gateway" not in pid_lower + ) + litellm_p = normalize_provider_name(npm) or provider_id - super().__init__(api_key, litellm_p, base_url=base_url, models=models_list) + class DynamicProvider: + def __init__(self, api_key: str): + models_dict = data.get("models", {}) + models_list = [ + {"id": mid, "name": m.get("name", mid)} + for mid, m in models_dict.items() + ] + if is_native_cloudflare: + self.impl = CloudflareProvider( + api_key, base_url=base_url, models=models_list + ) + else: + self.impl = LiteLLMProvider( + api_key, litellm_p, base_url=base_url, models=models_list + ) -def create_provider_factory(pid: str): - """Creates a provider class compatible with the existing PROVIDER_MAP interface.""" + async def ask(self, *args, **kwargs): + async for chunk in self.impl.ask(*args, **kwargs): + yield chunk - class SpecificProvider(DynamicLiteLLMProvider): - def __init__(self, api_key: str): - super().__init__(api_key, pid) + async def get_models(self): + return await self.impl.get_models() - return SpecificProvider + return DynamicProvider -# Build the map dynamically from JSON -PROVIDER_MAP = {pid: create_provider_factory(pid) for pid in REGISTRY_DATA.keys()} +PROVIDER_MAP = {pid: get_provider_class(pid) for pid in REGISTRY_DATA.keys()} -# Fallback to standard providers if JSON is missing or empty if not PROVIDER_MAP: class OpenAIProvider(LiteLLMProvider): @@ -74,18 +75,7 @@ class AnthropicProvider(LiteLLMProvider): def __init__(self, api_key: str): super().__init__(api_key, "anthropic") - class GoogleProvider(LiteLLMProvider): - def __init__(self, api_key: str): - super().__init__(api_key, "google") - - class OpenRouterProvider(LiteLLMProvider): - def __init__(self, api_key: str): - super().__init__(api_key, "openrouter") - PROVIDER_MAP = { "openai": OpenAIProvider, "anthropic": AnthropicProvider, - "google": GoogleProvider, - "openrouter": OpenRouterProvider, - "deepseek": lambda k: LiteLLMProvider(k, "deepseek"), } diff --git a/src/routecode/agents/types.py b/src/routecode/agents/types.py new file mode 100644 index 0000000..5e0b5a4 --- /dev/null +++ b/src/routecode/agents/types.py @@ -0,0 +1,37 @@ +""" +Typed stream chunks for the AI provider interface. + +The provider's ask() method yields StreamChunk dicts representing +individual events in the streaming response. +""" + +from typing import TypedDict, List, Dict, Any, Optional, Literal + + +class TextChunk(TypedDict): + type: Literal["text"] + content: str + + +class ReasoningChunk(TypedDict): + type: Literal["reasoning"] + content: str + + +class ToolCallChunk(TypedDict): + type: Literal["tool_call"] + tool_call: Dict[str, Any] + + +class UsageChunk(TypedDict): + type: Literal["usage"] + usage: Dict[str, Any] + + +class ErrorChunk(TypedDict): + type: Literal["error"] + content: str + + +StreamChunk = TextChunk | ReasoningChunk | ToolCallChunk | UsageChunk | ErrorChunk +"""Union of all possible stream chunk types yielded by ask().""" diff --git a/src/routecode/ascii_logo.md b/src/routecode/ascii_logo.md new file mode 100644 index 0000000..23c51f8 --- /dev/null +++ b/src/routecode/ascii_logo.md @@ -0,0 +1,20 @@ + ███████████ █████ █████████ █████ +░░███░░░░░███ ░░███ ███░░░░░███ ░░███ + ░███ ░███ ██████ █████ ████ ███████ ██████ ███ ░░░ ██████ ███████ ██████ + ░██████████ ███░░███░░███ ░███ ░░░███░ ███░░███░███ ███░░███ ███░░███ ███░░███ + ░███░░░░░███ ░███ ░███ ░███ ░███ ░███ ░███████ ░███ ░███ ░███░███ ░███ ░███████ + ░███ ░███ ░███ ░███ ░███ ░███ ░███ ███░███░░░ ░░███ ███░███ ░███░███ ░███ ░███░░░ + █████ █████░░██████ ░░████████ ░░█████ ░░██████ ░░█████████ ░░██████ ░░████████░░██████ +░░░░░ ░░░░░ ░░░░░░ ░░░░░░░░ ░░░░░ ░░░░░░ ░░░░░░░░░ ░░░░░░ ░░░░░░░░ ░░░░░░ + + + +second logo + +██████╗ ██████╗ ██╗ ██╗████████╗███████╗ ██████╗ ██████╗ ██████╗ ███████╗ +██╔══██╗██╔═══██╗██║ ██║╚══██╔══╝██╔════╝██╔════╝██╔═══██╗██╔══██╗██╔════╝ +██████╔╝██║ ██║██║ ██║ ██║ █████╗ ██║ ██║ ██║██║ ██║█████╗ +██╔══██╗██║ ██║██║ ██║ ██║ ██╔══╝ ██║ ██║ ██║██║ ██║██╔══╝ +██║ ██║╚██████╔╝╚██████╔╝ ██║ ███████╗╚██████╗╚██████╔╝██████╔╝███████╗ +╚═╝ ╚═╝ ╚═════╝ ╚═════╝ ╚═╝ ╚══════╝ ╚═════╝ ╚═════╝ ╚═════╝ ╚══════╝ + \ No newline at end of file diff --git a/src/routecode/commands/__init__.py b/src/routecode/commands/__init__.py index 5610adb..5c3f306 100644 --- a/src/routecode/commands/__init__.py +++ b/src/routecode/commands/__init__.py @@ -28,6 +28,13 @@ ) from .tasks import handle_tasks, handle_task_stop from .memory import handle_remember, handle_forget, handle_memories +from .skills import ( + handle_skill, + handle_skill_create, + handle_skill_find, + handle_skill_manage, + handle_skill_uninstall, +) COMMANDS = { "/help": handle_help, @@ -52,6 +59,11 @@ "/remember": handle_remember, "/forget": handle_forget, "/memories": handle_memories, + "/skill": handle_skill, + "/skill-create": handle_skill_create, + "/skill-find": handle_skill_find, + "/skill-manage": handle_skill_manage, + "/skill-uninstall": handle_skill_uninstall, "/exit": handle_exit, "/update": handle_update, } @@ -81,6 +93,11 @@ def get_command_metadata() -> Dict[str, str]: "/remember ": "Save a memory for future sessions", "/forget ": "Delete a saved memory", "/memories": "List all saved memories", + "/skill [name]": "Invoke a skill interactively or by name", + "/skill-create": "Create a new reusable skill", + "/skill-find": "List all installed skills", + "/skill-manage": "Enable or disable specific skills", + "/skill-uninstall [name]": "Permanently delete an external skill", "/exit": "Exit the session", "/update": "Check for and install RouteCode updates", } diff --git a/src/routecode/commands/config.py b/src/routecode/commands/config.py index 0acebd6..b6c652a 100644 --- a/src/routecode/commands/config.py +++ b/src/routecode/commands/config.py @@ -89,50 +89,121 @@ async def handle_provider(args: List[str], ctx: RouteCodeContext): break if not is_connected: - new_key = await RouteCodeDialog( - title=f"Setup {result.capitalize()}", - text=_ui.get_dialog_text( - f"Paste your {result} API key here and press Enter:", "input" - ), - password=True, - dialog_type="input", - ).run_async() - if new_key: - ctx.config.set_api_key(result, new_key) - await ctx.config.save_async() - print_success( - f"API key for {result} has been saved! You can now select its models in the model menu." - ) - break + p_info = models_db.get(result, {}) + env_vars = p_info.get("env", []) + + if len(env_vars) > 1: + # Multiple keys needed + collected_keys = {} + for ev in env_vars: + label = ev.replace("_", " ").title() + # User requested to see the key when adding it + val = await RouteCodeDialog( + title=f"Setup {result.capitalize()}", + text=_ui.get_dialog_text(f"Enter {label}:", "input"), + password=False, + dialog_type="input", + ).run_async() + if not val: + break + collected_keys[ev] = val + + if len(collected_keys) == len(env_vars): + import json + + ctx.config.set_api_key(result, json.dumps(collected_keys)) + await ctx.config.save_async() + print_success(f"Credentials for {result} have been saved!") + break + else: + continue else: - continue + # Single key + new_key = await RouteCodeDialog( + title=f"Setup {result.capitalize()}", + text=_ui.get_dialog_text( + f"Paste your {result} API key here and press Enter:", "input" + ), + password=False, + dialog_type="input", + ).run_async() + if new_key: + ctx.config.set_api_key(result, new_key) + await ctx.config.save_async() + print_success( + f"API key for {result} has been saved! You can now select its models in the model menu." + ) + break + else: + continue else: action = await RouteCodeDialog( title=f"Update {result.capitalize()}", text=_ui.get_dialog_text( - f"{result} is already connected. Do you want to update the API key?", + f"{result} is already connected. What would you like to do?", "button", ), - buttons=[("Update", True), ("Back", False)], + buttons=[ + ("Update", "update"), + ("Disconnect", "disconnect"), + ("Back", "back"), + ], dialog_type="button", ).run_async() - if action: - new_key = await RouteCodeDialog( - title=f"Update {result.capitalize()}", + if action == "update": + p_info = models_db.get(result, {}) + env_vars = p_info.get("env", []) + + if len(env_vars) > 1: + collected_keys = {} + for ev in env_vars: + label = ev.replace("_", " ").title() + val = await RouteCodeDialog( + title=f"Update {result.capitalize()}", + text=_ui.get_dialog_text(f"Enter {label}:", "input"), + password=False, + dialog_type="input", + ).run_async() + if not val: + break + collected_keys[ev] = val + + if len(collected_keys) == len(env_vars): + import json + + ctx.config.set_api_key(result, json.dumps(collected_keys)) + await ctx.config.save_async() + print_success(f"Credentials for {result} have been updated.") + break + else: + new_key = await RouteCodeDialog( + title=f"Update {result.capitalize()}", + text=_ui.get_dialog_text( + f"Paste your new {result} API key:", "input" + ), + password=False, + dialog_type="input", + ).run_async() + if new_key: + ctx.config.set_api_key(result, new_key) + await ctx.config.save_async() + print_success(f"API key for {result} has been updated.") + break + elif action == "disconnect": + confirm = await RouteCodeDialog( + title="Confirm Disconnect", text=_ui.get_dialog_text( - f"Paste your new {result} API key:", "input" + f"Are you sure you want to disconnect {result}?", "button" ), - password=True, - dialog_type="input", + buttons=[("Yes", True), ("No", False)], + dialog_type="button", ).run_async() - if new_key: - ctx.config.set_api_key(result, new_key) + if confirm: + del ctx.config.api_keys[result] await ctx.config.save_async() - print_success(f"API key for {result} has been updated.") + print_success(f"Disconnected {result}.") break - else: - continue else: continue @@ -279,11 +350,12 @@ def on_favorite(val): if val and ":" in val: p, m = val.split(":", 1) ctx.config.toggle_favorite(p, m) - return _build_values() + return [p, m] in ctx.config.favorites + return False menu.on_favorite = on_favorite - def on_connect_provider_stub(): + def on_connect_provider_stub(val): pass menu.on_connect_provider = on_connect_provider_stub @@ -392,7 +464,7 @@ async def handle_config(args: List[str], ctx: RouteCodeContext): async def handle_theme(args: List[str], ctx: RouteCodeContext): from ..ui import THEMES, apply_theme - from .core import _refresh_screen + from ..ui.renderables import refresh_screen if args: name = args[0] @@ -400,7 +472,7 @@ async def handle_theme(args: List[str], ctx: RouteCodeContext): apply_theme(name) ctx.config.theme = name await ctx.config.save_async() - _refresh_screen(ctx) + refresh_screen(ctx) print_success(f"Theme set to: {name}") else: avail = ", ".join(THEMES.keys()) @@ -425,7 +497,7 @@ def on_theme_hover(theme_name): apply_theme(result) ctx.config.theme = result await ctx.config.save_async() - _refresh_screen(ctx) + refresh_screen(ctx) print_success(f"Theme set to: {result}") else: apply_theme(original_theme) diff --git a/src/routecode/commands/skills.py b/src/routecode/commands/skills.py new file mode 100644 index 0000000..7dbfbc8 --- /dev/null +++ b/src/routecode/commands/skills.py @@ -0,0 +1,210 @@ +from typing import List +from ..ui import print_success, print_error, RouteCodeDialog +from ..core import RouteCodeContext +from ..domain.skills import discover_skills, run_skill + + +async def handle_skill(args: List[str], ctx: RouteCodeContext): + skills = discover_skills(only_enabled=True) + if not skills: + ctx.console.print( + "[dim]No active skills available. Check /skill-manage to enable some.[/dim]" + ) + return + + if args: + skill_name = args[0] + if skill_name in skills: + arg_str = " ".join(args[1:]) + ctx.console.print_tool_call(f"Skill({skill_name})", {"args": arg_str}) + result = run_skill(skills[skill_name], arg_str, ctx) + ctx.console.print_tool_result(result) + return + else: + print_error(f"Skill '{skill_name}' not found.") + return + + # No args: show interactive picker + choices = [(name, f"{name}: {s.description}") for name, s in skills.items()] + + result = await RouteCodeDialog( + title="Invoke Skill", + text="Select a skill to execute:", + values=choices, + dialog_type="radio", + ).run_async() + + if result: + # Ask for optional arguments + arg_str = await RouteCodeDialog( + title=f"Arguments for {result}", + text=f"Enter optional arguments for '{result}':", + dialog_type="input", + ).run_async() + + arg_str = arg_str or "" + ctx.console.print_tool_call(f"Skill({result})", {"args": arg_str}) + res = run_skill(skills[result], arg_str, ctx) + ctx.console.print_tool_result(res) + + +async def handle_skill_create(args: List[str], ctx: RouteCodeContext): + name = await RouteCodeDialog( + title="Create New Skill", + text="Enter a name for the skill (e.g., 'deploy-lambda'):", + dialog_type="input", + ).run_async() + if not name: + return + + description = await RouteCodeDialog( + title="Skill Description", + text="Enter a short description of what this skill does:", + dialog_type="input", + ).run_async() + if not description: + return + + prompt = await RouteCodeDialog( + title="Skill Prompt", + text="Enter the system prompt or instructions for this skill:", + dialog_type="input", + ).run_async() + if not prompt: + return + + context = await RouteCodeDialog( + title="Execution Context", + text="Select how this skill should run:", + values=[ + ("inline", "Inline (expands in main chat)"), + ("fork", "Fork (runs as sub-agent)"), + ], + dialog_type="radio", + ).run_async() + if not context: + return + + # Use the tool logic to create it + from ..tools.skill import SkillCreatorTool + + tool = SkillCreatorTool() + result = tool._run( + name=name, description=description, prompt=prompt, context=context + ) + + if result.get("success"): + print_success(result["message"]) + else: + print_error(result.get("error", "Failed to create skill.")) + + +def handle_skill_find(args: List[str], ctx: RouteCodeContext): + from rich.table import Table + + skills = discover_skills() + if not skills: + ctx.console.print("[dim]No skills found.[/dim]") + return + + table = Table( + title="Available Skills", show_header=True, header_style="bold magenta" + ) + table.add_column("Name", style="cyan") + table.add_column("Context", style="green") + table.add_column("Description") + + for name, s in skills.items(): + status = "[green]Enabled[/green]" if s.enabled else "[red]Disabled[/red]" + source = "[yellow]Bundled[/yellow]" if s.is_bundled else "[blue]External[/blue]" + table.add_row(name, s.context, source, status, s.description) + + ctx.console.print(table) + + +async def handle_skill_manage(args: List[str], ctx: RouteCodeContext): + from ..config import config + + skills = discover_skills() + if not skills: + print_error("No skills found.") + return + + choices = [] + for name, s in skills.items(): + label = ( + f"{name} ({'Enabled' if s.enabled else 'Disabled'}) - {s.description[:60]}" + ) + choices.append((name, label)) + + result = await RouteCodeDialog( + title="Manage Skills", + text="Select a skill to toggle its enabled status:", + values=choices, + dialog_type="radio", # Using radio for single toggle, or I could use checkboxes if I had them + ).run_async() + + if result: + if result in config.disabled_skills: + config.disabled_skills.remove(result) + print_success(f"Skill '{result}' enabled.") + else: + config.disabled_skills.append(result) + print_success(f"Skill '{result}' disabled.") + config.save() + + +async def handle_skill_uninstall(args: List[str], ctx: RouteCodeContext): + import shutil + + skills = discover_skills() + + if not args: + # Show a picker for external skills only + external = [(n, n) for n, s in skills.items() if not s.is_bundled] + if not external: + print_error("No external skills found to uninstall.") + return + + target = await RouteCodeDialog( + title="Uninstall Skill", + text="Select an EXTERNAL skill to permanently delete:", + values=external, + dialog_type="radio", + ).run_async() + else: + target = args[0] + + if not target or target not in skills: + if args: + print_error(f"Skill '{target}' not found.") + return + + skill = skills[target] + if skill.is_bundled: + print_error( + "Cannot uninstall bundled skills. Use /skill-manage to disable them instead." + ) + return + + confirm = await RouteCodeDialog( + title="Confirm Uninstall", + text=f"Are you sure you want to PERMANENTLY DELETE the skill '{target}'?", + dialog_type="confirm", + ).run_async() + + if confirm: + try: + # Skills are in folders now + folder = skill.path.parent + shutil.rmtree(folder) + print_success(f"Skill '{target}' uninstalled successfully.") + + # Also remove from disabled list if present + from ..config import config + + if target in config.disabled_skills: + config.disabled_skills.remove(target) + config.save() + except Exception as e: + print_error(f"Failed to uninstall skill: {e}") diff --git a/src/routecode/config/settings.py b/src/routecode/config/settings.py index ba5c67c..927700a 100644 --- a/src/routecode/config/settings.py +++ b/src/routecode/config/settings.py @@ -2,27 +2,48 @@ from pathlib import Path from typing import Dict, Optional, Any + CONFIG_DIR = Path.home() / ".routecode" CONFIG_FILE = CONFIG_DIR / "config.json" +def _resolve_env(env_var: str, file_value: Any, default: Any) -> Any: + """Explicit tiered precedence: env var > file > default.""" + env_val = os.environ.get(env_var) + if env_val is not None: + return env_val + if file_value is not None: + return file_value + return default + + class Config: + """ + Application configuration with explicit tiered precedence: + 1. Environment variables (highest) + 2. Persistent JSON (~/.routecode/config.json) + 3. Hardcoded defaults (lowest) + """ + def __init__(self): from ..utils.storage import AtomicJsonStore - self._provider: str = os.environ.get("ROUTECODE_PROVIDER", "openrouter") - self._model: str = os.environ.get( - "ROUTECODE_MODEL", "anthropic/claude-3.5-sonnet" - ) - self.personality: str = os.environ.get("ROUTECODE_PERSONALITY", "default") - self.theme: str = os.environ.get("ROUTECODE_THEME", "lava") + self.store = AtomicJsonStore(CONFIG_FILE) + self.api_keys: Dict[str, str] = {} + + # Defaults (lowest tier) + self._provider: str = "openrouter" + self._model: str = "anthropic/claude-3.5-sonnet" + self.personality: str = "default" + self.theme: str = "lava" self.allowlist: list = [] self.denylist: list = [] - self.api_keys: Dict[str, str] = {} - self.recent_models: list = [] # List of (provider, model) tuples - self.favorites: list = [] # List of (provider, model) tuples + self.recent_models: list = [] + self.favorites: list = [] + self.disabled_skills: list = [] self.last_update_check: float = 0.0 - self.store = AtomicJsonStore(CONFIG_FILE) + + # Load file (middle tier) then env (highest tier) self._load() self._load_env_keys() @@ -47,29 +68,37 @@ def model(self, value: str): if self._model != value: self._model = value self.add_recent_model(self._provider, value) - from ..core.events import bus - - bus.emit("config.model_changed", model=value) def _load(self): + """Loads config from JSON file, applying env var overrides.""" data = self.store.load() - if data: - if not os.environ.get("ROUTECODE_PROVIDER"): - self._provider = data.get("provider", self._provider) - if not os.environ.get("ROUTECODE_MODEL"): - self._model = data.get("model", self._model) - if not os.environ.get("ROUTECODE_PERSONALITY"): - self.personality = data.get("personality", self.personality) - if not os.environ.get("ROUTECODE_THEME"): - self.theme = data.get("theme", self.theme) - self.allowlist = data.get("allowlist", []) - self.denylist = data.get("denylist", []) - self.api_keys = data.get("api_keys", {}) - self.recent_models = data.get("recent_models", []) - self.favorites = data.get("favorites", []) - self.last_update_check = data.get("last_update_check", 0.0) + if not data: + return + + self._provider = _resolve_env( + "ROUTECODE_PROVIDER", data.get("provider"), self._provider + ) + self._model = _resolve_env( + "ROUTECODE_MODEL", data.get("model"), self._model + ) + self.personality = _resolve_env( + "ROUTECODE_PERSONALITY", data.get("personality"), self.personality + ) + self.theme = _resolve_env( + "ROUTECODE_THEME", data.get("theme"), self.theme + ) + self.allowlist = data.get("allowlist", self.allowlist) + self.denylist = data.get("denylist", self.denylist) + self.recent_models = data.get("recent_models", self.recent_models) + self.favorites = data.get("favorites", self.favorites) + self.disabled_skills = data.get("disabled_skills", self.disabled_skills) + self.last_update_check = data.get("last_update_check", self.last_update_check) + + # File API keys — env vars merged on top in _load_env_keys() + self.api_keys = data.get("api_keys", {}) def _load_env_keys(self): + """Merges environment-variable API keys into api_keys dict.""" from ..agents.registry import PROVIDER_MAP for provider in PROVIDER_MAP.keys(): @@ -83,6 +112,12 @@ def _load_env_keys(self): if "opencode-go" not in self.api_keys: self.api_keys["opencode-go"] = opencode_key + def reload(self): + """Re-reads config from disk and re-applies env overrides. + Values that differ from file are overwritten; no events are emitted + (handlers would not be ready during initial load).""" + self._load() + def to_dict(self) -> Dict[str, Any]: """Returns a dictionary representation of the current configuration.""" return { @@ -95,6 +130,7 @@ def to_dict(self) -> Dict[str, Any]: "denylist": self.denylist, "recent_models": self.recent_models, "favorites": self.favorites, + "disabled_skills": self.disabled_skills, "last_update_check": self.last_update_check, } @@ -123,7 +159,7 @@ def add_recent_model(self, provider: str, model: str): if item in self.recent_models: self.recent_models.remove(item) self.recent_models.insert(0, item) - self.recent_models = self.recent_models[:10] # Keep last 10 + self.recent_models = self.recent_models[:10] self.save() def toggle_favorite(self, provider: str, model: str): diff --git a/src/routecode/config/system_prompt.py b/src/routecode/config/system_prompt.py index abe19d2..7198543 100644 --- a/src/routecode/config/system_prompt.py +++ b/src/routecode/config/system_prompt.py @@ -17,7 +17,7 @@ def _build_identity_section() -> str: When responding, start with a block for your internal reasoning and plan, then provide your response.""" -def _build_workspace_section() -> str: +async def _build_workspace_section_async() -> str: import platform from datetime import datetime @@ -31,13 +31,12 @@ def _build_workspace_section() -> str: try: import subprocess - # Try to get git tree first - res = subprocess.run( + res = await asyncio.to_thread( + subprocess.run, ["git", "ls-tree", "-r", "--name-only", "HEAD"], capture_output=True, text=True, check=False, - shell=True if os_name == "Windows" else False, ) if res.returncode == 0: lines = res.stdout.strip().splitlines() @@ -48,7 +47,6 @@ def _build_workspace_section() -> str: else: project_structure = "\n".join(lines) else: - # Fallback to listing current directory files = os.listdir(cwd) project_structure = "\n".join(files[:40]) if len(files) > 40: @@ -284,12 +282,13 @@ async def compute_system_prompt(ctx: "RouteCodeContext") -> str: _build_configuration_tiers_async(ctx), _build_git_section_async(), _build_context_section_async(), + _build_workspace_section_async(), ) - config_tiers_sect, git_sect, context_sect = dynamic_results + config_tiers_sect, git_sect, context_sect, workspace_sect = dynamic_results sections += [ - _build_workspace_section(), + workspace_sect, _build_tools_section(), _build_skill_section(), _build_env_section(), diff --git a/src/routecode/core/__init__.py b/src/routecode/core/__init__.py index d00027f..096a76d 100644 --- a/src/routecode/core/__init__.py +++ b/src/routecode/core/__init__.py @@ -10,6 +10,18 @@ from .events import bus, EventBus from .context_manager import ContextManager from .path_guard import PathGuard +from .container import AppContainer +from .constants import ( + MAX_TOOL_RESULT_CHARS, + MAX_ATTACHMENT_CHARS, + MAX_FETCH_CHARS, + MAX_MEMORIES, + MAX_MEMORY_CHARS, + MAX_RECENT_MODELS, + MAX_TASK_HISTORY, + MAX_ORCHESTRATOR_TURNS, + SUMMARY_KEEP_COUNT, +) __all__ = [ "RouteCodeContext", @@ -23,4 +35,14 @@ "EventBus", "ContextManager", "PathGuard", + "AppContainer", + "MAX_TOOL_RESULT_CHARS", + "MAX_ATTACHMENT_CHARS", + "MAX_FETCH_CHARS", + "MAX_MEMORIES", + "MAX_MEMORY_CHARS", + "MAX_RECENT_MODELS", + "MAX_TASK_HISTORY", + "MAX_ORCHESTRATOR_TURNS", + "SUMMARY_KEEP_COUNT", ] diff --git a/src/routecode/core/constants.py b/src/routecode/core/constants.py new file mode 100644 index 0000000..015c17a --- /dev/null +++ b/src/routecode/core/constants.py @@ -0,0 +1,22 @@ +""" +Centralized constants for the RouteCode application. +All magic numbers that control limits, truncation, and capacity +should be defined here with clear documentation. +""" + +# ── Tool result & file attachment limits ──────────────────────────────── +MAX_TOOL_RESULT_CHARS = 50000 +MAX_ATTACHMENT_CHARS = 50000 +MAX_FETCH_CHARS = 50000 + +# ── Memory limits ─────────────────────────────────────────────────────── +MAX_MEMORIES = 50 +MAX_MEMORY_CHARS = 500 + +# ── UI / Config limits ────────────────────────────────────────────────── +MAX_RECENT_MODELS = 10 + +# ── Task / Orchestrator limits ────────────────────────────────────────── +MAX_TASK_HISTORY = 50 +MAX_ORCHESTRATOR_TURNS = 20 +SUMMARY_KEEP_COUNT = 7 diff --git a/src/routecode/core/container.py b/src/routecode/core/container.py new file mode 100644 index 0000000..14aab0c --- /dev/null +++ b/src/routecode/core/container.py @@ -0,0 +1,148 @@ +import asyncio +from pathlib import Path +from typing import Optional, TYPE_CHECKING + +if TYPE_CHECKING: + from .events import EventBus + from .state import SessionState + from .path_guard import PathGuard + from .memory import MemoryManager + from .tokenizer import TokenizerService + from .context import RouteCodeContext + from ..config import Config + from ..domain.task_manager import TaskManager + + +class AppContainer: + """ + Service container with explicit initialization phases. + + Phase 1: build() — Create session-scoped services, wire module-level globals + Phase 2: initialize() — Load persisted state (async) + Phase 3: validate() — Assert all services ready + Phase 4: start() — Hook into event loop + Phase 5: shutdown() — Clean teardown + + Module-level globals (config, registry, cost_estimator) are imported + and referenced directly rather than duplicated, so code that imports + them continues to work unchanged. + """ + + def __init__(self, config_dir: Path): + self._config_dir = config_dir + self._built = False + self._initialized = False + + # Tier 1: Module-level globals (reference, not duplicate) + self.bus: Optional["EventBus"] = None + self.config: Optional["Config"] = None + + # Tier 2: Session-scoped services (fresh per session) + self.state: Optional["SessionState"] = None + self.path_guard: Optional["PathGuard"] = None + self.memory: Optional["MemoryManager"] = None + + # Tier 3: Lazy runtime services + self._tokenizer: Optional["TokenizerService"] = None + self._task_manager: Optional["TaskManager"] = None + self._orchestrator: Optional["AgentOrchestrator"] = None + self._ctx: Optional["RouteCodeContext"] = None + + def build(self) -> "AppContainer": + from .events import EventBus + from .state import SessionState + from .path_guard import PathGuard + from .memory import MemoryManager + + # Tier 1: EventBus (new instance replaces module-level global) + self.bus = EventBus() + self._connect_bus() + + # Config — reference the existing module-level singleton + from ..config import config as global_config + + self.config = global_config + self.config.store.cleanup_stale_temps() + + # Tier 2: Session-scoped services + self.state = SessionState() + self.path_guard = PathGuard() + self.memory = MemoryManager(self._config_dir / "memory") + + self._built = True + return self + + def _connect_bus(self): + """Replace the module-level EventBus singleton with the container's instance.""" + from . import events as events_mod + + events_mod.bus = self.bus + + async def initialize(self) -> "AppContainer": + if not self._built: + raise RuntimeError("build() must be called before initialize()") + await self.memory._load_async() + self._initialized = True + return self + + def validate(self) -> "AppContainer": + if not self._initialized: + raise RuntimeError("initialize() must be called before validate()") + assert self.bus is not None, "EventBus not created" + assert self.config is not None, "Config not created" + assert self.state is not None, "SessionState not created" + assert self.memory is not None, "MemoryManager not created" + return self + + def set_event_loop(self, loop: asyncio.AbstractEventLoop): + self.ctx.loop = loop + + def shutdown(self): + self.bus.clear() + self._orchestrator = None + self._tokenizer = None + self._ctx = None + self._initialized = False + self._built = False + + @property + def tokenizer(self) -> "TokenizerService": + if self._tokenizer is None: + from .tokenizer import TokenizerService + + self._tokenizer = TokenizerService(bus=self.bus) + self.state.bind_tokenizer(self._tokenizer, bus=self.bus) + return self._tokenizer + + @property + def task_manager(self) -> "TaskManager": + if self._task_manager is None: + from ..domain.task_manager import TaskManager + + self._task_manager = TaskManager() + return self._task_manager + + @property + def orchestrator(self) -> "AgentOrchestrator": + if self._orchestrator is None: + from .orchestrator import AgentOrchestrator + + self._orchestrator = AgentOrchestrator(self.ctx) + return self._orchestrator + + @property + def ctx(self) -> "RouteCodeContext": + if self._ctx is None: + from .context import RouteCodeContext + from ..ui.console import console + + self._ctx = RouteCodeContext( + state=self.state, + config=self.config, + console=console, + task_manager=self.task_manager, + memory=self.memory, + path_guard=self.path_guard, + bus=self.bus, + ) + return self._ctx diff --git a/src/routecode/core/context.py b/src/routecode/core/context.py index cdb1a95..75fdebc 100644 --- a/src/routecode/core/context.py +++ b/src/routecode/core/context.py @@ -5,11 +5,11 @@ from .state import SessionState from ..config import Config from .memory import MemoryManager - from .path_guard import PathGuard if TYPE_CHECKING: from ..domain.task_manager import TaskManager + from .events import EventBus @dataclass @@ -25,4 +25,5 @@ class RouteCodeContext: task_manager: "TaskManager" memory: MemoryManager path_guard: PathGuard + bus: "EventBus" = None loop: Optional[asyncio.AbstractEventLoop] = None diff --git a/src/routecode/core/context_manager.py b/src/routecode/core/context_manager.py index 7f27689..85c3f1e 100644 --- a/src/routecode/core/context_manager.py +++ b/src/routecode/core/context_manager.py @@ -1,4 +1,5 @@ import asyncio +import logging from typing import List, Dict, Any, TYPE_CHECKING from .events import bus from .state import count_tokens @@ -7,6 +8,8 @@ from .context import RouteCodeContext from .history import ConversationHistory +logger = logging.getLogger(__name__) + class ContextManager: """ @@ -36,18 +39,13 @@ def check_and_compact(self, history: "ConversationHistory", model: str) -> bool: compacted = self.microcompact(history.get_messages()) if len(compacted) < original_len: history.set_messages(compacted) - # Recalculate tokens for the retained history retained_content = " ".join( m.get("content", "") or "" for m in compacted ) if hasattr(self.ctx.state, "_tokenizer") and self.ctx.state._tokenizer: - new_token_count = self.ctx.state._tokenizer.count_tokens( + self.ctx.state._tokenizer.recalculate( retained_content, model ) - self.ctx.state._tokenizer.load_state( - new_token_count, self.ctx.state.estimated_cost - ) - self.ctx.state.tokens_used = new_token_count else: self.ctx.state.tokens_used = count_tokens(retained_content, model) @@ -153,10 +151,10 @@ async def summarize_compact(self, history: "ConversationHistory") -> bool: task_id = f"c{abs(hash(compact_prompt)) % 10**7}" task_manager.create("Context compaction", None, task_id) - # Run as async task on the main loop asyncio.create_task( _run_sub_agent_async(compact_prompt, 3, task_id, self.ctx) ) return True except Exception: + logger.exception("Context compaction failed") return False diff --git a/src/routecode/core/event_types.py b/src/routecode/core/event_types.py new file mode 100644 index 0000000..2c71d15 --- /dev/null +++ b/src/routecode/core/event_types.py @@ -0,0 +1,97 @@ +""" +Typed event dataclasses for the RouteCode EventBus. + +Each dataclass represents a distinct event with typed fields. +Use with bus.emit_typed() / bus.on_typed() for static analysis, +IDE autocompletion, and protection against typos. +""" + +from dataclasses import dataclass +from typing import Any, Dict, Optional + + +# ── Config events ─────────────────────────────────────────────────────── + +@dataclass +class ProviderChanged: + provider: str + + +# ── Context management events ─────────────────────────────────────────── + +@dataclass +class ContextThresholdWarning: + usage: float + type: str # "micro" | "full" + + +@dataclass +class ContextCompacted: + type: str # "micro" + saved: int + + +# ── History events ────────────────────────────────────────────────────── + +@dataclass +class HistoryAppended: + message: Dict[str, Any] + + +@dataclass +class HistoryCleared: + pass + + +@dataclass +class HistoryRewound: + count: int + + +@dataclass +class HistoryReset: + pass + + +# ── Session events ────────────────────────────────────────────────────── + +@dataclass +class SessionReset: + pass + + +# ── Task events ───────────────────────────────────────────────────────── + +@dataclass +class TaskCreated: + task_id: str + description: str + + +@dataclass +class TaskCompleted: + task_id: str + description: str + result: Optional[Dict[str, Any]] = None + + +@dataclass +class TaskFailed: + task_id: str + description: str + error: str + + +# ── Tokenizer events ──────────────────────────────────────────────────── + +@dataclass +class TokenUsageUpdated: + tokens: int + cost: float + + +# ── UI events ─────────────────────────────────────────────────────────── + +@dataclass +class ThemeChanged: + name: str diff --git a/src/routecode/core/events.py b/src/routecode/core/events.py index 3a086b3..b629914 100644 --- a/src/routecode/core/events.py +++ b/src/routecode/core/events.py @@ -1,8 +1,6 @@ -from typing import Callable, Dict, List, Any, Optional +from typing import Callable, Dict, List, Any, Optional, Type import logging import asyncio -import inspect -from weakref import WeakSet logger = logging.getLogger(__name__) @@ -10,20 +8,28 @@ class EventBus: """ A lightweight, thread-safe event bus for decoupling routecode subsystems. - Allows modules to emit events without knowing about their listeners. + + Supports both string-based events (emit/on) and typed dataclass events + (emit_typed/on_typed) for improved static analysis and type safety. """ def __init__(self): self._handlers: Dict[str, List[Callable]] = {} - self._active_tasks = WeakSet() + self._typed_to_name: Dict[type, str] = {} def on(self, event: str, handler: Callable): - """Registers a handler for a specific event.""" + """Registers a handler for a specific event name.""" if event not in self._handlers: self._handlers[event] = [] if handler not in self._handlers[event]: self._handlers[event].append(handler) + def on_typed(self, event_class: type, handler: Callable): + """Registers a handler for a typed event dataclass.""" + name = self._resolve_typed_name(event_class) + self._typed_to_name[event_class] = name + self.on(name, handler) + def off(self, event: str, handler: Callable): """Deregisters a handler for a specific event.""" if event in self._handlers and handler in self._handlers[event]: @@ -36,64 +42,57 @@ def clear(self, event: Optional[str] = None): self._handlers[event] = [] else: self._handlers = {} + self._typed_to_name.clear() def emit(self, event: str, **data): """ - Synchronously emits an event. Async handlers are fired as background tasks. + Synchronously emits a string-named event. + All handlers must be synchronous. """ if event not in self._handlers: return for handler in self._handlers[event]: try: - if inspect.iscoroutinefunction(handler): - self._fire_and_forget(handler, data, event) - else: - handler(**data) + handler(**data) except Exception as e: logger.error(f"Error in sync event handler for {event}: {e}") + def emit_typed(self, event): + """ + Synchronously emits a typed dataclass event. + Converts the event fields to kwargs and dispatches to all registered + handlers (both typed and string-based). + """ + name = self._resolve_typed_name(type(event)) + self.emit(name, **vars(event)) + async def emit_async(self, event: str, **data): """ Asynchronously emits an event. Properly awaits all async handlers. + Sync handlers are called directly. """ if event not in self._handlers: return - tasks = [] for handler in self._handlers[event]: try: - if inspect.iscoroutinefunction(handler): - tasks.append(handler(**data)) + if asyncio.iscoroutinefunction(handler): + await handler(**data) else: handler(**data) except Exception as e: logger.error(f"Error in async event handler for {event}: {e}") - if tasks: - results = await asyncio.gather(*tasks, return_exceptions=True) - for res in results: - if isinstance(res, Exception): - logger.error(f"Async error in handler for {event}: {res}") - - def _fire_and_forget( - self, handler: Callable, data: Dict[str, Any], event_name: str - ): - try: - loop = asyncio.get_running_loop() - task = loop.create_task(handler(**data)) - self._active_tasks.add(task) - task.add_done_callback(lambda t: self._active_tasks.discard(t)) - - def _log_error(t): - try: - t.result() - except Exception as e: - logger.error(f"Background task error for {event_name}: {e}") - - task.add_done_callback(_log_error) - except RuntimeError: - pass + @staticmethod + def _resolve_typed_name(event_class: type) -> str: + """Derives a stable event name from a typed event class.""" + module = getattr(event_class, "__module__", "") + qualname = getattr(event_class, "__qualname__", event_class.__name__) + if module and module != "builtins": + # core.event_types.ProviderChanged -> core.event_types.ProviderChanged + return f"{module}.{qualname}" + return qualname # Global bus instance diff --git a/src/routecode/core/history.py b/src/routecode/core/history.py index 2e3c41e..7e00b90 100644 --- a/src/routecode/core/history.py +++ b/src/routecode/core/history.py @@ -1,23 +1,33 @@ -from typing import List, Dict, Any, Optional +from collections import deque +from typing import List, Dict, Any, Optional, Deque from .events import bus class ConversationHistory: """ Unified manager for conversation messages. - Wraps a list of messages and provides safe mutation methods. + + Uses a deque with a configurable max size to bound memory growth. + Safe mutation methods emit events for downstream listeners. """ - def __init__(self, messages: Optional[List[Dict[str, Any]]] = None): - self._messages: List[Dict[str, Any]] = messages if messages is not None else [] + def __init__( + self, + messages: Optional[List[Dict[str, Any]]] = None, + maxlen: int = 2000, + ): + self._messages: Deque[Dict[str, Any]] = deque( + messages if messages is not None else [], + maxlen=maxlen, + ) def append(self, message: Dict[str, Any]): self._messages.append(message) bus.emit("history.appended", message=message) def extend(self, messages: List[Dict[str, Any]]): - self._messages.extend(messages) for m in messages: + self._messages.append(m) bus.emit("history.appended", message=m) def clear(self): @@ -25,27 +35,31 @@ def clear(self): bus.emit("history.cleared") def rewind(self, count: int): - """Removes the last N turns/messages.""" + """Removes the last N messages.""" if count <= 0: return - self._messages = self._messages[:-count] + for _ in range(min(count, len(self._messages))): + self._messages.pop() bus.emit("history.rewound", count=count) def set_messages(self, messages: Any): """Completely replaces the history.""" if isinstance(messages, ConversationHistory): - self._messages = messages.get_messages() + source = messages.get_messages() else: - self._messages = list(messages) + source = list(messages) + self._messages.clear() + for m in source: + self._messages.append(m) bus.emit("history.reset") def get_messages(self) -> List[Dict[str, Any]]: """Returns the raw list for iteration or API calls.""" - return self._messages + return list(self._messages) def to_list(self) -> List[Dict[str, Any]]: - """Returns a copy of the underlying list.""" - return self._messages[:] + """Returns a shallow copy as a list.""" + return list(self._messages) def snapshot(self) -> List[Dict[str, Any]]: """Alias for to_list().""" @@ -55,6 +69,8 @@ def __len__(self): return len(self._messages) def __getitem__(self, index): + if isinstance(index, slice): + return list(self._messages)[index] return self._messages[index] def __iter__(self): diff --git a/src/routecode/core/orchestrator.py b/src/routecode/core/orchestrator.py index a89784a..4e26b47 100644 --- a/src/routecode/core/orchestrator.py +++ b/src/routecode/core/orchestrator.py @@ -6,6 +6,7 @@ from .history import ConversationHistory from .context_manager import ContextManager from .events import bus +from .constants import MAX_TOOL_RESULT_CHARS, MAX_ORCHESTRATOR_TURNS from ..utils.storage import AtomicJsonStore from ..utils.logger import get_logger @@ -94,8 +95,7 @@ async def run( self, history: ConversationHistory, hooks: Optional[OrchestratorHooks] = None, - max_turns: int = 20, - tool_executor: Optional[Callable] = None, + max_turns: int = MAX_ORCHESTRATOR_TURNS, ): """ Runs the core agent loop: LLM call -> Tool execution -> State update. @@ -108,11 +108,16 @@ async def run( hooks = hooks or OrchestratorHooks() tool_executor = tool_executor or self._call_tool_safe - tool_schemas = [ - tool.to_json_schema() - for tool in registry._tools.values() - if tool.name != "task" - ] + tool_schemas = [] + for tool in registry._tools.values(): + if tool.name == "task": + continue + try: + tool_schemas.append(tool.to_json_schema()) + except Exception as e: + logger.error(f"Failed to generate schema for tool '{tool.name}': {str(e)}") + # We skip failing tools to prevent crashing the whole session + continue turn_count = 0 while turn_count < max_turns: @@ -226,7 +231,8 @@ async def _exec(tid=tc_id, n=name, a=args): results = await asyncio.gather(*tasks) for tid, n, res in results: await self._append_tool_result(history, tid, n, res) - await hooks.on_tool_result(n, res, 0.0) + hook_res = res.to_dict() if hasattr(res, 'to_dict') else res + await hooks.on_tool_result(n, hook_res, 0.0) else: # Sequential execution for tc_id, name, args in items: @@ -235,7 +241,8 @@ async def _exec(tid=tc_id, n=name, a=args): res = await tool_executor(name, args) elapsed = time.time() - ts await self._append_tool_result(history, tc_id, name, res) - await hooks.on_tool_result(name, res, elapsed) + hook_res = res.to_dict() if hasattr(res, 'to_dict') else res + await hooks.on_tool_result(name, hook_res, elapsed) except Exception as e: await hooks.on_error(f"Orchestrator error: {str(e)}") @@ -260,35 +267,52 @@ def _partition_tools(self, tool_inputs: list) -> list: return batches async def _call_tool_safe(self, name: str, args: dict) -> Dict[str, Any]: + from ..tools.base import ToolResult + tool = registry.get_tool(name) if not tool: - return {"error": f"Tool not found: {name}"} + return ToolResult(success=False, error=f"Tool not found: {name}") try: - # Most tools are still synchronous, so we run them in a thread to avoid blocking the event loop. - return await asyncio.to_thread( - tool.execute, **args, ctx=self.ctx, provider=self.provider - ) + return await tool.execute(ctx=self.ctx, provider=self.provider, **args) except Exception as e: - return {"error": str(e)} + return ToolResult(success=False, error=str(e)) async def _append_tool_result( self, history: ConversationHistory, tc_id: str, name: str, - result: Dict[str, Any], + result: "ToolResult", ): + from ..tools.base import ToolResult + MAX_CHARS = 50000 - content = json.dumps(result) + + if ToolResult.is_error(result): + error_msg = result.error if isinstance(result, ToolResult) else result.get("error", "Unknown error") + history.append( + { + "role": "system", + "content": f"Tool {name} failed: {error_msg}", + } + ) + self.tokenizer.add_usage( + self.tokenizer.count_tokens(error_msg, self.ctx.config.model), + self.ctx.config.model, + ) + return + + payload = result.to_dict() if isinstance(result, ToolResult) else result + content = json.dumps(payload) if len(content) > MAX_CHARS: path = CONFIG_DIR / "tool_results" / f"{tc_id}.json" store = AtomicJsonStore(path) - await store.save_async(result) - result["content"] = ( - f"[Result too large, saved to {path}]\n{result.get('content', '')[:2000]}" + await store.save_async(payload) + payload["content"] = ( + f"[Result too large, saved to {path}]\n{payload.get('content', '')[:2000]}" ) - content = json.dumps(result) + content = json.dumps(payload) history.append( {"role": "tool", "tool_call_id": tc_id, "name": name, "content": content} diff --git a/src/routecode/core/path_guard.py b/src/routecode/core/path_guard.py index 9a69e2b..efe20e1 100644 --- a/src/routecode/core/path_guard.py +++ b/src/routecode/core/path_guard.py @@ -21,22 +21,28 @@ def resolve(self, path: str) -> Tuple[Optional[str], Optional[str]]: """ Resolves a path relative to the current workspace and validates sandboxing. Returns (resolved_absolute_path, error_message). + + Guards against: + - Absolute path traversal (e.g. /etc/passwd) + - Prefix attacks (e.g. /my_dir vs /my_dir_secret) + - Symlink escape (e.g. workspace/link -> /etc/passwd) """ ws = self.get_workspace() try: - # Handle both relative and absolute paths joined = os.path.join(ws, path) if not os.path.isabs(path) else path resolved = os.path.abspath(os.path.normpath(joined)) resolved = os.path.realpath(resolved) except (ValueError, OSError): return None, f"Invalid path format: {path}" - # Ensure the resolved path is within the workspace - # We add a trailing separator to avoid 'prefix' attacks (e.g. /my_dir and /my_dir_secret) - ws_sep = ws if ws.endswith(os.sep) else ws + os.sep + # Canonicalize workspace root too, so symlinks in the root are also resolved + ws_canonical = os.path.realpath(ws) + + # Trailing-separator check prevents prefix attacks + ws_sep = ws_canonical if ws_canonical.endswith(os.sep) else ws_canonical + os.sep res_sep = resolved if resolved.endswith(os.sep) else resolved + os.sep - if not res_sep.startswith(ws_sep) and resolved != ws: + if not res_sep.startswith(ws_sep) and resolved != ws_canonical: return None, f"Path escapes workspace sandbox: {path}" return resolved, None diff --git a/src/routecode/core/state.py b/src/routecode/core/state.py index 2946deb..49428bf 100644 --- a/src/routecode/core/state.py +++ b/src/routecode/core/state.py @@ -29,27 +29,28 @@ class SessionState: model: Optional[str] = None workspace_path: Optional[str] = None - def bind_tokenizer(self, tokenizer): + def bind_tokenizer(self, tokenizer, bus=None): self._tokenizer = tokenizer + (bus or self._get_bus()).on( + "tokenizer.usage_updated", self._on_usage_updated + ) + + @staticmethod + def _get_bus(): from .events import bus - bus.on("tokenizer.usage_updated", self._on_usage_updated) + return bus + def _on_usage_updated(self, tokens: int, cost: float, **kwargs): self.tokens_used = tokens self.estimated_cost = cost def get_context_usage(self, model: str) -> float: - """Returns the current context usage percentage.""" + """Returns the current context usage percentage via the bound tokenizer.""" if hasattr(self, "_tokenizer") and self._tokenizer: return self._tokenizer.get_context_usage_percent(model) - - from ..utils.costs import cost_estimator - - _, ctx_limit, _ = cost_estimator.calculate_cost(0, 0, model) - if ctx_limit <= 0: - return 0.0 - return (self.tokens_used / ctx_limit) * 100 + return 0.0 def reset_context_warning(self): """Resets the context warning flag, typically after compaction.""" diff --git a/src/routecode/core/tokenizer.py b/src/routecode/core/tokenizer.py index 2eeaa97..da9448a 100644 --- a/src/routecode/core/tokenizer.py +++ b/src/routecode/core/tokenizer.py @@ -1,16 +1,28 @@ -from typing import Optional -from .events import bus +from typing import Optional, TYPE_CHECKING + +if TYPE_CHECKING: + from .events import EventBus class TokenizerService: """ - Standalone service for tracking token usage and estimating costs. - Decoupled from the main application state to allow independent orchestration. + Single source of truth for token counting, cost estimation, and context usage. + + All token queries pass through this service. State binding via + SessionState.bind_tokenizer() subscribes to usage events emitted here. """ - def __init__(self): + def __init__(self, bus: Optional["EventBus"] = None): self.tokens_used: int = 0 self.estimated_cost: float = 0.0 + self._bus = bus + + def _get_bus(self): + if self._bus is not None: + return self._bus + from .events import bus + + return bus def count_tokens(self, text: str, model: str) -> int: from ..utils.costs import cost_estimator @@ -24,10 +36,6 @@ def add_usage( input_tokens: Optional[int] = None, output_tokens: Optional[int] = None, ): - """ - Records token usage. If precise input/output splits are provided, uses them. - Otherwise treats 'count' as total and estimates 50/50 split. - """ from ..utils.costs import cost_estimator if input_tokens is not None and output_tokens is not None: @@ -40,7 +48,7 @@ def add_usage( cost, _, _ = cost_estimator.calculate_cost(count // 2, count // 2, model) self.estimated_cost += cost - bus.emit( + self._get_bus().emit( "tokenizer.usage_updated", tokens=self.tokens_used, cost=self.estimated_cost ) @@ -52,6 +60,22 @@ def get_context_usage_percent(self, model: str) -> float: return 0.0 return (self.tokens_used / ctx_limit) * 100 + def recalculate(self, content: str, model: str): + """Recalculate token count and cost from content string (used after compaction).""" + new_count = self.count_tokens(content, model) + self.load_state(new_count, self._reestimate_cost(new_count, model)) + self._get_bus().emit( + "tokenizer.usage_updated", tokens=self.tokens_used, cost=self.estimated_cost + ) + + def _reestimate_cost(self, token_count: int, model: str) -> float: + from ..utils.costs import cost_estimator + + cost, _, _ = cost_estimator.calculate_cost( + token_count // 2, token_count // 2, model + ) + return cost + def load_state(self, tokens: int, cost: float): self.tokens_used = tokens self.estimated_cost = cost diff --git a/src/routecode/domain/attachments.py b/src/routecode/domain/attachments.py index 5bc0f7a..9927f06 100644 --- a/src/routecode/domain/attachments.py +++ b/src/routecode/domain/attachments.py @@ -42,8 +42,11 @@ def load_attachment(path: str) -> Optional[Dict]: } elif att_type == "pdf": - content = f"[PDF file: {name}. Text extraction requires PyMuPDF. File at: {resolved}]" - return {"type": "text", "name": name, "content": content, "path": resolved} + raise NotImplementedError( + f"PDF text extraction not supported. " + f"Install PyMuPDF (pip install pymupdf) to read PDFs. " + f"File at: {resolved}" + ) elif att_type == "text": with open(resolved, "r", encoding="utf-8", errors="replace") as f: diff --git a/src/routecode/domain/git.py b/src/routecode/domain/git.py index 80bab57..57846f6 100644 --- a/src/routecode/domain/git.py +++ b/src/routecode/domain/git.py @@ -41,17 +41,5 @@ async def run_git_async(cmd: str) -> str: def get_git_context() -> str: - """Synchronous fallback that runs the async version.""" - try: - loop = asyncio.get_running_loop() - if loop.is_running(): - # If we're in a loop, we can't easily run it sync without a thread - from concurrent.futures import ThreadPoolExecutor - - with ThreadPoolExecutor() as executor: - return executor.submit(asyncio.run, get_git_context_async()).result() - except RuntimeError: - return asyncio.run(get_git_context_async()) - - # Fallback if loop detection fails or other issues - return "" + """Synchronous fallback — only callable outside a running event loop.""" + return asyncio.run(get_git_context_async()) diff --git a/src/routecode/domain/skills.py b/src/routecode/domain/skills.py index d986386..7dae2a1 100644 --- a/src/routecode/domain/skills.py +++ b/src/routecode/domain/skills.py @@ -29,6 +29,16 @@ def __init__(self, path: Path): self.tools: List[str] = [] self._parse() + @property + def enabled(self) -> bool: + from ..config import config + + return self.name not in config.disabled_skills + + @property + def is_bundled(self) -> bool: + return "bundled_skills" in str(self.path).replace("\\", "/") + def _parse(self): content = self.path.read_text(encoding="utf-8") metadata, body = parse_frontmatter(content) @@ -59,7 +69,7 @@ def _parse(self): _skill_cache_mtime: float = 0.0 -def discover_skills() -> Dict[str, Skill]: +def discover_skills(only_enabled: bool = False) -> Dict[str, Skill]: """ Discovers all available skills from the configured skill directories. Uses MTIME-based caching to avoid expensive filesystem scans. @@ -107,11 +117,14 @@ def discover_skills() -> Dict[str, Skill]: _skill_cache = skills _skill_cache_mtime = current_mtime + + if only_enabled: + return {n: s for n, s in skills.items() if s.enabled} return skills def get_skill_prompts() -> str: - skills = discover_skills() + skills = discover_skills(only_enabled=True) if not skills: return "" lines = ["## Available Skills"] diff --git a/src/routecode/domain/task_manager.py b/src/routecode/domain/task_manager.py index 058efcd..7f92a83 100644 --- a/src/routecode/domain/task_manager.py +++ b/src/routecode/domain/task_manager.py @@ -85,22 +85,25 @@ def fail(self, task_id: str, error: str): def kill(self, task_id: str) -> bool: killed = False + worker = None with self._lock: if task_id in self._tasks: record = self._tasks[task_id] if record.status == "running": record.status = "killed" record.completed_at = time.time() - - # If it's an asyncio Task, cancel it immediately - if isinstance(record.worker, asyncio.Task): - try: - record.worker.cancel() - except Exception: - pass + worker = record.worker record.worker = None killed = True + # Cancel outside the lock to avoid deadlock if the task's + # done callback tries to acquire the same lock. + if isinstance(worker, asyncio.Task): + try: + worker.cancel() + except Exception: + pass + if killed: self.prune() return killed diff --git a/src/routecode/main.py b/src/routecode/main.py index 1fb4e95..0550f16 100644 --- a/src/routecode/main.py +++ b/src/routecode/main.py @@ -10,6 +10,40 @@ ) +def _open_debug_window(log_file): + """Opens a separate terminal window that tails the log file in real-time.""" + import os + import sys + + if not log_file.exists(): + log_file.parent.mkdir(parents=True, exist_ok=True) + log_file.write_text("", encoding="utf-8") + + if sys.platform == "win32": + import subprocess + + log_path = str(log_file) + powershell_cmd = f'Get-Content "{log_path}" -Wait -Tail 50' + cmd = f'start "RouteCode Logs" powershell -NoExit -Command {powershell_cmd}' + subprocess.Popen(cmd, shell=True) + elif sys.platform == "darwin": + import subprocess + + terminal = "Terminal" if os.path.exists("/Applications/Utilities/Terminal.app") else "iTerm" + script = f'tell application "{terminal}" to do script "tail -f {log_file}"' + subprocess.Popen(["osascript", "-e", script]) + else: + import subprocess + + terminals = ["x-terminal-emulator", "gnome-terminal", "xterm", "konsole"] + for term in terminals: + try: + subprocess.Popen([term, "-e", f"tail -f {log_file}"]) + break + except FileNotFoundError: + continue + + @app.callback(invoke_without_command=True) def main( ctx: typer.Context, @@ -28,9 +62,11 @@ def main( update: bool = typer.Option( False, "--update", help="Check for and install the latest version of RouteCode" ), + debug: bool = typer.Option( + False, "--debug", "-d", help="Development mode: opens log window at DEBUG level" + ), ): """routecode: An AI assistant for your terminal.""" - # Check for pending updates and apply them before doing anything else from .updater import apply_pending_update apply_pending_update() @@ -38,9 +74,29 @@ def main( if ctx.invoked_subcommand is not None: return - from .utils.logger import setup_logging + from .utils.logger import setup_logging, LOG_FILE + + if debug: + from .utils.logger import get_logger - setup_logging() + setup_logging(level="DEBUG") + _open_debug_window(LOG_FILE) + + import sys as _sys + original_excepthook = _sys.excepthook + + def _debug_excepthook(typ, val, tb): + import traceback + get_logger("main").error( + "Unhandled exception:\n%s", + "".join(traceback.format_exception(typ, val, tb)), + ) + original_excepthook(typ, val, tb) + + _sys.excepthook = _debug_excepthook + get_logger("main").debug("Debug mode enabled, log window opened") + else: + setup_logging() if update: from .updater import check_for_update, perform_update diff --git a/src/routecode/tools/base.py b/src/routecode/tools/base.py index 047685c..221b36a 100644 --- a/src/routecode/tools/base.py +++ b/src/routecode/tools/base.py @@ -1,11 +1,68 @@ import asyncio +import logging from abc import ABC, abstractmethod +from dataclasses import dataclass, field from typing import Any, Callable, Dict, List, Optional, Type, TYPE_CHECKING from pydantic import BaseModel if TYPE_CHECKING: from ..core import RouteCodeContext +logger = logging.getLogger(__name__) + + +@dataclass +class ToolResult: + """Typed result from a tool execution.""" + success: bool + content: str = "" + error: Optional[str] = None + metadata: Dict[str, Any] = field(default_factory=dict) + + @staticmethod + def from_value(value: Any) -> "ToolResult": + """Normalizes any tool return value into a ToolResult. + + Handles: + - ToolResult instances (pass-through) + - dicts with 'success'/'error' keys (tool convention) + - raw strings (implicit success) + - other types (wrapped as-is) + """ + if isinstance(value, ToolResult): + return value + if isinstance(value, dict): + if "error" in value: + return ToolResult( + success=False, + error=str(value["error"]), + content=value.get("content", ""), + metadata={k: v for k, v in value.items() if k not in ("error", "content", "success")}, + ) + if "success" in value: + return ToolResult( + success=bool(value.get("success", True)), + content=str(value.get("content", "")), + error=str(value["error"]) if "error" in value else None, + metadata={k: v for k, v in value.items() if k not in ("success", "content", "error")}, + ) + return ToolResult(success=True, content=str(value.get("content", "")), metadata=value) + if isinstance(value, str): + return ToolResult(success=True, content=value) + return ToolResult(success=True, content=str(value)) + + def to_dict(self) -> Dict[str, Any]: + if self.success: + return {"success": True, "content": self.content, **self.metadata} + return {"success": False, "error": self.error} + + @staticmethod + def is_error(result: Any) -> bool: + """Check if a result represents a tool error.""" + if isinstance(result, ToolResult): + return not result.success + return isinstance(result, dict) and "error" in result + class BaseTool(ABC): name: str @@ -46,7 +103,8 @@ async def execute( provider: Optional[Any] = None, **kwargs, ) -> Any: - return await asyncio.to_thread(self._run, ctx=ctx, provider=provider, **kwargs) + raw = await asyncio.to_thread(self._run, ctx=ctx, provider=provider, **kwargs) + return ToolResult.from_value(raw) @abstractmethod def _run(self, **kwargs) -> Any: @@ -84,7 +142,7 @@ async def execute_tool( ) -> Any: tool = self.get_tool(name) if not tool: - return {"error": f"Tool not found: {name}"} + return ToolResult(success=False, error=f"Tool not found: {name}") self.run_pre_hooks(name, args) @@ -104,12 +162,13 @@ async def _next(t, a, c): try: result = await pipeline(tool, args, ctx) - self.run_post_hooks(name, result) - return result + tr = ToolResult.from_value(result) + self.run_post_hooks(name, tr.to_dict()) + return tr except Exception as e: - result = {"error": str(e)} - self.run_post_hooks(name, result) - return result + tr = ToolResult(success=False, error=str(e)) + self.run_post_hooks(name, tr.to_dict()) + return tr def register(self, tool: BaseTool): self._tools[tool.name] = tool @@ -134,14 +193,14 @@ def run_pre_hooks(self, name: str, args: Dict): try: fn(name, args) except Exception: - pass + logger.exception("Pre-hook failed for tool '%s'", name) def run_post_hooks(self, name: str, result: Dict): for fn in self._post_hooks: try: fn(name, result) except Exception: - pass + logger.exception("Post-hook failed for tool '%s'", name) def parse_and_validate(self, name: str, arguments: Any) -> Dict[str, Any]: """ diff --git a/src/routecode/tools/skill.py b/src/routecode/tools/skill.py index bf409ed..9bcf463 100644 --- a/src/routecode/tools/skill.py +++ b/src/routecode/tools/skill.py @@ -104,10 +104,14 @@ def _run( return {"success": False, "error": f"Failed to create skill: {str(e)}"} +class EmptySchema(BaseModel): + pass + + class FindSkillsTool(BaseTool): name = "find_skills" description = "List all currently installed skills and their descriptions." - input_schema = BaseModel + input_schema = EmptySchema def prompt(self) -> str: return "- find_skills: List all installed skills and discover your extended capabilities." diff --git a/src/routecode/ui/dialogs/base.py b/src/routecode/ui/dialogs/base.py index 9db101d..a511b38 100644 --- a/src/routecode/ui/dialogs/base.py +++ b/src/routecode/ui/dialogs/base.py @@ -1,27 +1,6 @@ -from ..console import _mirror_output import asyncio from typing import Any, Optional -from prompt_toolkit.application.current import get_app -from prompt_toolkit.layout.containers import Float -from ..terminal import TerminalManager - - -def _get_backdrop_ansi() -> str: - """Generates a dimmed ANSI screenshot of the current terminal state.""" - full_ansi = _mirror_output.getvalue() - from prompt_toolkit.output.defaults import create_output - - try: - h = create_output().get_size().rows - except Exception: - import shutil - - h = shutil.get_terminal_size().lines - lines = full_ansi.splitlines() - recent_lines = lines[-h:] - ansi_content = "\n".join(recent_lines) - # Strong dimming for the backdrop - return f"\033[2m\033[38;5;238m{ansi_content}\033[0m" +from .manager import DialogManager def get_dialog_text(main_text: str, dialog_type: str = "radio") -> str: @@ -39,67 +18,25 @@ def get_dialog_text(main_text: str, dialog_type: str = "radio") -> str: class BaseModalLayer: """ Abstract base class for modal dialogs and overlays. - Provides standard lifecycle management for Float injection, - mouse tracking, focus trapping, global dimming state, and cleanup. + Lifeycle management is delegated to DialogManager. + Subclasses implement _build_container() and _get_focus_target(). + + Key bindings resolve the dialog by calling self.future.set_result(value). """ def __init__(self): self.future: Optional[asyncio.Future] = None def _build_container(self) -> Any: - """Override to return the prompt_toolkit container (e.g., Shadow) to inject.""" raise NotImplementedError def _get_focus_target(self) -> Any: - """Override to return the widget that should receive focus initially.""" raise NotImplementedError async def run_async(self) -> Any: - import asyncio - - current_app = get_app() - is_injected = ( - current_app - and current_app.is_running - and hasattr(current_app.layout.container, "floats") - ) - - if not is_injected: - # Fallback for headless testing or non-injected states - raise RuntimeError( - "BaseModalLayer must be run within an active Application with floats." - ) - self.future = asyncio.Future() - menu_container = self._build_container() - focus_target = self._get_focus_target() - - menu_float = Float(content=menu_container, transparent=False) - current_app.layout.container.floats.append(menu_float) - previous_focus = current_app.layout.current_window - - if focus_target: - current_app.layout.focus(focus_target) - - if hasattr(current_app, "routecode_repl"): - current_app.routecode_repl.is_modal_open = True - current_app.routecode_repl.update_style() - - current_app.invalidate() - TerminalManager.enable_mouse_tracking() - - try: - return await self.future - finally: - TerminalManager.disable_mouse_tracking() - if menu_float in current_app.layout.container.floats: - current_app.layout.container.floats.remove(menu_float) - if previous_focus: - try: - current_app.layout.focus(previous_focus) - except Exception: - pass - if hasattr(current_app, "routecode_repl"): - current_app.routecode_repl.is_modal_open = False - current_app.routecode_repl.update_style() - current_app.invalidate() + return await DialogManager.run_dialog( + container=self._build_container(), + future=self.future, + focus_target=self._get_focus_target(), + ) diff --git a/src/routecode/ui/dialogs/manager.py b/src/routecode/ui/dialogs/manager.py new file mode 100644 index 0000000..8fe584c --- /dev/null +++ b/src/routecode/ui/dialogs/manager.py @@ -0,0 +1,79 @@ +""" +Centralized dialog lifecycle management. + +DialogManager handles Float injection, focus save/restore, modal state, +mouse tracking, and cleanup — so individual dialog classes only need to +implement their layout. +""" + +import asyncio +from typing import Any, Optional +from prompt_toolkit.layout.containers import Float, Container +from prompt_toolkit.application.current import get_app +from ..terminal import TerminalManager + + +class DialogManager: + """ + Manages the lifecycle of a modal dialog Float. + + Usage: + future = asyncio.Future() + result = await DialogManager.run_dialog( + container=dialog_container, + future=future, + focus_target=search_field, + ) + """ + + @staticmethod + async def run_dialog( + container: Container, + future: Optional["asyncio.Future[Any]"] = None, + focus_target: Any = None, + ) -> Any: + """ + Injects a dialog Float, awaits a future, cleans up. + + The dialog signals completion by calling: + future.set_result(value) + + If no future is provided, one is created internally. In that case, + the dialog must have access to resolve it another way (e.g. via + an on_open callback). + """ + future = future or asyncio.Future() + app = get_app() + + if not app or not app.is_running or not hasattr(app.layout.container, "floats"): + raise RuntimeError("Dialog must be run within an active Application with floats") + + dialog_float = Float(content=container, transparent=False) + app.layout.container.floats.append(dialog_float) + previous_focus = app.layout.current_window + + if focus_target: + app.layout.focus(focus_target) + + if hasattr(app, "routecode_repl"): + app.routecode_repl.is_modal_open = True + app.routecode_repl.update_style() + + app.invalidate() + TerminalManager.enable_mouse_tracking() + + try: + return await future + finally: + TerminalManager.disable_mouse_tracking() + if dialog_float in app.layout.container.floats: + app.layout.container.floats.remove(dialog_float) + if previous_focus: + try: + app.layout.focus(previous_focus) + except Exception: + pass + if hasattr(app, "routecode_repl"): + app.routecode_repl.is_modal_open = False + app.routecode_repl.update_style() + app.invalidate() diff --git a/src/routecode/ui/dialogs/palette.py b/src/routecode/ui/dialogs/palette.py index dd268cc..17c4eca 100644 --- a/src/routecode/ui/dialogs/palette.py +++ b/src/routecode/ui/dialogs/palette.py @@ -48,11 +48,10 @@ def _(event): @kb.add("enter") def _(event): if menu_list.values: - self.result = menu_list.current_value + self.result = menu_list.values[menu_list._selected_index][0] if not self.future.done(): self.future.set_result(self.result) - @kb.add("up", eager=True) def _(event): menu_list._selected_index -= 1 event.app.invalidate() @@ -263,7 +262,7 @@ def _(event): @kb.add("enter") def _(event): if menu_list.values: - self.result = menu_list.current_value + self.result = menu_list.values[menu_list._selected_index][0] if not self.future.done(): self.future.set_result(self.result) diff --git a/src/routecode/ui/dialogs/widgets.py b/src/routecode/ui/dialogs/widgets.py index c513eca..0f6fd36 100644 --- a/src/routecode/ui/dialogs/widgets.py +++ b/src/routecode/ui/dialogs/widgets.py @@ -11,7 +11,6 @@ from prompt_toolkit.layout.containers import ConditionalContainer from prompt_toolkit.filters import to_filter, has_completions, is_done from prompt_toolkit.layout.scrollable_pane import ScrollOffsets -from prompt_toolkit.layout.dimension import Dimension class HoverCompletionsMenuControl(CompletionsMenuControl): @@ -35,14 +34,58 @@ def mouse_handler(self, mouse_event): return super().mouse_handler(mouse_event) +def _make_list_mouse_handler( + self, + *, + get_index=None, + on_click=None, + can_select=None, + invalidate=False, +): + """ + Builds a mouse handler for list-style widgets. + + Parameters: + get_index: callable(y) -> int|None — resolves mouse y to item index. + on_click: callable(idx) — called on MOUSE_UP. + can_select: callable(idx) -> bool — filters which items accept clicks. + invalidate: if True, calls get_app().invalidate() on scroll. + """ + if get_index is None: + get_index = lambda y: y + + def mouse_handler(mouse_event: MouseEvent) -> None: + y = mouse_event.position.y + if mouse_event.event_type == MouseEventType.MOUSE_MOVE: + idx = get_index(y) + if idx is not None: + self._selected_index = idx + elif mouse_event.event_type == MouseEventType.MOUSE_UP: + idx = get_index(y) + if idx is not None and (can_select is None or can_select(idx)): + self._current_index = idx + if on_click: + on_click(idx) + elif mouse_event.event_type == MouseEventType.SCROLL_UP: + self._selected_index -= 1 + if invalidate: + get_app().invalidate() + elif mouse_event.event_type == MouseEventType.SCROLL_DOWN: + self._selected_index += 1 + if invalidate: + get_app().invalidate() + + return mouse_handler + + class HoverCompletionsMenu(ConditionalContainer): """CompletionsMenu with mouse-hover highlighting support.""" - def __init__( - self, max_height=None, scroll_offset=0, extra_filter=None, z_index=10**8 - ): + def __init__(self, max_height=None, scroll_offset=0, extra_filter=None): extra_filter = to_filter(extra_filter if extra_filter is not None else True) + from prompt_toolkit.layout.dimension import Dimension + window = Window( content=HoverCompletionsMenuControl(), width=Dimension(min=8), @@ -63,6 +106,7 @@ def __init__( ) + class HoverRadioList: """A lightweight list widget for dialogs with hover and click support.""" @@ -115,19 +159,14 @@ def _selected_index(self, value): self._current_index = value % len(self.values) def _get_text_fragments(self): - def mouse_handler(mouse_event: MouseEvent) -> None: - if mouse_event.event_type == MouseEventType.MOUSE_MOVE: - self._selected_index = mouse_event.position.y - elif mouse_event.event_type == MouseEventType.MOUSE_UP: - idx = mouse_event.position.y - if idx < len(self.values): - self._current_index = idx - if self._on_enter: - self._on_enter() - elif mouse_event.event_type == MouseEventType.SCROLL_UP: - self._selected_index -= 1 - elif mouse_event.event_type == MouseEventType.SCROLL_DOWN: - self._selected_index += 1 + mouse_handler = _make_list_mouse_handler( + self, + on_click=lambda idx: ( + self._on_enter() + if self._on_enter + else None + ), + ) result = [] menu_width = 40 @@ -245,24 +284,12 @@ def _get_text_fragments(self): y_to_idx = {} current_y = 0 - def mouse_handler(mouse_event: MouseEvent) -> None: - y = mouse_event.position.y - - if mouse_event.event_type == MouseEventType.MOUSE_MOVE: - idx = y_to_idx.get(y) - if idx is not None: - self._selected_index = idx - elif mouse_event.event_type == MouseEventType.MOUSE_UP: - idx = y_to_idx.get(y) - if idx is not None: - self._selected_index = idx - self._handle_enter() - elif mouse_event.event_type == MouseEventType.SCROLL_UP: - self._selected_index -= 1 - get_app().invalidate() - elif mouse_event.event_type == MouseEventType.SCROLL_DOWN: - self._selected_index += 1 - get_app().invalidate() + mouse_handler = _make_list_mouse_handler( + self, + get_index=lambda y: y_to_idx.get(y), + on_click=lambda idx: self._handle_enter(), + invalidate=True, + ) result = [] menu_width = 40 @@ -372,30 +399,22 @@ def _get_text_fragments(self): y_to_idx = {} current_y = 0 - def mouse_handler(mouse_event: MouseEvent) -> None: - y = mouse_event.position.y - - if mouse_event.event_type == MouseEventType.MOUSE_MOVE: - idx = y_to_idx.get(y) - if idx is not None: - self._selected_index = idx - elif mouse_event.event_type == MouseEventType.MOUSE_UP: - idx = y_to_idx.get(y) - if idx is not None and not self.values[idx][2]: - self._current_index = idx - if hasattr(self, "_on_enter"): - self._on_enter() - elif mouse_event.event_type == MouseEventType.SCROLL_UP: - self._selected_index -= 1 - elif mouse_event.event_type == MouseEventType.SCROLL_DOWN: - self._selected_index += 1 + mouse_handler = _make_list_mouse_handler( + self, + get_index=lambda y: y_to_idx.get(y), + on_click=lambda idx: ( + self._on_enter() + if hasattr(self, "_on_enter") + else None + ), + can_select=lambda idx: not self.values[idx][2], + ) result = [] menu_width = 58 for i, (value, label, is_header, description, tag) in enumerate(self.values): if is_header: - # Header takes 3 lines: empty, label, empty y_to_idx[current_y] = i y_to_idx[current_y + 1] = i y_to_idx[current_y + 2] = i @@ -403,7 +422,6 @@ def mouse_handler(mouse_event: MouseEvent) -> None: result.append(("class:menu-header", f"\n{label}\n", mouse_handler)) continue - # Item takes 1 line y_to_idx[current_y] = i current_y += 1 selected = i == self._current_index diff --git a/src/routecode/ui/renderables.py b/src/routecode/ui/renderables.py index e2588b5..f5d6e0f 100644 --- a/src/routecode/ui/renderables.py +++ b/src/routecode/ui/renderables.py @@ -123,16 +123,21 @@ def __init__(self, markup: str, *args, **kwargs): def get_logo(): """Return (route_lines, code_lines) for a sleek, modern look.""" - # Refined characters for maximum "Premium" feel route = [ - "█▀▀▄ █▀▀█ █ █ ▀█▀ █▀▀█", - "█▄▄▀ █░░█ █░░█ █ █░▀▀", - "█ █ ▀▀▀▀ ▀▀▀▀ ▀ ▀▀▀▀", + "██████╗ ██████╗ ██╗ ██╗████████╗███████╗", + "██╔══██╗██╔═══██╗██║ ██║╚══██╔══╝██╔════╝", + "██████╔╝██║ ██║██║ ██║ ██║ █████╗ ", + "██╔══██╗██║ ██║██║ ██║ ██║ ██╔══╝ ", + "██║ ██║╚██████╔╝╚██████╔╝ ██║ ███████╗", + "╚═╝ ╚═╝ ╚═════╝ ╚═════╝ ╚═╝ ╚══════╝", ] code = [ - "█▀▀▀ █▀▀█ █▀▀▄ █▀▀█", - "█░░░ █░░█ █░░█ █▀▀▀", - "▀▀▀▀ ▀▀▀▀ ▀▀▀ ▀▀▀▀", + " ██████╗ ██████╗ ██████╗ ███████╗", + "██╔════╝██╔═══██╗██╔══██╗██╔════╝", + "██║ ██║ ██║██║ ██║█████╗ ", + "██║ ██║ ██║██║ ██║██╔══╝ ", + "╚██████╗╚██████╔╝██████╔╝███████╗", + " ╚═════╝ ╚═════╝ ╚═════╝ ╚══════╝", ] return route, code diff --git a/src/routecode/ui/repl/app.py b/src/routecode/ui/repl/app.py index 2d35c83..2a3bbb9 100644 --- a/src/routecode/ui/repl/app.py +++ b/src/routecode/ui/repl/app.py @@ -3,6 +3,7 @@ import sys import time import asyncio +import traceback from io import StringIO from prompt_toolkit.completion import WordCompleter from prompt_toolkit.styles import DynamicStyle @@ -17,19 +18,17 @@ from prompt_toolkit.keys import Keys from prompt_toolkit.filters import has_focus -from ...core import SessionState, RouteCodeContext, bus, PathGuard +from ...core import bus from ...utils.logger import get_logger from .. import console, get_tool_label from ...commands import execute_command, get_command_metadata from ...tools import registry, AuthorizationMiddleware from ...config import config, CONFIG_DIR, compute_system_prompt -from ...domain.task_manager import task_manager -from ...core.orchestrator import AgentOrchestrator from .styles import RouteCodeVt100Output, build_repl_style from .layout import RouteCodeLayout from .handlers import AppHooks - +from .bindings import KeyBindingsMixin from ..dialogs import HoverCompletionsMenu @@ -37,7 +36,7 @@ logger = get_logger(__name__) -class RouteCodeREPL: +class RouteCodeREPL(KeyBindingsMixin): def __init__(self): command_metadata = get_command_metadata() from ...domain.skills import discover_skills @@ -58,16 +57,13 @@ def __init__(self): multiline=False, completer=self.completer, complete_while_typing=True ) - # ── Flags ───────────────────────────────────────────────────────────── self._welcome_mode = True self.is_working = False self.work_start_time = 0 - # Setup redirection for rich console self._output_buffer = StringIO() - self._rich_console = console # The shared console + self._rich_console = console - # Disable litellm's verbose logging import litellm litellm.set_verbose = False @@ -78,7 +74,6 @@ def __init__(self): logging.getLogger("LiteLLM").setLevel(logging.ERROR) logging.getLogger("litellm").setLevel(logging.ERROR) - # Force colors and reasonable width self._rich_console.force_terminal = True self._rich_console.color_system = "truecolor" try: @@ -86,41 +81,10 @@ def __init__(self): except Exception: self._rich_console.width = 120 - # Output interception self._original_print = self._rich_console.print self._rich_console.print = self._intercepted_print self.history_buffer.text = "" - self.style = build_repl_style() - self._set_terminal_background() - - from ...core.memory import MemoryManager - - self.memory = MemoryManager(CONFIG_DIR) - self.state = SessionState() - self.path_guard = PathGuard() - self.ctx = RouteCodeContext( - state=self.state, - config=config, - console=self._rich_console, - task_manager=task_manager, - memory=self.memory, - path_guard=self.path_guard, - ) - self.auto_save_counter = 0 - self.logo_animation_count = 0 - self.orchestrator = AgentOrchestrator(self.ctx) - from ...core.audit import audit_hook - - registry.add_post_hook(audit_hook) - registry.add_middleware( - AuthorizationMiddleware(confirm_callback=self._confirm_destructive) - ) - - self._setup_event_handlers() - self._kb = KeyBindings() - self._setup_key_bindings() - self.app = None self.layout_manager = RouteCodeLayout(self) self.is_modal_open = False @@ -157,7 +121,6 @@ def _schedule(): else: self.ctx.loop.call_soon_threadsafe(_schedule) except RuntimeError: - # We are not in an async context, safely delegate to the main loop self.ctx.loop.call_soon_threadsafe(_schedule) def _is_scrolled_to_bottom(self): @@ -169,51 +132,6 @@ def update_style(self): self.style = styles.build_repl_style(is_dimmed=self.is_modal_open) self.request_invalidate() - def _setup_key_bindings(self): - @self._kb.add("c-c") - def _(event): - if getattr(self, "is_working", False): - if hasattr(self, "_current_agent_task") and self._current_agent_task: - self._current_agent_task.cancel() - self._rich_console.print(" [yellow]Agent aborted by user.[/yellow]") - return - - now = time.time() - if self._ctrl_c_press_time and (now - self._ctrl_c_press_time) < 3.0: - event.app.exit() - else: - self._ctrl_c_press_time = now - self.toast_message = "Press Ctrl+C again to exit" - self.request_invalidate() - - def clear_toast(): - if ( - self.toast_message - and time.time() - self._ctrl_c_press_time >= 2.9 - ): - self.toast_message = None - self.request_invalidate() - - if getattr(self, "ctx", None) and getattr(self.ctx, "loop", None): - self.ctx.loop.call_later(3.0, clear_toast) - - @self._kb.add("enter", filter=has_focus(self.input_buffer)) - def _(event): - text = self.input_buffer.text.strip() - self.input_buffer.reset() - if text: - if self._welcome_mode: - self._switch_to_session_mode() - self._current_agent_task = asyncio.create_task(self.handle_input(text)) - - @self._kb.add(Keys.ScrollUp) - def _(event): - self.history_buffer.cursor_up(count=3) - win = self.layout_manager.history_main - if win: - win.vertical_scroll = max(0, win.vertical_scroll - 3) - event.app.invalidate() - @self._kb.add(Keys.ScrollDown) def _(event): self.history_buffer.cursor_down(count=3) @@ -286,7 +204,6 @@ async def _on_turn_complete(count, **kwargs): await handle_save(["auto"], self.ctx) - bus.on("session.turn_complete", _on_turn_complete) bus.on("session.reset", self._on_session_reset) bus.on("ui.theme_changed", lambda **kwargs: self._on_theme_changed()) @@ -328,13 +245,10 @@ def _on_session_reset(self, **kwargs): self.request_invalidate() def _on_theme_changed(self): - self._set_terminal_background() self.style = build_repl_style() + self._original_print = self._rich_console._instance.print self.request_invalidate() - def _set_terminal_background(self): - self.style = build_repl_style() - def _on_resize(self): try: self._rich_console.width = os.get_terminal_size().columns @@ -342,9 +256,58 @@ def _on_resize(self): pass async def run(self): - self.ctx.loop = asyncio.get_running_loop() + loop = asyncio.get_running_loop() + + # Capture asyncio task exceptions that would otherwise be silently dropped + original_handler = loop.get_exception_handler() + + def _task_exc_handler(loop, context): + logger.error("Asyncio exception: %s", context) + if original_handler: + original_handler(loop, context) + + loop.set_exception_handler(_task_exc_handler) + + # ── Phase 1: Build services (dependency-ordered, synchronous) ── + from ...core import AppContainer + + self.container = AppContainer(CONFIG_DIR) + self.container.build() + + # Swap the rich console reference to our intercepted one + self.container.ctx.console = self._rich_console + self.ctx = self.container.ctx + self.state = self.container.state + self.memory = self.container.memory + self.ctx.loop = loop + self.auto_save_counter = 0 + self.logo_animation_count = 0 + + # ── Phase 2: Initialize (async — load persisted state) ─────── + await self.container.initialize() + + # ── Phase 3: Validate (assert all services ready) ───────────── + self.container.validate() + + # ── Phase 4: Wire ───────────────────────────────────────────── + from ..theme import apply_theme + + apply_theme(config.theme) + self.style = build_repl_style() + self.orchestrator = self.container.orchestrator + + from ...core.audit import audit_hook - # Pre-build both layouts + registry.add_post_hook(audit_hook) + registry.add_middleware( + AuthorizationMiddleware(confirm_callback=self._confirm_destructive) + ) + + self._setup_event_handlers() + self._kb = KeyBindings() + self._setup_key_bindings() + + # ── Phase 5: Build UI layout ────────────────────────────────── self._welcome_container = self.layout_manager.build_welcome_layout() self._session_container = self.layout_manager.build_session_layout() @@ -379,14 +342,15 @@ async def run(self): self.app.routecode_repl = self - self.app.routecode_repl = self - - # Start periodic refresh loop for animations + # ── Phase 6: Start ──────────────────────────────────────────── asyncio.create_task(self._periodic_refresh_loop()) - # Start background update check (fires after 3s delay) asyncio.create_task(self._check_for_updates()) - await self.app.run_async() + try: + await self.app.run_async() + finally: + # ── Phase 7: Shutdown ──────────────────────────────── + self.container.shutdown() def _get_active_layout(self): if self._welcome_mode: @@ -403,13 +367,17 @@ def _switch_to_session_mode(self): async def handle_input(self, text): self.history_buffer.cursor_position = len(self.history_buffer.text) - if text.startswith("/"): - if await execute_command(text, self.ctx): - pass + try: + if text.startswith("/"): + if await execute_command(text, self.ctx): + pass + else: + self._rich_console.print(f" [error]✘[/error] Unknown command: {text}") else: - self._rich_console.print(f" [error]✘[/error] Unknown command: {text}") - else: - await self.process_agent_request(text) + await self.process_agent_request(text) + except Exception as e: + logger.error("handle_input crashed:\n%s", traceback.format_exc()) + self._rich_console.print(f" [error]✘[/error] Internal error: {e}") async def process_agent_request(self, user_input: str): if not self.orchestrator.provider: diff --git a/src/routecode/ui/repl/bindings.py b/src/routecode/ui/repl/bindings.py new file mode 100644 index 0000000..e1c9b51 --- /dev/null +++ b/src/routecode/ui/repl/bindings.py @@ -0,0 +1,103 @@ +""" +Key bindings for the RouteCode REPL. + +Extracted from app.py to reduce RouteCodeREPL to ~250 lines. +RouteCodeREPL inherits from KeyBindingsMixin. +""" + +import asyncio +import time +from prompt_toolkit.keys import Keys +from prompt_toolkit.filters import has_focus +from ...utils.logger import get_logger + +logger = get_logger(__name__) + + +class KeyBindingsMixin: + """Mixin that provides _setup_key_bindings() for RouteCodeREPL.""" + + def _setup_key_bindings(self): + @self._kb.add("c-c") + def _(event): + if getattr(self, "is_working", False): + if hasattr(self, "_current_agent_task") and self._current_agent_task: + self._current_agent_task.cancel() + self._rich_console.print(" [yellow]Agent aborted by user.[/yellow]") + return + + now = time.time() + if self._ctrl_c_press_time and (now - self._ctrl_c_press_time) < 3.0: + event.app.exit() + else: + self._ctrl_c_press_time = now + self.toast_message = "Press Ctrl+C again to exit" + self.request_invalidate() + + def clear_toast(): + if ( + self.toast_message + and time.time() - self._ctrl_c_press_time >= 2.9 + ): + self.toast_message = None + self.request_invalidate() + + if getattr(self, "ctx", None) and getattr(self.ctx, "loop", None): + self.ctx.loop.call_later(3.0, clear_toast) + + @self._kb.add("enter", filter=has_focus(self.input_buffer)) + def _(event): + text = self.input_buffer.text.strip() + self.input_buffer.reset() + if text: + if self._welcome_mode: + self._switch_to_session_mode() + self._current_agent_task = asyncio.create_task( + self.handle_input(text) + ) + + @self._kb.add(Keys.ScrollUp) + def _(event): + self.history_buffer.cursor_up(count=3) + win = self.layout_manager.history_main + if win: + win.vertical_scroll = max(0, win.vertical_scroll - 3) + event.app.invalidate() + + @self._kb.add(Keys.ScrollDown) + def _(event): + self.history_buffer.cursor_down(count=3) + win = self.layout_manager.history_main + if win and win.render_info: + max_scroll = max( + 0, + win.render_info.ui_content.line_count + - win.render_info.window_height, + ) + win.vertical_scroll = min(max_scroll, win.vertical_scroll + 3) + elif win: + win.vertical_scroll += 3 + event.app.invalidate() + + @self._kb.add(Keys.PageUp) + def _(event): + self.history_buffer.cursor_up(count=15) + win = self.layout_manager.history_main + if win: + win.vertical_scroll = max(0, win.vertical_scroll - 15) + event.app.invalidate() + + @self._kb.add(Keys.PageDown) + def _(event): + self.history_buffer.cursor_down(count=15) + win = self.layout_manager.history_main + if win and win.render_info: + max_scroll = max( + 0, + win.render_info.ui_content.line_count + - win.render_info.window_height, + ) + win.vertical_scroll = min(max_scroll, win.vertical_scroll + 15) + elif win: + win.vertical_scroll += 15 + event.app.invalidate() diff --git a/src/routecode/ui/repl/handlers.py b/src/routecode/ui/repl/handlers.py index 92bab17..398669a 100644 --- a/src/routecode/ui/repl/handlers.py +++ b/src/routecode/ui/repl/handlers.py @@ -280,6 +280,13 @@ async def on_turn_complete(self, full_response, tool_calls): self._remove_cursor() self._remove_thinking() + # Auto-save every 5 turns + from ..commands import handle_save + + self.repl.auto_save_counter += 1 + if self.repl.auto_save_counter > 0 and self.repl.auto_save_counter % 5 == 0: + await handle_save(["auto"], self.repl.ctx) + if self._in_thought: if self._stream_buffer: formatted = escape(self._stream_buffer).replace("\n", "\n[dim]│[/dim] ") diff --git a/src/routecode/ui/repl/layout.py b/src/routecode/ui/repl/layout.py index 1f33a5d..47512af 100644 --- a/src/routecode/ui/repl/layout.py +++ b/src/routecode/ui/repl/layout.py @@ -6,7 +6,7 @@ from prompt_toolkit.widgets import Frame from prompt_toolkit.filters import Condition from .styles import SimpleAnsiLexer -from ..theme import THEME_ACCENTS, _current_theme_name +from ..theme import get_theme_accent from ..renderables import get_logo, format_duration from ... import __version__ @@ -232,6 +232,7 @@ def build_session_layout(self): left_pane = HSplit( [ history_window, + Window(style="class:history"), # Expanding spacer to push input_area to the bottom input_area, ] ) @@ -246,7 +247,7 @@ def build_session_layout(self): # ── Text Generators ─────────────────────────────────────────────────────── def _get_logo_formatted(self): - accent = THEME_ACCENTS.get(_current_theme_name, "#ffaf00") + accent = get_theme_accent() dim_color = "#888899" # vibrant silver for "route" gap = " " # 3-space gap between words result = [] @@ -264,7 +265,7 @@ def _get_logo_formatted(self): return result def _get_welcome_model_line(self): - accent = THEME_ACCENTS.get(_current_theme_name, "#ffaf00") + accent = get_theme_accent() return [ (f"fg:{accent} bold", f"{self.repl.ctx.config.provider.title()}"), ("fg:#555566", " · "), @@ -272,7 +273,7 @@ def _get_welcome_model_line(self): ] def _get_welcome_tip(self): - accent = THEME_ACCENTS.get(_current_theme_name, "#ffaf00") + accent = get_theme_accent() return [ (f"fg:{accent}", "● "), (f"fg:{accent} bold", "Tip "), @@ -282,7 +283,7 @@ def _get_welcome_tip(self): ] def _get_input_model_line(self): - accent = THEME_ACCENTS.get(_current_theme_name, "#ffaf00") + accent = get_theme_accent() return [ ( f"bg:#22222a fg:{accent} bold", @@ -307,7 +308,7 @@ def _get_session_footer_left(self): spinner = frames[frame_idx] dur_str = format_duration(duration) - accent = THEME_ACCENTS.get(_current_theme_name, "#ffaf00") + accent = get_theme_accent() res.extend( [ ("fg:#555566", " · "), @@ -318,7 +319,7 @@ def _get_session_footer_left(self): return res def _get_session_footer_right(self): - accent = THEME_ACCENTS.get(_current_theme_name, "#ffaf00") + accent = get_theme_accent() cwd = os.path.basename(os.getcwd()) or "~" base = [] diff --git a/src/routecode/ui/repl/styles.py b/src/routecode/ui/repl/styles.py index f2585b6..b824ade 100644 --- a/src/routecode/ui/repl/styles.py +++ b/src/routecode/ui/repl/styles.py @@ -3,7 +3,7 @@ from prompt_toolkit.formatted_text import ANSI from prompt_toolkit.output.vt100 import Vt100_Output -from ..theme import get_theme_bg, THEME_ACCENTS, _current_theme_name +from ..theme import get_theme_bg, get_theme_accent from ...utils.helpers import parse_hex_color @@ -90,7 +90,7 @@ def _is_modal_open() -> bool: def build_repl_style(is_dimmed: bool = False): bg = get_theme_bg() active_bg = _get_active_bg(is_dimmed) - accent = THEME_ACCENTS.get(_current_theme_name, "#ffaf00") + accent = get_theme_accent() if is_dimmed: sidebar_bg = "#08080a" diff --git a/src/routecode/utils/costs.py b/src/routecode/utils/costs.py index 8d2ccaf..30317a0 100644 --- a/src/routecode/utils/costs.py +++ b/src/routecode/utils/costs.py @@ -1,6 +1,6 @@ import litellm import re -from typing import Tuple +from typing import Tuple, Optional from .logger import get_logger from ..config.models_db import get_model_pricing @@ -16,17 +16,13 @@ class CostEstimator: def __init__( self, default_input_price: float = 2.0, default_output_price: float = 10.0 ): - # Default prices per 1M tokens self.default_input_price = default_input_price self.default_output_price = default_output_price + self._model_info_cache: dict = {} + self._model_info_failed: set = set() def count_tokens(self, text: str, model: str) -> int: - """ - Calculates token count using litellm's token_counter, - which uses the appropriate tokenizer for the given model. - """ try: - # LiteLLM handles different tokenizers (tiktoken, anthropic, etc.) return litellm.token_counter(model=model, text=text) except Exception as e: logger.debug( @@ -34,36 +30,43 @@ def count_tokens(self, text: str, model: str) -> int: model, e, ) - # Last resort fallback if litellm fails or model is unknown - # Provides a better estimate for code/JSON than simple split() tokens = re.findall(r"\w+|[^\w\s]", text) return len(tokens) + def _get_model_info(self, model: str) -> Optional[dict]: + """Cached wrapper around litellm.get_model_info to avoid blocking + the event loop with repeated synchronous network calls.""" + if model in self._model_info_cache: + return self._model_info_cache[model] + if model in self._model_info_failed: + return None + + try: + info = litellm.get_model_info(model) + self._model_info_cache[model] = info + return info + except Exception as e: + logger.debug( + "Failed to get model info from LiteLLM for %s: %s", model, e + ) + self._model_info_failed.add(model) + return None + def calculate_cost( self, input_tokens: int, output_tokens: int, model: str ) -> Tuple[float, int, float]: - """ - Calculates the cost based on input and output tokens for a specific model. - Returns (estimated_cost, context_limit, input_cost_per_1m). - """ - # Try to get pricing from litellm first - try: - model_info = litellm.get_model_info(model) - if model_info: - input_price = model_info.get("input_cost_per_token") - output_price = model_info.get("output_cost_per_token") - context_limit = model_info.get("max_tokens", 32000) + model_info = self._get_model_info(model) + if model_info: + input_price = model_info.get("input_cost_per_token") + output_price = model_info.get("output_cost_per_token") + context_limit = model_info.get("max_tokens", 32000) - if input_price is not None and output_price is not None: - cost = (input_tokens * input_price) + (output_tokens * output_price) - # Normalize to per 1M tokens for the third return value - return cost, context_limit, input_price * 1_000_000 - except Exception as e: - logger.debug("Failed to get model info from LiteLLM for %s: %s", model, e) + if input_price is not None and output_price is not None: + cost = (input_tokens * input_price) + (output_tokens * output_price) + return cost, context_limit, input_price * 1_000_000 # Fallback to internal models_db input_price_m, output_price_m, context_limit = get_model_pricing(model) - # models_db prices are per 1M tokens cost = (input_tokens * input_price_m / 1_000_000) + ( output_tokens * output_price_m / 1_000_000 ) diff --git a/src/routecode/utils/errors.py b/src/routecode/utils/errors.py index 49747f4..5eb9e01 100644 --- a/src/routecode/utils/errors.py +++ b/src/routecode/utils/errors.py @@ -1,5 +1,5 @@ import httpx -from typing import Dict +from typing import Dict, Optional class ClassifiedError: @@ -82,6 +82,32 @@ def to_message(self) -> Dict[str, str]: } +# Shared keyword-to-category matching — single source of truth for both +# classify_http_error() and classify_exception(). +_KEYWORD_MATCHERS = [ + ("auth", ["api key", "unauthorized", "unauth", "auth failed", "401", "403"]), + ("insufficient_quota", ["credit", "quota", "balance", "402", "insufficient_quota"]), + ("rate_limit", ["rate limit", "rate_limit", "429", "too many requests"]), + ("timeout", ["timeout", "timed out", "408"]), + ("prompt_too_long", ["too long", "context length", "maximum context", "413"]), + ("connection", ["connect", "connection", "dns", "name resolution"]), +] + + +def _match_keywords(text: str) -> Optional[str]: + """Checks a lowercase text string against ERROR_CATEGORY keyword matchers. + Returns the category key if a match is found, or None.""" + for category, keywords in _KEYWORD_MATCHERS: + for kw in keywords: + if kw in text: + return category + return None + + +def _model_not_found_in(text: str) -> bool: + return "model" in text and ("not found" in text or "not support" in text) + + def classify_http_error(status_code: int, body: str = "") -> ClassifiedError: body_lower = body.lower() @@ -92,7 +118,7 @@ def classify_http_error(status_code: int, body: str = "") -> ClassifiedError: elif status_code == 402: return ERROR_CATEGORIES["insufficient_quota"] elif status_code == 404: - if "model" in body_lower: + if _model_not_found_in(body_lower): return ERROR_CATEGORIES["model_not_found"] return ERROR_CATEGORIES["bad_request"] elif status_code == 413: @@ -102,14 +128,9 @@ def classify_http_error(status_code: int, body: str = "") -> ClassifiedError: elif status_code >= 500: return ERROR_CATEGORIES["server_error"] elif status_code >= 400: - if "credit" in body_lower or "quota" in body_lower or "balance" in body_lower: - return ERROR_CATEGORIES["insufficient_quota"] - if "timeout" in body_lower or "timed out" in body_lower: - return ERROR_CATEGORIES["timeout"] - if "model" in body_lower and ( - "not found" in body_lower or "not supported" in body_lower - ): - return ERROR_CATEGORIES["model_not_found"] + cat = _match_keywords(body_lower) + if cat: + return ERROR_CATEGORIES[cat] return ERROR_CATEGORIES["bad_request"] else: return ERROR_CATEGORIES["unknown"] @@ -130,19 +151,11 @@ def classify_exception(e: Exception) -> ClassifiedError: return ERROR_CATEGORIES["connection"] msg = str(e).lower() - if "timeout" in msg or "timed out" in msg: - return ERROR_CATEGORIES["timeout"] - if "rate limit" in msg or "429" in msg: - return ERROR_CATEGORIES["rate_limit"] - if "api key" in msg or "unauthorized" in msg or "401" in msg or "403" in msg: - return ERROR_CATEGORIES["auth"] - if "credit" in msg or "quota" in msg or "balance" in msg or "402" in msg: - return ERROR_CATEGORIES["insufficient_quota"] - if "model" in msg and ("not found" in msg or "not support" in msg): + if _model_not_found_in(msg): return ERROR_CATEGORIES["model_not_found"] - if "too long" in msg or "context length" in msg or "maximum context" in msg: - return ERROR_CATEGORIES["prompt_too_long"] - if "connect" in msg or "connection" in msg or "dns" in msg: - return ERROR_CATEGORIES["connection"] + + cat = _match_keywords(msg) + if cat: + return ERROR_CATEGORIES[cat] return ERROR_CATEGORIES["unknown"] diff --git a/src/routecode/utils/logger.py b/src/routecode/utils/logger.py index 3bcaced..bd2b803 100644 --- a/src/routecode/utils/logger.py +++ b/src/routecode/utils/logger.py @@ -12,24 +12,24 @@ def setup_logging(level=logging.INFO): if not CONFIG_DIR.exists(): CONFIG_DIR.mkdir(parents=True, exist_ok=True) - # Clear existing handlers if any root = logging.getLogger() if root.handlers: for handler in root.handlers[:]: root.removeHandler(handler) + numeric_level = getattr(logging, level.upper()) if isinstance(level, str) else level + logging.basicConfig( - level=level, + level=numeric_level, format="%(asctime)s [%(levelname)s] %(name)s: %(message)s", handlers=[logging.FileHandler(LOG_FILE, encoding="utf-8")], ) - # Suppress verbose third-party logs logging.getLogger("httpx").setLevel(logging.WARNING) logging.getLogger("httpcore").setLevel(logging.WARNING) logging.getLogger("prompt_toolkit").setLevel(logging.ERROR) - logging.info("Logging initialized. File: %s", LOG_FILE) + logging.info("Logging initialized. File: %s Level: %s", LOG_FILE, level) def get_logger(name: str): diff --git a/src/routecode/utils/storage.py b/src/routecode/utils/storage.py index 0c955cc..eb35b92 100644 --- a/src/routecode/utils/storage.py +++ b/src/routecode/utils/storage.py @@ -11,6 +11,8 @@ class AtomicJsonStore: Unified, crash-safe JSON persistence layer. Ensures that writes are atomic by writing to a temporary file and renaming it to the target path. + + Supports context manager protocol for lifecycle scoping. """ def __init__(self, path: Path): @@ -20,6 +22,22 @@ def __init__(self, path: Path): def _ensure_dir(self): self.path.parent.mkdir(parents=True, exist_ok=True) + # ── Context manager protocol ───────────────────────────────────────── + + def __enter__(self): + return self + + def __exit__(self, *args): + self.cleanup_stale_temps() + + async def __aenter__(self): + return self + + async def __aexit__(self, *args): + self.cleanup_stale_temps() + + # ── Core API ───────────────────────────────────────────────────────── + def load(self, default: Optional[Dict] = None) -> Dict[str, Any]: """Loads JSON data from the file. Returns default if file doesn't exist or is invalid.""" if not self.path.exists(): @@ -65,7 +83,7 @@ async def save_async(self, data: Dict[str, Any]): await asyncio.to_thread(os.replace, tmp_path, self.path) except Exception as e: if tmp_path.exists(): - await asyncio.to_thread(tmp_path.unlink, missing_ok=True) + await asyncio.to_thread(tmp_path.unlink) raise e def exists(self) -> bool: @@ -74,3 +92,23 @@ def exists(self) -> bool: def delete(self): if self.path.exists(): self.path.unlink() + + # ── Cleanup ────────────────────────────────────────────────────────── + + def cleanup_stale_temps(self, base_path: Optional[Path] = None): + """ + Removes orphaned .tmp files left behind by crashed sessions. + Call once at startup to clean up from previous incomplete writes. + + Args: + base_path: Directory to scan. Defaults to the parent directory + of this store's path. + """ + search_dir = base_path or self.path.parent + if not search_dir.exists(): + return + for tmp_file in search_dir.glob("*.tmp"): + try: + tmp_file.unlink() + except Exception: + pass diff --git a/tests/test_cloudflare_native.py b/tests/test_cloudflare_native.py new file mode 100644 index 0000000..c18a6fa --- /dev/null +++ b/tests/test_cloudflare_native.py @@ -0,0 +1,69 @@ +import asyncio +import os +import json +import sys +import logging +from pathlib import Path +from typing import List, Dict, Any, Optional + +# Add src to path +sys.path.append(str(Path(__file__).parent.parent / "src")) + +from routecode.agents.cloudflare_provider import CloudflareProvider +from routecode.utils.logger import setup_logging + +async def test_cloudflare(model_id: Optional[str] = None): + setup_logging(logging.DEBUG) + api_key = os.environ.get("CLOUDFLARE_API_KEY") + account_id = os.environ.get("CLOUDFLARE_ACCOUNT_ID") + + if not api_key or not account_id: + print("Error: CLOUDFLARE_API_KEY and CLOUDFLARE_ACCOUNT_ID environment variables must be set.") + return + + print(f"Testing Cloudflare Native Provider...") + print(f"Account ID: {account_id}") + + provider = CloudflareProvider(api_key, account_id=account_id) + + if not model_id: + # Test model list + print("\nFetching models...") + models = await provider.get_models() + print(f"Found {len(models)} models.") + + # Default model selection + model_id = "@cf/meta/llama-3-8b-instruct" + if any("moonshot" in m["id"] for m in models): + # Sort to get highest version first or 2.6 specifically + moonshots = [m["id"] for m in models if "moonshot" in m["id"]] + moonshots.sort(reverse=True) + model_id = moonshots[0] + if any("2.6" in m for m in moonshots): + model_id = [m for m in moonshots if "2.6" in m][0] + + print(f"\nTesting streaming ask with model: {model_id}") + messages = [ + {"role": "system", "content": "You are a helpful assistant."}, + {"role": "user", "content": "Hello! Reply with a short joke."} + ] + + full_response = "" + print("Response: ", end="", flush=True) + async for chunk in provider.ask(messages, model_id, stream=True): + if chunk["type"] == "text": + print(chunk["content"], end="", flush=True) + full_response += chunk["content"] + elif chunk["type"] == "error": + print(f"\nError: {chunk['content']}") + break + + print("\n\nTest Complete.") + if full_response: + print("Status: SUCCESS") + else: + print("Status: FAILED (Empty response)") + +if __name__ == "__main__": + mid = sys.argv[1] if len(sys.argv) > 1 else None + asyncio.run(test_cloudflare(mid))