From 47a8de75777456547f065407a7da90abcc6cf12e Mon Sep 17 00:00:00 2001 From: ANAS Date: Sun, 17 May 2026 21:57:06 +0100 Subject: [PATCH] feat: improving core SDK modules and CLI scaffold for multi-provider AI agent interactions and workspace session management --- Cargo.lock | 10 + apps/cli/Cargo.toml | 1 + apps/cli/src/main.rs | 60 +- apps/cli/src/ui/components.rs | 76 ++ apps/cli/src/ui/logic.rs | 169 +++ apps/cli/src/ui/menus.rs | 286 +++++ apps/cli/src/ui/mod.rs | 1888 +++++++++++++++-------------- apps/cli/src/ui/session.rs | 382 ++++++ apps/cli/src/ui/welcome.rs | 151 +++ cli_build_debug.bat | 3 +- libs/sdk/Cargo.toml | 5 +- libs/sdk/src/agents/anthropic.rs | 120 +- libs/sdk/src/agents/cloudflare.rs | 28 +- libs/sdk/src/agents/gemini.rs | 201 ++- libs/sdk/src/agents/mod.rs | 2 +- libs/sdk/src/agents/openai.rs | 28 +- libs/sdk/src/agents/opencode.rs | 182 ++- libs/sdk/src/agents/openrouter.rs | 77 +- libs/sdk/src/agents/traits.rs | 1 + libs/sdk/src/agents/types.rs | 25 +- libs/sdk/src/agents/utils.rs | 141 ++- libs/sdk/src/core/config.rs | 21 + libs/sdk/src/core/message.rs | 24 +- libs/sdk/src/core/orchestrator.rs | 181 ++- libs/sdk/src/tools/file_ops.rs | 37 +- libs/sdk/src/tools/navigation.rs | 190 ++- libs/sdk/src/utils/costs.rs | 138 ++- libs/sdk/src/utils/storage.rs | 239 +++- libs/sdk/src/utils/tokens.rs | 65 +- 29 files changed, 3486 insertions(+), 1245 deletions(-) create mode 100644 apps/cli/src/ui/components.rs create mode 100644 apps/cli/src/ui/logic.rs create mode 100644 apps/cli/src/ui/menus.rs create mode 100644 apps/cli/src/ui/session.rs create mode 100644 apps/cli/src/ui/welcome.rs diff --git a/Cargo.lock b/Cargo.lock index f0af383..8020d1d 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -603,6 +603,12 @@ dependencies = [ "wasip3", ] +[[package]] +name = "glob" +version = "0.3.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0cc23270f6e1808e30a928bdc84dea0b9b4136a8bc82338574f23baf47bbd280" + [[package]] name = "h2" version = "0.4.14" @@ -1432,6 +1438,7 @@ dependencies = [ "simplelog", "tokio", "tui-textarea", + "unicode-width", ] [[package]] @@ -1445,7 +1452,10 @@ dependencies = [ "dirs", "env_logger", "futures", + "glob", "log", + "once_cell", + "regex", "reqwest", "serde", "serde_json", diff --git a/apps/cli/Cargo.toml b/apps/cli/Cargo.toml index bb3e556..50cf69b 100644 --- a/apps/cli/Cargo.toml +++ b/apps/cli/Cargo.toml @@ -19,3 +19,4 @@ log = "0.4" chrono = { version = "0.4", features = ["serde"] } async-trait = "0.1" simplelog = "0.12" +unicode-width = "0.1" diff --git a/apps/cli/src/main.rs b/apps/cli/src/main.rs index bbbe8b2..ae474d4 100644 --- a/apps/cli/src/main.rs +++ b/apps/cli/src/main.rs @@ -41,15 +41,16 @@ pub enum Commands { mod ui; use crossterm::{ - event::{DisableMouseCapture, EnableMouseCapture}, + event::{EnableBracketedPaste, DisableBracketedPaste}, execute, + style::Print, terminal::{disable_raw_mode, enable_raw_mode, EnterAlternateScreen, LeaveAlternateScreen}, }; use ratatui::{backend::CrosstermBackend, Terminal}; use routecode_sdk::core::AgentOrchestrator; use routecode_sdk::tools::bash::BashTool; use routecode_sdk::tools::file_ops::{FileEditTool, FileReadTool, FileWriteTool}; -use routecode_sdk::tools::navigation::{GrepTool, LsTool}; +use routecode_sdk::tools::navigation::{GrepTool, LsTool, TreeTool}; use routecode_sdk::tools::ToolRegistry; use std::io; use std::process::Command; @@ -72,35 +73,53 @@ async fn main() -> anyhow::Result<()> { let log_level = if cli.debug { LevelFilter::Debug } else { LevelFilter::Info }; - CombinedLogger::init(vec![ + let loggers: Vec> = vec![ WriteLogger::new( log_level, ConfigBuilder::new().set_time_format_rfc3339().build(), File::create(&log_path)?, ), - ])?; + ]; + + // Only use TermLogger if we are NOT about to enter TUI mode immediately, + // or if it's a headless run. But for now, let's just use it for the very beginning. + // Actually, simplelog doesn't support easy removal, so we'll just use WriteLogger + // and manual printlns for the very early stages if needed. + + CombinedLogger::init(loggers)?; + + log::info!("Starting RouteCode v{}", env!("CARGO_PKG_VERSION")); if cli.debug { log::debug!("Debug mode active. Spawning log window..."); // Spawn a new terminal window to tail the log file #[cfg(target_os = "windows")] { - let _ = Command::new("cmd") - .args(["/C", "start", "powershell", "-NoExit", "-Command", &format!("Get-Content -Path '{}' -Wait", log_path.display())]) - .spawn(); + if let Err(e) = Command::new("cmd") + .args(["/C", "start", "powershell", "-NoExit", "-Command", &format!("Get-Content -Path \"{}\" -Wait", log_path.display())]) + .spawn() + { + log::warn!("Failed to spawn debug log window: {}", e); + } } #[cfg(target_os = "macos")] { - let _ = Command::new("osascript") + if let Err(e) = Command::new("osascript") .args(["-e", &format!("tell application \"Terminal\" to do script \"tail -f '{}'\"", log_path.display())]) - .spawn(); + .spawn() + { + log::warn!("Failed to spawn debug log window: {}", e); + } } #[cfg(target_os = "linux")] { // Try common terminal emulators - let _ = Command::new("x-terminal-emulator") + if let Err(e) = Command::new("x-terminal-emulator") .args(["-e", "tail", "-f", &log_path.display().to_string()]) - .spawn(); + .spawn() + { + log::warn!("Failed to spawn debug log window: {}", e); + } } } @@ -148,6 +167,7 @@ async fn main() -> anyhow::Result<()> { tool_registry.register(Arc::new(FileEditTool)); tool_registry.register(Arc::new(BashTool)); tool_registry.register(Arc::new(LsTool)); + tool_registry.register(Arc::new(TreeTool)); tool_registry.register(Arc::new(GrepTool)); let tool_registry = Arc::new(tool_registry); @@ -161,7 +181,7 @@ async fn main() -> anyhow::Result<()> { // Setup terminal enable_raw_mode()?; let mut stdout = io::stdout(); - execute!(stdout, EnterAlternateScreen, EnableMouseCapture)?; + execute!(stdout, EnterAlternateScreen, Print("\x1b[?1003h\x1b[?1006h"), EnableBracketedPaste)?; let backend = CrosstermBackend::new(stdout); let mut terminal = Terminal::new(backend)?; @@ -177,8 +197,19 @@ async fn main() -> anyhow::Result<()> { app.current_model = session.model; let mut u = app.orchestrator.usage.lock().await; *u = session.usage; + app.session_id = resume_name.clone(); + if let Ok(config) = routecode_sdk::utils::storage::load_session_config(&resume_name) { + app.orchestrator.allow_session_commands.store(config.allow_all_commands, std::sync::atomic::Ordering::SeqCst); + app.orchestrator.allow_session_outside_access.store(config.allow_all_outside_access, std::sync::atomic::Ordering::SeqCst); + } } - Err(e) => eprintln!("Failed to resume session: {}", e), + Err(e) => app.history.push(routecode_sdk::core::Message::system(format!("Failed to resume session '{}': {}", resume_name, e))), + } + } + + if let Ok(workspace_config) = routecode_sdk::utils::storage::load_workspace_config() { + if workspace_config.allow_all_outside_access { + app.orchestrator.allow_session_outside_access.store(true, std::sync::atomic::Ordering::SeqCst); } } @@ -189,7 +220,8 @@ async fn main() -> anyhow::Result<()> { execute!( terminal.backend_mut(), LeaveAlternateScreen, - DisableMouseCapture + Print("\x1b[?1003l\x1b[?1006l"), + DisableBracketedPaste )?; terminal.show_cursor()?; diff --git a/apps/cli/src/ui/components.rs b/apps/cli/src/ui/components.rs new file mode 100644 index 0000000..3fc29a4 --- /dev/null +++ b/apps/cli/src/ui/components.rs @@ -0,0 +1,76 @@ +use ratatui::layout::Rect; +use ratatui::style::{Color, Modifier, Style}; +use ratatui::text::{Line, Span}; +use ratatui::widgets::{Block, Clear}; +use ratatui::Frame; + +// --- Theme --- +pub const COLOR_PRIMARY: Color = Color::Rgb(0, 150, 255); // Ocean Blue +pub const COLOR_BG: Color = Color::Rgb(25, 25, 25); // Midnight Charcoal +pub const COLOR_INPUT_BG: Color = Color::Rgb(35, 35, 35);// Soft Obsidian +pub const COLOR_SECONDARY: Color = Color::DarkGray; // Slate Gray +pub const COLOR_SYSTEM: Color = Color::Yellow; // Amber Yellow +pub const COLOR_SUCCESS: Color = Color::Green; // Emerald Green +pub const COLOR_TEXT: Color = Color::White; // Primary Text +pub const COLOR_DIM: Color = Color::Rgb(50, 50, 50); // Very Dim Text/Lines + +pub fn clean_model_name(name: &str, provider_id: &str) -> String { + if provider_id.starts_with("cloudflare") && name.starts_with("@cf/") { + name.split('/').last().unwrap_or(name).to_string() + } else if (provider_id == "openrouter" || provider_id == "nvidia") && name.contains('/') { + name.split('/').last().unwrap_or(name).to_string() + } else { + name.to_string() + } +} + +pub fn draw_modal(f: &mut Frame, title: &str, width: u16, height: u16, mouse_col: Option, mouse_row: Option, footer: Vec) -> Rect { + let area = f.size(); + let modal_area = Rect::new( + (area.width.saturating_sub(width)) / 2, + (area.height.saturating_sub(height)) / 2, + width, + height, + ); + f.render_widget(Clear, modal_area); + f.render_widget(Block::default().style(Style::default().bg(COLOR_BG)), modal_area); + + let main_layout = ratatui::layout::Layout::default() + .direction(ratatui::layout::Direction::Vertical) + .constraints([ + ratatui::layout::Constraint::Length(1), // Header + ratatui::layout::Constraint::Min(0), // Content + ratatui::layout::Constraint::Length(1), // Footer Spacer + ratatui::layout::Constraint::Length(1), // Footer + ]) + .margin(1) + .split(modal_area); + + let header_layout = ratatui::layout::Layout::default() + .direction(ratatui::layout::Direction::Horizontal) + .constraints([ + ratatui::layout::Constraint::Min(0), + ratatui::layout::Constraint::Length(5), // "esc" + ]) + .split(main_layout[0]); + + f.render_widget( + ratatui::widgets::Paragraph::new(Span::styled(title, Style::default().add_modifier(Modifier::BOLD))), + header_layout[0], + ); + let mut esc_style = Style::default().fg(COLOR_SECONDARY); + if let (Some(col), Some(row)) = (mouse_col, mouse_row) { + if row <= modal_area.y + 2 && col >= modal_area.x + width.saturating_sub(10) && col <= modal_area.x + width { + esc_style = Style::default().fg(Color::Red).add_modifier(Modifier::BOLD); + } + } + + f.render_widget( + ratatui::widgets::Paragraph::new(Span::styled("esc", esc_style)), + header_layout[1], + ); + + f.render_widget(ratatui::widgets::Paragraph::new(Line::from(footer)), main_layout[3]); + + main_layout[1] +} diff --git a/apps/cli/src/ui/logic.rs b/apps/cli/src/ui/logic.rs new file mode 100644 index 0000000..ddd9cb4 --- /dev/null +++ b/apps/cli/src/ui/logic.rs @@ -0,0 +1,169 @@ +use ratatui::style::Style; +use tui_textarea::TextArea; +use routecode_sdk::agents::StreamChunk; +use routecode_sdk::core::{DynamicModelInfo, Message}; +use crate::ui::{App, PROVIDERS, ModelMenuItem, Screen, COLOR_SECONDARY}; + +pub async fn handle_model_search(app: &mut App, search: &str, force_reset: bool) { + let mut sections: Vec = Vec::new(); + let config = app.orchestrator.config.lock().await.clone(); + + let recent: Vec = config.recent_models.iter() + .filter(|m| m.name.to_lowercase().contains(search) || m.provider_id.to_lowercase().contains(search)) + .cloned() + .collect(); + if !recent.is_empty() { + sections.push(ModelMenuItem::Header("Recently Used".to_string())); + for m in recent { sections.push(ModelMenuItem::Model(m)); } + } + + let favorites: Vec = config.favorites.iter() + .filter(|m| m.name.to_lowercase().contains(search) || m.provider_id.to_lowercase().contains(search)) + .cloned() + .collect(); + if !favorites.is_empty() { + sections.push(ModelMenuItem::Header("Favorite Models".to_string())); + for m in favorites { sections.push(ModelMenuItem::Model(m)); } + } + + let mut by_provider: std::collections::HashMap> = std::collections::HashMap::new(); + for m in &app.all_available_models { + if m.name.to_lowercase().contains(search) || m.provider_id.to_lowercase().contains(search) { + by_provider.entry(m.provider_id.clone()).or_default().push(m.clone()); + } + } + + let mut provider_ids: Vec = by_provider.keys().cloned().collect(); + provider_ids.sort(); + + for p_id in provider_ids { + if let Some(models) = by_provider.get(&p_id) { + let p_name = PROVIDERS.iter().find(|p| p.id == p_id).map(|p| p.name).unwrap_or(&p_id); + sections.push(ModelMenuItem::Header(p_name.to_string())); + for m in models { sections.push(ModelMenuItem::Model(m.clone())); } + } + } + + app.filtered_models = sections; + + if force_reset { + if !app.filtered_models.is_empty() { + let mut first_model = None; + for (i, item) in app.filtered_models.iter().enumerate() { + if let ModelMenuItem::Model(_) = item { first_model = Some(i); break; } + } + app.menu_state.select(first_model); + } else { app.menu_state.select(None); } + } +} + +pub async fn handle_command(app: &mut App, input: &str) { + let parts: Vec<&str> = input.split_whitespace().collect(); + if parts.is_empty() { return; } + let command = parts[0]; + let args = &parts[1..]; + + match command { + "/model" => { + app.show_model_menu = true; + app.is_fetching_models = true; + app.all_available_models.clear(); + app.model_search_input = TextArea::default(); + app.model_search_input.set_cursor_line_style(Style::default()); + app.model_search_input.set_placeholder_text(" Search models..."); + app.model_search_input.set_placeholder_style(Style::default().fg(COLOR_SECONDARY)); + handle_model_search(app, "", true).await; + let config_mutex = app.orchestrator.config.clone(); + let tx = app.tx.clone(); + tokio::spawn(async move { + let config = config_mutex.lock().await.clone(); + let mut set = tokio::task::JoinSet::new(); + for p_info in PROVIDERS { + let env_key = format!("{}_API_KEY", p_info.id.to_uppercase().replace("-", "_")); + let mut api_key = std::env::var(env_key).ok().or_else(|| config.api_keys.get(p_info.id).cloned()); + if api_key.is_none() && p_info.id.starts_with("cloudflare") { api_key = std::env::var("CLOUDFLARE_API_KEY").ok(); } + if let Some(key) = api_key { + let provider_id = p_info.id.to_string(); + let provider = routecode_sdk::agents::resolve_provider(&provider_id, key); + set.spawn(async move { + match provider.list_models().await { + Ok(models) => { + let dynamic_models: Vec = models.into_iter() + .map(|m| DynamicModelInfo { name: m, provider_id: provider_id.clone() }) + .collect(); + Ok(dynamic_models) + } + Err(e) => Err(e), + } + }); + } + } + while let Some(res) = set.join_next().await { if let Ok(Ok(models)) = res { let _ = tx.send(StreamChunk::Models { models }); } } + let _ = tx.send(StreamChunk::ModelsDone); + }); + } + "/resume" => { + if let Some(name) = args.first() { + if let Ok(session) = routecode_sdk::utils::storage::load_session(name) { + app.history = session.messages; + app.current_model = session.model; + let mut u = app.orchestrator.usage.lock().await; + *u = session.usage; + app.session_id = name.to_string(); + if let Ok(config) = routecode_sdk::utils::storage::load_session_config(name) { + app.orchestrator.allow_session_commands.store(config.allow_all_commands, std::sync::atomic::Ordering::SeqCst); + app.orchestrator.allow_session_outside_access.store(config.allow_all_outside_access, std::sync::atomic::Ordering::SeqCst); + } + if let Ok(workspace_config) = routecode_sdk::utils::storage::load_workspace_config() { + if workspace_config.allow_all_outside_access { + app.orchestrator.allow_session_outside_access.store(true, std::sync::atomic::Ordering::SeqCst); + } + } + app.history.push(Message::system(format!("Session resumed: {}", name))); + app.screen = Screen::Session; + } + } + } + "/sessions" => { + if let Ok(sessions) = routecode_sdk::utils::storage::list_sessions() { + if sessions.is_empty() { app.history.push(Message::system("No saved sessions found.")); } + else { app.history.push(Message::system(format!("Saved sessions:\n {}", sessions.join("\n ")))); } + } + } + "/clear" => { + app.pending_clear = true; + } + "/stop" => { + if app.is_generating { + if let Some(handle) = app.current_task.take() { handle.abort(); } + app.is_generating = false; + app.active_tool = None; + app.history.push(Message::system("Generation cancelled.")); + } + } + "/help" => { + app.history.push(Message::system("Available commands:\n /model - Select model\n /thinking - Set level (low/max)\n /provider - Manage connections\n /settings - Manage settings\n /resume - Resume session\n /sessions - List sessions\n /clear - Clear history\n /help - Show help\n /exit - Use Esc to exit")); + } + "/thinking" => { + if let Some(level) = args.first() { + let level = level.to_lowercase(); + let valid = ["default", "low", "medium", "high", "max"]; + if valid.contains(&level.as_str()) { + let mut config = app.orchestrator.config.lock().await; + config.thinking_level = level.clone(); + let _ = routecode_sdk::utils::storage::save_config(&config); + app.history.push(Message::system(format!("Thinking level set to: {}", level))); + } else { app.history.push(Message::system(format!("Invalid level. Valid: {}", valid.join(", ")))); } + } else { + let config = app.orchestrator.config.lock().await; + app.history.push(Message::system(format!("Current thinking level: {}", config.thinking_level))); + } + } + "/provider" => { app.show_provider_menu = true; app.menu_state.select(Some(0)); } + "/settings" => { app.populate_settings().await; app.show_settings_menu = true; app.menu_state.select(Some(1)); } + "/exit" => { + app.pending_exit = true; + } + _ => { app.history.push(Message::system(format!("Unknown command: {}", command))); } + } +} diff --git a/apps/cli/src/ui/menus.rs b/apps/cli/src/ui/menus.rs new file mode 100644 index 0000000..caab8f0 --- /dev/null +++ b/apps/cli/src/ui/menus.rs @@ -0,0 +1,286 @@ +use ratatui::layout::{Constraint, Layout, Rect}; +use ratatui::style::{Modifier, Style, Color}; +use ratatui::text::{Line, Span}; +use ratatui::widgets::{Block, Borders, List, ListItem, Paragraph}; +use ratatui::Frame; +use crate::ui::{App, PROVIDERS, ModelMenuItem, ApiKeyInputStage}; +use crate::ui::components::{COLOR_PRIMARY, COLOR_SECONDARY, COLOR_TEXT, COLOR_SUCCESS, draw_modal, clean_model_name}; + +pub fn render_menu(f: &mut Frame, app: &mut App, _input_area: Rect) { + let height = (app.filtered_commands.len() + 6).min(15) as u16; + let body_area = draw_modal(f, "Commands", 60, height, app.mouse_col, app.mouse_row, vec![ + Span::styled("Enter", Style::default().add_modifier(Modifier::BOLD)), + Span::raw(" select command") + ]); + + let items: Vec = app.filtered_commands.iter().map(|cmd| { + let total_width = body_area.width.saturating_sub(4); + let left = cmd.name.to_string(); + let right = cmd.description.to_string(); + let padding = total_width.saturating_sub(left.len() as u16).saturating_sub(right.len() as u16); + let spaces = " ".repeat(padding as usize); + ListItem::new(Line::from(vec![ + Span::raw(format!(" {}", left)), + Span::raw(spaces), + Span::styled(right, Style::default().fg(COLOR_SECONDARY)), + Span::raw(" ") + ])) + }).collect(); + + let list = List::new(items) + .highlight_style(Style::default().bg(COLOR_PRIMARY).fg(Color::Black)) + .highlight_symbol(""); + + let items_len = app.filtered_commands.len(); + if app.mouse_moved { + if let (Some(col), Some(row)) = (app.mouse_col, app.mouse_row) { + if col >= body_area.x && col < body_area.x + body_area.width && row >= body_area.y && row < body_area.y + body_area.height { + let idx = (row - body_area.y) as usize + app.menu_state.offset(); + if idx < items_len { + app.menu_state.select(Some(idx)); + } + } + } + app.mouse_moved = false; + } + + f.render_stateful_widget(list, body_area, &mut app.menu_state); +} + +pub fn render_api_key_dialog(f: &mut Frame, app: &mut App) { + let provider_id = app.pending_provider_id.as_deref().unwrap_or("provider"); + let p_info = PROVIDERS.iter().find(|p| p.id == provider_id); + let provider_name = p_info.map(|p| p.name).unwrap_or(provider_id); + + let title = format!("Connect {}", provider_name); + let body_area = draw_modal(f, &title, 60, 10, app.mouse_col, app.mouse_row, vec![ + Span::styled("Press Enter to save", Style::default().add_modifier(Modifier::BOLD)) + ]); + + let layout = Layout::default() + .direction(ratatui::layout::Direction::Vertical) + .constraints([ + Constraint::Length(1), + Constraint::Length(1), + Constraint::Length(3), + ]) + .split(body_area); + + let (prompt, placeholder) = match app.api_key_input_stage { + ApiKeyInputStage::CloudflareAccountId => (format!("Enter Cloudflare Account ID:"), " Account ID..."), + ApiKeyInputStage::CloudflareGatewayId => (format!("Enter Cloudflare Gateway ID:"), " Gateway ID..."), + ApiKeyInputStage::CloudflareApiKey => (format!("Enter Cloudflare API Token:"), " API Token..."), + _ => (format!("Enter API key for {}:", provider_name), " Paste your API key here..."), + }; + + f.render_widget(Paragraph::new(prompt), layout[0]); + + app.api_key_input.set_placeholder_text(placeholder); + app.api_key_input.set_block(Block::default().borders(Borders::ALL).border_style(Style::default().fg(COLOR_SECONDARY))); + f.render_widget(app.api_key_input.widget(), layout[2]); + + let (row, col) = app.api_key_input.cursor(); + f.set_cursor(layout[2].x + 1 + col as u16, layout[2].y + 1 + row as u16); +} + +pub fn render_provider_menu(f: &mut Frame, app: &mut App, _input_area: Rect) { + let height = (PROVIDERS.len() + 6).min(15) as u16; + let body_area = draw_modal(f, "AI Providers", 60, height, app.mouse_col, app.mouse_row, vec![ + Span::styled("Enter", Style::default().add_modifier(Modifier::BOLD)), + Span::raw(" configure API key") + ]); + + let config_guard = app.orchestrator.config.try_lock(); + if config_guard.is_err() { return; } + let config = config_guard.unwrap(); + + let items: Vec = PROVIDERS.iter().map(|p| { + let env_key = format!("{}_API_KEY", p.id.to_uppercase().replace("-", "_")); + let is_connected = config.api_keys.contains_key(p.id) || std::env::var(env_key).is_ok(); + + let status = if is_connected { + Span::styled(" ✔ connected", Style::default().fg(COLOR_SUCCESS)) + } else { + Span::styled(" ✖ disconnected", Style::default().fg(COLOR_SECONDARY)) + }; + + let total_width = body_area.width.saturating_sub(4); + let left = p.name.to_string(); + let status_str = if is_connected { "✔ connected" } else { "✖ disconnected" }; + let padding = total_width.saturating_sub(left.len() as u16).saturating_sub(status_str.len() as u16); + let spaces = " ".repeat(padding as usize); + + ListItem::new(Line::from(vec![ + Span::raw(format!(" {}", left)), + Span::raw(spaces), + status, + Span::raw(" ") + ])) + }).collect(); + + let list = List::new(items) + .highlight_style(Style::default().bg(COLOR_PRIMARY).fg(Color::Black)) + .highlight_symbol(""); + + let items_len = PROVIDERS.len(); + if app.mouse_moved { + if let (Some(col), Some(row)) = (app.mouse_col, app.mouse_row) { + if col >= body_area.x && col < body_area.x + body_area.width && row >= body_area.y && row < body_area.y + body_area.height { + let idx = (row - body_area.y) as usize + app.menu_state.offset(); + if idx < items_len { + app.menu_state.select(Some(idx)); + } + } + } + app.mouse_moved = false; + } + + f.render_stateful_widget(list, body_area, &mut app.menu_state); +} + +pub fn render_model_menu(f: &mut Frame, app: &mut App, _input_area: Rect) { + let height = (app.filtered_models.len() + 7).min(18) as u16; + let mut footer = vec![ + Span::styled("Connect provider ", Style::default().add_modifier(Modifier::BOLD)), + Span::styled("ctrl+a", Style::default().fg(COLOR_SECONDARY)), + Span::raw(" "), + Span::styled("Favorite ", Style::default().add_modifier(Modifier::BOLD)), + Span::styled("ctrl+f", Style::default().fg(COLOR_SECONDARY)), + ]; + + if app.is_fetching_models { + let spinner = ["⠋", "⠙", "⠹", "⠸", "⠼", "⠴", "⠦", "⠧", "⠇", "⠏"]; + let frame = spinner[(app.tick_count % spinner.len() as u64) as usize]; + footer.push(Span::raw(" ")); + footer.push(Span::styled(format!("{} Fetching models...", frame), Style::default().fg(COLOR_PRIMARY).add_modifier(Modifier::BOLD))); + } + + let body_area = draw_modal(f, "Select model", 70, height, app.mouse_col, app.mouse_row, footer); + + let layout = Layout::default() + .direction(ratatui::layout::Direction::Vertical) + .constraints([ + Constraint::Length(2), + Constraint::Min(0), + ]) + .split(body_area); + + let search_text = app.model_search_input.lines()[0].clone(); + let search_para = if search_text.is_empty() { + Paragraph::new(Span::styled("search models...", Style::default().fg(COLOR_SECONDARY))) + } else { + Paragraph::new(Span::styled(&search_text, Style::default().fg(COLOR_TEXT))) + }; + f.render_widget(search_para, layout[0]); + + if app.show_model_menu && !app.is_inputting_api_key { + let (row, col) = app.model_search_input.cursor(); + f.set_cursor(layout[0].x + col as u16, layout[0].y + row as u16); + } + + let config_guard = app.orchestrator.config.try_lock(); + if config_guard.is_err() { return; } + let config = config_guard.unwrap(); + + let items: Vec = app.filtered_models.iter().map(|item| { + match item { + ModelMenuItem::Header(title) => { + ListItem::new(Line::from(vec![ + Span::styled(format!(" {}", title), Style::default().fg(COLOR_SECONDARY).add_modifier(Modifier::DIM)) + ])) + } + ModelMenuItem::Model(m) => { + let is_fav = config.favorites.iter().any(|fav| fav.name == m.name && fav.provider_id == m.provider_id); + let fav_star = if is_fav { " ★" } else { "" }; + let display_name = clean_model_name(&m.name, &m.provider_id).replace(":free", " Free"); + let p_name = PROVIDERS.iter().find(|p| p.id == m.provider_id).map(|p| p.name).unwrap_or(&m.provider_id); + + let left = format!("{}{}", display_name, fav_star); + let right = p_name.to_string(); + let total_width = layout[1].width.saturating_sub(4); + let padding = total_width.saturating_sub(left.len() as u16).saturating_sub(right.len() as u16); + let spaces = " ".repeat(padding as usize); + + ListItem::new(Line::from(vec![ + Span::raw(format!(" {}", left)), + Span::raw(spaces), + Span::raw(right), + Span::raw(" ") + ])) + } + } + }).collect(); + + let list = List::new(items) + .highlight_style(Style::default().bg(COLOR_PRIMARY).fg(Color::Black)) + .highlight_symbol(""); + + let items_len = app.filtered_models.len(); + if app.mouse_moved { + if let (Some(col), Some(row)) = (app.mouse_col, app.mouse_row) { + if col >= layout[1].x && col < layout[1].x + layout[1].width && row >= layout[1].y && row < layout[1].y + layout[1].height { + let idx = (row - layout[1].y) as usize + app.menu_state.offset(); + if idx < items_len { + app.menu_state.select(Some(idx)); + } + } + } + app.mouse_moved = false; + } + + f.render_stateful_widget(list, layout[1], &mut app.menu_state); +} + +use crate::ui::SettingsMenuItem; + +pub fn render_settings_menu(f: &mut Frame, app: &mut App, _input_area: Rect) { + let height = (app.settings_items.len() + 6).min(15) as u16; + let body_area = draw_modal(f, "Settings", 60, height, app.mouse_col, app.mouse_row, vec![ + Span::styled("Enter", Style::default().add_modifier(Modifier::BOLD)), + Span::raw(" toggle setting") + ]); + + let items: Vec = app.settings_items.iter().map(|item| { + match item { + SettingsMenuItem::Header(title) => { + ListItem::new(Line::from(vec![ + Span::styled(format!("[{}]", title), Style::default().fg(COLOR_SECONDARY).add_modifier(Modifier::BOLD)) + ])) + } + SettingsMenuItem::Option { name, val, .. } => { + let total_width = body_area.width.saturating_sub(4); + let left = format!(" {}", name); + let right = val.to_string(); + let padding = total_width.saturating_sub(left.len() as u16).saturating_sub(right.len() as u16); + let spaces = " ".repeat(padding as usize); + ListItem::new(Line::from(vec![ + Span::raw(left), + Span::raw(spaces), + Span::styled(right, Style::default().fg(COLOR_PRIMARY).add_modifier(Modifier::BOLD)), + Span::raw(" ") + ])) + } + } + }).collect(); + + let list = List::new(items) + .highlight_style(Style::default().bg(COLOR_PRIMARY).fg(Color::Black)) + .highlight_symbol(""); + + let items_len = app.settings_items.len(); + if app.mouse_moved { + if let (Some(col), Some(row)) = (app.mouse_col, app.mouse_row) { + if col >= body_area.x && col < body_area.x + body_area.width && row >= body_area.y && row < body_area.y + body_area.height { + let idx = (row - body_area.y) as usize + app.menu_state.offset(); + if idx < items_len { + if !matches!(app.settings_items.get(idx), Some(SettingsMenuItem::Header(_))) { + app.menu_state.select(Some(idx)); + } + } + } + } + app.mouse_moved = false; + } + + f.render_stateful_widget(list, body_area, &mut app.menu_state); +} diff --git a/apps/cli/src/ui/mod.rs b/apps/cli/src/ui/mod.rs index 24ebf9b..f34f8dc 100644 --- a/apps/cli/src/ui/mod.rs +++ b/apps/cli/src/ui/mod.rs @@ -1,9 +1,9 @@ -use crossterm::event::{self, Event, KeyCode, KeyEventKind, MouseEventKind}; +use crossterm::event::{self, Event, KeyCode, KeyEventKind, MouseEventKind, MouseButton}; use ratatui::{ - layout::{Constraint, Direction, Layout, Rect}, - style::{Color, Modifier, Style}, - text::{Line, Span, Text}, - widgets::{Block, Borders, Clear, List, ListItem, ListState, Paragraph, Wrap}, + layout::{Constraint, Direction, Layout}, + style::{Modifier, Style}, + text::Span, + widgets::{Block, ListState, Paragraph}, Frame, Terminal, }; use routecode_sdk::agents::StreamChunk; @@ -11,25 +11,26 @@ use routecode_sdk::core::{AgentOrchestrator, Message, Role, DynamicModelInfo}; use routecode_sdk::utils::costs::Usage; use std::io; use std::sync::Arc; -use tokio::sync::Mutex; use tui_textarea::TextArea; -// --- Theme --- -const COLOR_PRIMARY: Color = Color::Rgb(0, 150, 255); // Ocean Blue -const COLOR_BG: Color = Color::Rgb(25, 25, 25); // Midnight Charcoal -const COLOR_INPUT_BG: Color = Color::Rgb(35, 35, 35);// Soft Obsidian -const COLOR_SECONDARY: Color = Color::DarkGray; // Slate Gray -const COLOR_SYSTEM: Color = Color::Yellow; // Amber Yellow -const COLOR_SUCCESS: Color = Color::Green; // Emerald Green -const COLOR_TEXT: Color = Color::White; // Primary Text -const COLOR_DIM: Color = Color::Rgb(50, 50, 50); // Very Dim Text/Lines +pub mod components; +pub mod welcome; +pub mod session; +pub mod menus; +pub mod logic; + +pub use components::*; +pub use logic::*; +pub use menus::*; +pub use session::*; +pub use welcome::*; pub struct ProviderInfo { pub id: &'static str, pub name: &'static str, } -const PROVIDERS: &[ProviderInfo] = &[ +pub const PROVIDERS: &[ProviderInfo] = &[ ProviderInfo { id: "openrouter", name: "OpenRouter" }, ProviderInfo { id: "nvidia", name: "NVIDIA" }, ProviderInfo { id: "opencode-zen", name: "OpenCode Zen" }, @@ -53,14 +54,16 @@ pub struct Command { pub description: &'static str, } -const COMMANDS: &[Command] = &[ +pub const COMMANDS: &[Command] = &[ Command { name: "/model", description: "Switch model" }, Command { name: "/resume", description: "Resume a session" }, Command { name: "/sessions", description: "List saved sessions" }, Command { name: "/clear", description: "Clear history" }, + Command { name: "/thinking", description: "Set thinking level (low/max)" }, Command { name: "/help", description: "Show help" }, Command { name: "/stop", description: "Stop AI generation" }, Command { name: "/provider", description: "Manage providers" }, + Command { name: "/settings", description: "Manage settings" }, Command { name: "/exit", description: "Exit application" }, ]; @@ -79,6 +82,12 @@ pub enum ApiKeyInputStage { CloudflareApiKey, } +#[derive(Clone, Debug, PartialEq)] +pub enum SettingsMenuItem { + Header(String), + Option { name: String, val: String, key: String }, +} + pub struct App { pub screen: Screen, pub input: TextArea<'static>, @@ -90,11 +99,14 @@ pub struct App { pub show_menu: bool, pub show_provider_menu: bool, pub show_model_menu: bool, + pub show_settings_menu: bool, pub menu_state: ListState, pub filtered_commands: Vec<&'static Command>, pub filtered_models: Vec, pub all_available_models: Vec, pub history_scroll: u16, + pub max_scroll: u16, + pub auto_scroll: bool, pub is_generating: bool, pub tick_count: u64, pub active_tool: Option, @@ -108,8 +120,34 @@ pub struct App { pub api_key_input_stage: ApiKeyInputStage, pub pending_account_id: Option, pub pending_gateway_id: Option, + pub pending_clear: bool, + pub pending_exit: bool, + pub is_fetching_models: bool, + pub collapse_thinking: bool, + pub mouse_row: Option, + pub mouse_col: Option, + pub mouse_moved: bool, + pub mouse_events_count: u64, + pub logo_anim_frames: u16, pub rx: tokio::sync::mpsc::UnboundedReceiver, pub tx: tokio::sync::mpsc::UnboundedSender, + pub settings_items: Vec, + pub last_click_up: Option<(std::time::Instant, u16, u16)>, + pub mouse_down_start: Option<(std::time::Instant, u16, u16)>, + pub temp_expand_thinking: bool, + pub last_toggle_time: Option, + pub thinking_hover_rendered: bool, + pub usage: Usage, + pub cached_history_len: usize, + pub cached_last_msg_len: usize, + pub cached_width: u16, + pub cached_is_collapsed: bool, + pub cached_thinking_hovered: bool, + pub cached_total_height: usize, + pub cached_text: Option>, + pub pending_command_confirmation: Option<(String, String, std::sync::Arc>>>)>, + pub inputting_command_feedback: bool, + pub session_id: String, } impl App { @@ -141,11 +179,15 @@ impl App { show_menu: false, show_provider_menu: false, show_model_menu: false, + show_settings_menu: false, menu_state: ListState::default(), filtered_commands: Vec::new(), filtered_models: Vec::new(), all_available_models: Vec::new(), + settings_items: Vec::new(), history_scroll: 0, + max_scroll: 0, + auto_scroll: true, is_generating: false, tick_count: 0, active_tool: None, @@ -159,11 +201,53 @@ impl App { api_key_input_stage: ApiKeyInputStage::None, pending_account_id: None, pending_gateway_id: None, + pending_clear: false, + pending_exit: false, + is_fetching_models: false, + collapse_thinking: false, + mouse_row: None, + mouse_col: None, + mouse_moved: false, + mouse_events_count: 0, + logo_anim_frames: 0, rx, tx, + usage: Usage::default(), + last_click_up: None, + mouse_down_start: None, + temp_expand_thinking: false, + last_toggle_time: None, + thinking_hover_rendered: false, + cached_history_len: 0, + cached_last_msg_len: 0, + cached_width: 0, + cached_is_collapsed: false, + cached_thinking_hovered: false, + cached_total_height: 0, + cached_text: None, + pending_command_confirmation: None, + inputting_command_feedback: false, + session_id: format!("session_{}", chrono::Utc::now().format("%Y%m%d_%H%M%S")), } } + pub async fn populate_settings(&mut self) { + let config = self.orchestrator.config.lock().await; + self.settings_items = vec![ + SettingsMenuItem::Header("Appearance".to_string()), + SettingsMenuItem::Option { + name: "Logo Animation".to_string(), + val: config.logo_animation.clone(), + key: "logo_animation".to_string(), + }, + SettingsMenuItem::Option { + name: "Animation Theme".to_string(), + val: config.logo_animation_color.clone(), + key: "logo_animation_color".to_string(), + }, + ]; + } + pub fn update_filtered_commands(&mut self) { let input_line = self.input.lines()[0].to_lowercase(); if input_line.starts_with('/') { @@ -181,982 +265,1002 @@ impl App { } } -pub async fn run_app( - terminal: &mut Terminal, - mut app: App, -) -> io::Result<()> { - let mut last_tick = std::time::Instant::now(); - let tick_rate = std::time::Duration::from_millis(100); +/// Compute whether the mouse is hovering over a thinking block, accounting for text wrapping. +/// Uses the same wrapping calculation as the auto-scroll logic in ui_session. +pub fn compute_thinking_hover(app: &App, size: ratatui::layout::Rect) -> bool { + let mouse_row = match app.mouse_row { + Some(r) => r, + None => return false, + }; + if app.screen != Screen::Session { + return false; + } + let has_thinking = app.history.iter().any(|m| m.thought.is_some()); + if !has_thinking { + return false; + } - loop { - let usage = app.orchestrator.usage.lock().await.clone(); - terminal.draw(|f| ui(f, &mut app, &usage))?; + // Compute layout: header=1 row, then history area, then input, then status bar + let input_height = (app.input.lines().len() as u16 + 2).min(12); + // area starts at row 1 (after header). History is area minus input and status. + let area_height = size.height.saturating_sub(1); // main area below header + let history_height = area_height.saturating_sub(input_height).saturating_sub(1); - let timeout = tick_rate - .checked_sub(last_tick.elapsed()) - .unwrap_or_else(|| std::time::Duration::from_secs(0)); + // Check mouse is in history area (row 1 to 1+history_height exclusive) + if mouse_row < 1 || mouse_row >= 1 + history_height { + return false; + } - if event::poll(timeout)? { - match event::read()? { - Event::Key(key) => { - if key.kind == KeyEventKind::Press { - match key.code { - KeyCode::Char('p') if key.modifiers.contains(event::KeyModifiers::CONTROL) => { - app.show_menu = true; - app.menu_state.select(Some(0)); - app.update_filtered_commands(); - } - KeyCode::Char('a') if key.modifiers.contains(event::KeyModifiers::CONTROL) => { - if app.show_model_menu { app.show_model_menu = false; } - app.show_provider_menu = true; - app.menu_state.select(Some(0)); + // The visual row within the history viewport (0-indexed from top of visible area) + let viewport_row = mouse_row - 1; + // The absolute visual row including scroll + let target_visual_row = viewport_row as usize + app.history_scroll as usize; + + // Build the history text and compute wrapping to find which logical line the target row maps to + let is_collapsed = app.collapse_thinking && !app.temp_expand_thinking; + let history_text = render_history(&app.history, is_collapsed, app.thinking_hover_rendered); + let available_width = size.width.max(1) as usize; + let calc_width = (available_width as f32 * 0.95).floor().max(1.0) as usize; + + let mut cumulative_visual_row: usize = 0; + for line in &history_text.lines { + let line_width: usize = line.spans.iter().map(|s| unicode_width::UnicodeWidthStr::width(s.content.as_ref())).sum(); + let wrapped_height = if line_width == 0 { 1 } else { + (line_width + calc_width - 1) / calc_width + }; + + // Check if target_visual_row falls within this logical line's visual rows + if target_visual_row >= cumulative_visual_row && target_visual_row < cumulative_visual_row + wrapped_height { + // Found the line - check if it's a thinking line + return line.spans.iter().any(|span| { + span.content.contains('\u{2502}') || span.content.contains('\u{2503}') || span.content.contains("Thinking...") + }); + } + cumulative_visual_row += wrapped_height; + } + false +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +enum KeyEventResult { + Continue, + Exit, +} + +async fn handle_key_event( + app: &mut App, + key: event::KeyEvent, + is_burst: bool, +) -> io::Result { + if app.pending_command_confirmation.is_some() { + if app.inputting_command_feedback { + match key.code { + KeyCode::Esc => { + app.inputting_command_feedback = false; + app.input.delete_line_by_head(); + while app.input.cursor() != (0, 0) { + app.input.move_cursor(tui_textarea::CursorMove::Head); + app.input.delete_line_by_head(); + } + app.input.set_placeholder_text(" Ask anything... \"How do I use this?\""); + } + KeyCode::Enter => { + if let Some((_, _, tx_mutex)) = app.pending_command_confirmation.take() { + let lines = app.input.lines().to_vec(); + app.input.delete_line_by_head(); + while app.input.cursor() != (0, 0) { + app.input.move_cursor(tui_textarea::CursorMove::Head); + app.input.delete_line_by_head(); + } + app.input.set_placeholder_text(" Ask anything... \"How do I use this?\""); + + let msg = lines.join("\n").trim().to_string(); + let feedback = if msg.is_empty() { "Command cancelled.".to_string() } else { msg }; + + tokio::spawn(async move { + let mut tx_opt = tx_mutex.lock().await; + if let Some(tx) = tx_opt.take() { + let _ = tx.send(routecode_sdk::agents::types::ConfirmationResponse::Feedback(feedback)); } - KeyCode::Char('c') if key.modifiers.contains(event::KeyModifiers::CONTROL) => { - if app.is_generating { - if let Some(handle) = app.current_task.take() { handle.abort(); } - app.is_generating = false; - app.active_tool = None; - } + }); + } + app.inputting_command_feedback = false; + } + _ => { + app.input.input(key); + } + } + } else { + match key.code { + KeyCode::Char('y') | KeyCode::Char('Y') => { + if let Some((_, _, tx_mutex)) = app.pending_command_confirmation.take() { + tokio::spawn(async move { + let mut tx_opt = tx_mutex.lock().await; + if let Some(tx) = tx_opt.take() { + let _ = tx.send(routecode_sdk::agents::types::ConfirmationResponse::AllowOnce); } - KeyCode::Char('l') if key.modifiers.contains(event::KeyModifiers::CONTROL) => { - app.history.clear(); - app.screen = Screen::Welcome; - app.history_scroll = 0; + }); + } + } + KeyCode::Char('s') | KeyCode::Char('S') => { + let mut config = routecode_sdk::utils::storage::load_session_config(&app.session_id).unwrap_or_default(); + config.allow_all_commands = true; + let _ = routecode_sdk::utils::storage::save_session_config(&app.session_id, &config); + + if let Some((_, _, tx_mutex)) = app.pending_command_confirmation.take() { + tokio::spawn(async move { + let mut tx_opt = tx_mutex.lock().await; + if let Some(tx) = tx_opt.take() { + let _ = tx.send(routecode_sdk::agents::types::ConfirmationResponse::AllowSession); } - KeyCode::Enter if key.modifiers.contains(event::KeyModifiers::SHIFT) => { - app.input.insert_newline(); + }); + } + } + KeyCode::Char('w') | KeyCode::Char('W') => { + let mut config = routecode_sdk::utils::storage::load_workspace_config().unwrap_or_default(); + config.allow_all_outside_access = true; + let _ = routecode_sdk::utils::storage::save_workspace_config(&config); + + if let Some((_, _, tx_mutex)) = app.pending_command_confirmation.take() { + tokio::spawn(async move { + let mut tx_opt = tx_mutex.lock().await; + if let Some(tx) = tx_opt.take() { + let _ = tx.send(routecode_sdk::agents::types::ConfirmationResponse::AllowWorkspace); } - KeyCode::Enter => { - if app.show_menu { - if let Some(selected) = app.menu_state.selected() { - if let Some(cmd) = app.filtered_commands.get(selected) { - let name = cmd.name.to_string(); - app.show_menu = false; - app.input = TextArea::default(); - handle_command(&mut app, &name).await; - } - } - } else if app.show_provider_menu { - if let Some(selected) = app.menu_state.selected() { - if let Some(p) = PROVIDERS.get(selected) { - app.pending_provider_id = Some(p.id.to_string()); - app.is_inputting_api_key = true; - app.api_key_input = TextArea::default(); - app.show_provider_menu = false; - if p.id == "cloudflare-workers" || p.id == "cloudflare-gateway" { - app.api_key_input_stage = ApiKeyInputStage::CloudflareAccountId; - } else { - app.api_key_input_stage = ApiKeyInputStage::ApiKey; - } - } - } - } else if app.show_model_menu { - if let Some(selected) = app.menu_state.selected() { - if let Some(ModelMenuItem::Model(model_info)) = app.filtered_models.get(selected).cloned() { - let provider_id = &model_info.provider_id; - let model_name = &model_info.name; - let mut config = app.orchestrator.config.lock().await; - let env_key = format!("{}_API_KEY", provider_id.to_uppercase().replace("-", "_")); - let api_key = std::env::var(env_key).ok().or_else(|| config.api_keys.get(provider_id).cloned()); - if let Some(key) = api_key { - config.model = model_name.clone(); - config.provider = provider_id.clone(); - config.recent_models.retain(|m| m.name != *model_name || m.provider_id != *provider_id); - config.recent_models.insert(0, model_info.clone()); - config.recent_models.truncate(3); - let _ = routecode_sdk::utils::storage::save_config(&config); - if app.provider_name.to_lowercase() != *provider_id { - let provider = routecode_sdk::agents::resolve_provider(provider_id, key); - app.provider_name = provider.name().to_string(); - app.current_provider_id = provider_id.clone(); - drop(config); - app.orchestrator.change_provider(provider).await; - } else { drop(config); } - app.current_model = model_name.clone(); - app.history.push(Message::system(format!("Switched to {} on {}", model_name, app.provider_name))); - app.show_model_menu = false; - } else { - app.history.push(Message::system(format!("Error: No API key for {}", provider_id))); - } - } - } - } else if app.is_inputting_api_key { - let input_value = app.api_key_input.lines().join("\n").trim().to_string(); - if !input_value.is_empty() { - match app.api_key_input_stage { - ApiKeyInputStage::ApiKey => { - if let Some(provider_id) = app.pending_provider_id.take() { - let mut config = app.orchestrator.config.lock().await; - config.api_keys.insert(provider_id.clone(), input_value); - let _ = routecode_sdk::utils::storage::save_config(&config); - app.history.push(Message::system(format!("API Key saved for {}", provider_id))); - } - app.is_inputting_api_key = false; - app.api_key_input_stage = ApiKeyInputStage::None; - } - ApiKeyInputStage::CloudflareAccountId => { - app.pending_account_id = Some(input_value); - app.api_key_input = TextArea::default(); - if app.pending_provider_id.as_deref() == Some("cloudflare-gateway") { - app.api_key_input_stage = ApiKeyInputStage::CloudflareGatewayId; - } else { app.api_key_input_stage = ApiKeyInputStage::CloudflareApiKey; } - } - ApiKeyInputStage::CloudflareGatewayId => { - app.pending_gateway_id = Some(input_value); - app.api_key_input = TextArea::default(); - app.api_key_input_stage = ApiKeyInputStage::CloudflareApiKey; - } - ApiKeyInputStage::CloudflareApiKey => { - if let Some(provider_id) = app.pending_provider_id.take() { - let account_id = app.pending_account_id.take().unwrap_or_default(); - let final_key = if provider_id == "cloudflare-gateway" { - let gateway_id = app.pending_gateway_id.take().unwrap_or_default(); - format!("{}:{}:{}", account_id, gateway_id, input_value) - } else { format!("{}:{}", account_id, input_value) }; - let mut config = app.orchestrator.config.lock().await; - config.api_keys.insert(provider_id.clone(), final_key); - let _ = routecode_sdk::utils::storage::save_config(&config); - app.history.push(Message::system(format!("Credentials saved for {}", provider_id))); - } - app.is_inputting_api_key = false; - app.api_key_input_stage = ApiKeyInputStage::None; - } - _ => { app.is_inputting_api_key = false; } - } - } else { - app.is_inputting_api_key = false; - app.api_key_input_stage = ApiKeyInputStage::None; - } - } else { - let input_text = app.input.lines().join("\n"); - if !input_text.trim().is_empty() { - if input_text.starts_with('/') { - handle_command(&mut app, &input_text).await; - } else { - app.history.push(Message::user(input_text.clone())); - app.prompt_history.push(input_text.clone()); - app.prompt_history_index = None; - app.input = TextArea::default(); - app.screen = Screen::Session; - app.is_generating = true; - let orchestrator = app.orchestrator.clone(); - let mut history = app.history.clone(); - let model = app.current_model.clone(); - let tx = app.tx.clone(); - let task = tokio::spawn(async move { - let _ = orchestrator.run(&mut history, &model, Some(tx)).await; - }); - app.current_task = Some(task); - } - app.input = TextArea::default(); - } - } + }); + } + } + KeyCode::Char('d') | KeyCode::Char('D') | KeyCode::Esc => { + if let Some((_, _, tx_mutex)) = app.pending_command_confirmation.take() { + tokio::spawn(async move { + let mut tx_opt = tx_mutex.lock().await; + if let Some(tx) = tx_opt.take() { + let _ = tx.send(routecode_sdk::agents::types::ConfirmationResponse::Deny); } - KeyCode::Esc => { - if app.show_menu { app.show_menu = false; } - else if app.show_provider_menu { app.show_provider_menu = false; } - else if app.show_model_menu { app.show_model_menu = false; } - else if app.is_inputting_api_key { - app.is_inputting_api_key = false; - app.api_key_input_stage = ApiKeyInputStage::None; - app.pending_account_id = None; - app.pending_gateway_id = None; - } else if app.is_generating { - if let Some(handle) = app.current_task.take() { handle.abort(); } - app.is_generating = false; - app.active_tool = None; - } else { - if !app.history.is_empty() { - let session = routecode_sdk::utils::storage::Session { - messages: app.history.clone(), - model: app.current_model.clone(), - usage: app.orchestrator.usage.lock().await.clone(), - timestamp: chrono::Utc::now().timestamp(), - }; - let _ = routecode_sdk::utils::storage::save_session("last_session", &session); - } - return Ok(()); - } + }); + } + } + KeyCode::Char('f') | KeyCode::Char('F') => { + app.inputting_command_feedback = true; + app.input.set_placeholder_text(" Tell agent (e.g. 'don't run without backup')..."); + } + _ => {} + } + } + return Ok(KeyEventResult::Continue); + } + + if app.pending_clear { + match key.code { + KeyCode::Char('y') | KeyCode::Char('Y') | KeyCode::Enter => { + app.history.clear(); + app.screen = Screen::Welcome; + app.history_scroll = 0; + app.pending_clear = false; + } + KeyCode::Char('n') | KeyCode::Char('N') | KeyCode::Esc => { + app.pending_clear = false; + } + _ => {} + } + return Ok(KeyEventResult::Continue); + } + if app.pending_exit { + match key.code { + KeyCode::Char('y') | KeyCode::Char('Y') | KeyCode::Enter => { + if !app.history.is_empty() { + let session = routecode_sdk::utils::storage::Session { + messages: app.history.clone(), + model: app.current_model.clone(), + usage: app.orchestrator.usage.lock().await.clone(), + timestamp: chrono::Utc::now().timestamp(), + }; + let _ = routecode_sdk::utils::storage::save_session(&app.session_id, &session); + } + return Ok(KeyEventResult::Exit); + } + KeyCode::Char('n') | KeyCode::Char('N') | KeyCode::Esc => { + app.pending_exit = false; + } + _ => {} + } + return Ok(KeyEventResult::Continue); + } + match key.code { + KeyCode::Char('p') if key.modifiers.contains(event::KeyModifiers::CONTROL) => { + app.show_menu = true; + app.menu_state.select(Some(0)); + app.update_filtered_commands(); + } + KeyCode::Char('a') if key.modifiers.contains(event::KeyModifiers::CONTROL) => { + if app.show_model_menu { app.show_model_menu = false; } + app.show_provider_menu = true; + app.menu_state.select(Some(0)); + } + KeyCode::Char('c') if key.modifiers.contains(event::KeyModifiers::CONTROL) => { + if app.is_generating { + if let Some(handle) = app.current_task.take() { handle.abort(); } + app.is_generating = false; + app.active_tool = None; + } + } + KeyCode::Char('l') if key.modifiers.contains(event::KeyModifiers::CONTROL) => { + app.history.clear(); + app.screen = Screen::Welcome; + app.history_scroll = 0; + } + KeyCode::Enter if key.modifiers.contains(event::KeyModifiers::SHIFT) || key.modifiers.contains(event::KeyModifiers::ALT) => { + app.input.insert_newline(); + } + KeyCode::Enter => { + let mut should_send = !is_burst; + if should_send { + let lines = app.input.lines(); + if let Some(last_line) = lines.last() { + if last_line.ends_with('\\') { + app.input.delete_char(); + app.input.insert_newline(); + should_send = false; + } + } + } + + if !should_send { + app.input.insert_newline(); + } else { + if app.show_menu { + if let Some(selected) = app.menu_state.selected() { + if let Some(cmd) = app.filtered_commands.get(selected) { + let name = cmd.name.to_string(); + app.show_menu = false; + app.input = TextArea::default(); + handle_command(app, &name).await; + } + } + } else if app.show_provider_menu { + if let Some(selected) = app.menu_state.selected() { + if let Some(p) = PROVIDERS.get(selected) { + app.pending_provider_id = Some(p.id.to_string()); + app.is_inputting_api_key = true; + app.api_key_input = TextArea::default(); + app.show_provider_menu = false; + if p.id == "cloudflare-workers" || p.id == "cloudflare-gateway" { + app.api_key_input_stage = ApiKeyInputStage::CloudflareAccountId; + } else { + app.api_key_input_stage = ApiKeyInputStage::ApiKey; } - KeyCode::Up => { - if app.show_menu || app.show_provider_menu || app.show_model_menu { - let items_len = if app.show_menu { app.filtered_commands.len() } - else if app.show_provider_menu { PROVIDERS.len() } - else { app.filtered_models.len() }; - if items_len > 0 { - let selected = app.menu_state.selected().unwrap_or(0); - let mut new_selected = if selected == 0 { items_len - 1 } else { selected - 1 }; - if app.show_model_menu { - while let Some(ModelMenuItem::Header(_)) = app.filtered_models.get(new_selected) { - new_selected = if new_selected == 0 { items_len - 1 } else { new_selected - 1 }; - if new_selected == selected { break; } - } - } - app.menu_state.select(Some(new_selected)); - } - } else if key.modifiers.contains(event::KeyModifiers::SHIFT) { - app.history_scroll = app.history_scroll.saturating_sub(1); - } else { - let (row, _) = app.input.cursor(); - if row == 0 && !app.prompt_history.is_empty() { - let idx = match app.prompt_history_index { - Some(i) => if i == 0 { 0 } else { i - 1 }, - None => app.prompt_history.len() - 1, - }; - app.prompt_history_index = Some(idx); - let prev = app.prompt_history[idx].clone(); - app.input = TextArea::from(prev.lines().map(|s| s.to_string())); - app.input.move_cursor(tui_textarea::CursorMove::End); - } else { app.input.input(key); } - } + } + } + } else if app.show_settings_menu { + if let Some(selected) = app.menu_state.selected() { + if let Some(SettingsMenuItem::Option { key, val, .. }) = app.settings_items.get(selected) { + if key == "logo_animation" { + let next_val = match val.as_str() { + "always" => "hover", + "hover" => "click", + _ => "always", + }; + let mut config = app.orchestrator.config.lock().await; + config.logo_animation = next_val.to_string(); + let _ = routecode_sdk::utils::storage::save_config(&config); + drop(config); + app.populate_settings().await; + } else if key == "logo_animation_color" { + let next_val = match val.as_str() { + "rainbow" => "neon", + "neon" => "cyberpunk", + "cyberpunk" => "sunset", + "sunset" => "mono", + _ => "rainbow", + }; + let mut config = app.orchestrator.config.lock().await; + config.logo_animation_color = next_val.to_string(); + let _ = routecode_sdk::utils::storage::save_config(&config); + drop(config); + app.populate_settings().await; } - KeyCode::Down => { - if app.show_menu || app.show_provider_menu || app.show_model_menu { - let items_len = if app.show_menu { app.filtered_commands.len() } - else if app.show_provider_menu { PROVIDERS.len() } - else { app.filtered_models.len() }; - if items_len > 0 { - let selected = app.menu_state.selected().unwrap_or(0); - let mut new_selected = if selected >= items_len - 1 { 0 } else { selected + 1 }; - if app.show_model_menu { - while let Some(ModelMenuItem::Header(_)) = app.filtered_models.get(new_selected) { - new_selected = if new_selected >= items_len - 1 { 0 } else { new_selected + 1 }; - if new_selected == selected { break; } - } - } - app.menu_state.select(Some(new_selected)); - } - } else if key.modifiers.contains(event::KeyModifiers::SHIFT) { - app.history_scroll = app.history_scroll.saturating_add(1); + } + } + } else if app.show_model_menu { + if let Some(selected) = app.menu_state.selected() { + match app.filtered_models.get(selected) { + Some(ModelMenuItem::Model(model_info)) => { + let model_info = model_info.clone(); + let provider_id = &model_info.provider_id; + let model_name = &model_info.name; + let mut config = app.orchestrator.config.lock().await; + let env_key = format!("{}_API_KEY", provider_id.to_uppercase().replace("-", "_")); + let api_key = std::env::var(env_key).ok().or_else(|| config.api_keys.get(provider_id).cloned()); + if let Some(key) = api_key { + config.model = model_name.clone(); + config.provider = provider_id.clone(); + config.recent_models.retain(|m| m.name != *model_name || m.provider_id != *provider_id); + config.recent_models.insert(0, model_info.clone()); + config.recent_models.truncate(3); + let _ = routecode_sdk::utils::storage::save_config(&config); + if app.provider_name.to_lowercase() != *provider_id { + let provider = routecode_sdk::agents::resolve_provider(provider_id, key); + app.provider_name = provider.name().to_string(); + app.current_provider_id = provider_id.clone(); + drop(config); + app.orchestrator.change_provider(provider).await; + } else { drop(config); } + app.current_model = model_name.clone(); + app.history.push(Message::system(format!("Switched to {} on {}", model_name, app.provider_name))); + app.show_model_menu = false; } else { - let (row, _) = app.input.cursor(); - let lines_len = app.input.lines().len(); - if row >= lines_len - 1 && app.prompt_history_index.is_some() { - let idx = app.prompt_history_index.unwrap(); - if idx >= app.prompt_history.len() - 1 { - app.prompt_history_index = None; - app.input = TextArea::default(); - } else { - let new_idx = idx + 1; - app.prompt_history_index = Some(new_idx); - let next = app.prompt_history[new_idx].clone(); - app.input = TextArea::from(next.lines().map(|s| s.to_string())); - app.input.move_cursor(tui_textarea::CursorMove::End); - } - } else { app.input.input(key); } + app.history.push(Message::system(format!("Error: No API key for {}", provider_id))); } } - KeyCode::Right if app.show_model_menu => { - let len = app.filtered_models.len(); - if len > 0 { - let current = app.menu_state.selected().unwrap_or(0); - let mut next_header_idx = None; - for i in (current + 1)..len { - if let Some(ModelMenuItem::Header(_)) = app.filtered_models.get(i) { - next_header_idx = Some(i); break; - } - } - if next_header_idx.is_none() { - for i in 0..current { - if let Some(ModelMenuItem::Header(_)) = app.filtered_models.get(i) { - next_header_idx = Some(i); break; - } - } - } - if let Some(h_idx) = next_header_idx { - let mut target = (h_idx + 1) % len; - while let Some(ModelMenuItem::Header(_)) = app.filtered_models.get(target) { - target = (target + 1) % len; - if target == h_idx { break; } - } - app.menu_state.select(Some(target)); - } + _ => {} + } + } + } else if app.is_inputting_api_key { + let input_value = app.api_key_input.lines().join("\n").trim().to_string(); + if !input_value.is_empty() { + match app.api_key_input_stage { + ApiKeyInputStage::ApiKey => { + if let Some(provider_id) = app.pending_provider_id.take() { + let mut config = app.orchestrator.config.lock().await; + config.api_keys.insert(provider_id.clone(), input_value); + let _ = routecode_sdk::utils::storage::save_config(&config); + app.history.push(Message::system(format!("API Key saved for {}", provider_id))); } + app.is_inputting_api_key = false; + app.api_key_input_stage = ApiKeyInputStage::None; } - KeyCode::Left if app.show_model_menu => { - let len = app.filtered_models.len(); - if len > 0 { - let current = app.menu_state.selected().unwrap_or(0); - let mut headers = Vec::new(); - for (i, item) in app.filtered_models.iter().enumerate() { - if let ModelMenuItem::Header(_) = item { headers.push(i); } - } - if !headers.is_empty() { - let current_header_idx_in_headers = headers.iter().enumerate().rev().find(|(_, &h_idx)| h_idx < current).map(|(i, _)| i); - let target_header_idx = match current_header_idx_in_headers { - Some(i) => if i == 0 { *headers.last().unwrap() } else { headers[i - 1] }, - None => *headers.last().unwrap() - }; - let mut target = (target_header_idx + 1) % len; - while let Some(ModelMenuItem::Header(_)) = app.filtered_models.get(target) { - target = (target + 1) % len; - if target == target_header_idx { break; } - } - app.menu_state.select(Some(target)); - } - } + ApiKeyInputStage::CloudflareAccountId => { + app.pending_account_id = Some(input_value); + app.api_key_input = TextArea::default(); + if app.pending_provider_id.as_deref() == Some("cloudflare-gateway") { + app.api_key_input_stage = ApiKeyInputStage::CloudflareGatewayId; + } else { app.api_key_input_stage = ApiKeyInputStage::CloudflareApiKey; } } - KeyCode::Char('f') if key.modifiers.contains(event::KeyModifiers::CONTROL) && app.show_model_menu => { - if let Some(selected) = app.menu_state.selected() { - if let Some(ModelMenuItem::Model(model_info)) = app.filtered_models.get(selected).cloned() { - let mut config = app.orchestrator.config.lock().await; - if config.favorites.iter().any(|m| m.name == model_info.name && m.provider_id == model_info.provider_id) { - config.favorites.retain(|m| m.name != model_info.name || m.provider_id != model_info.provider_id); - app.history.push(Message::system(format!("Removed {} from favorites", model_info.name))); - } else { - config.favorites.push(model_info.clone()); - app.history.push(Message::system(format!("Added {} to favorites", model_info.name))); - } - let _ = routecode_sdk::utils::storage::save_config(&config); - } - } + ApiKeyInputStage::CloudflareGatewayId => { + app.pending_gateway_id = Some(input_value); + app.api_key_input = TextArea::default(); + app.api_key_input_stage = ApiKeyInputStage::CloudflareApiKey; } - _ => { - let event = event::Event::Key(key); - if app.is_inputting_api_key { - app.api_key_input.input(event); - } else if app.show_model_menu { - if app.model_search_input.input(event) { - let search = app.model_search_input.lines()[0].to_lowercase().trim().to_string(); - handle_model_search(&mut app, &search, true).await; - } - } else { - app.input.input(event); - app.update_filtered_commands(); + ApiKeyInputStage::CloudflareApiKey => { + if let Some(provider_id) = app.pending_provider_id.take() { + let account_id = app.pending_account_id.take().unwrap_or_default(); + let final_key = if provider_id == "cloudflare-gateway" { + let gateway_id = app.pending_gateway_id.take().unwrap_or_default(); + format!("{}:{}:{}", account_id, gateway_id, input_value) + } else { format!("{}:{}", account_id, input_value) }; + let mut config = app.orchestrator.config.lock().await; + config.api_keys.insert(provider_id.clone(), final_key); + let _ = routecode_sdk::utils::storage::save_config(&config); + app.history.push(Message::system(format!("Credentials saved for {}", provider_id))); } + app.is_inputting_api_key = false; + app.api_key_input_stage = ApiKeyInputStage::None; } + _ => { app.is_inputting_api_key = false; } } + } else { + app.is_inputting_api_key = false; + app.api_key_input_stage = ApiKeyInputStage::None; } - } - Event::Mouse(mouse) => { - match mouse.kind { - MouseEventKind::ScrollUp => { app.history_scroll = app.history_scroll.saturating_sub(2); } - MouseEventKind::ScrollDown => { app.history_scroll = app.history_scroll.saturating_add(2); } - _ => {} + } else { + let input_text = app.input.lines().join("\n"); + if !input_text.trim().is_empty() { + if input_text.starts_with('/') { + handle_command(app, &input_text).await; + } else { + app.history.push(Message::user(input_text.clone())); + app.prompt_history.push(input_text.clone()); + app.prompt_history_index = None; + app.input = TextArea::default(); + app.screen = Screen::Session; + app.is_generating = true; + app.auto_scroll = true; + let orchestrator = app.orchestrator.clone(); + let mut history = app.history.clone(); + let model = app.current_model.clone(); + let tx = app.tx.clone(); + let task = tokio::spawn(async move { + let _ = orchestrator.run(&mut history, &model, Some(tx)).await; + }); + app.current_task = Some(task); + } + app.input = TextArea::default(); } } - _ => {} } } - - if last_tick.elapsed() >= tick_rate { - app.tick_count += 1; - last_tick = std::time::Instant::now(); + KeyCode::Esc => { + if app.show_menu { app.show_menu = false; } + else if app.show_provider_menu { app.show_provider_menu = false; } + else if app.show_model_menu { app.show_model_menu = false; } + else if app.show_settings_menu { app.show_settings_menu = false; } + else if app.is_inputting_api_key { + app.is_inputting_api_key = false; + app.api_key_input_stage = ApiKeyInputStage::None; + app.pending_account_id = None; + app.pending_gateway_id = None; + } else if app.is_generating { + if let Some(handle) = app.current_task.take() { handle.abort(); } + app.is_generating = false; + app.active_tool = None; + } else { + app.pending_exit = true; + } } - - while let Ok(chunk) = app.rx.try_recv() { - match chunk { - StreamChunk::Text { content } => { - if let Some(last) = app.history.last_mut() { - if last.role == Role::Assistant { - let mut current = last.content.clone().unwrap_or_default(); - current.push_str(&content); - last.content = Some(current); - } else { - app.history.push(Message::assistant(Some(content), None, None)); + KeyCode::Char('t') if key.modifiers.contains(event::KeyModifiers::CONTROL) => { + app.auto_scroll = !app.auto_scroll; + app.history.push(Message::system(format!("Auto-scroll {}", if app.auto_scroll { "enabled" } else { "disabled" }))); + } + KeyCode::Char('o') if key.modifiers.contains(event::KeyModifiers::CONTROL) => { + app.collapse_thinking = !app.collapse_thinking; + } + KeyCode::End => { + app.auto_scroll = true; + app.history_scroll = app.max_scroll; + } + KeyCode::Up if key.modifiers.contains(event::KeyModifiers::CONTROL) => { + let (row, _) = app.input.cursor(); + if row == 0 && app.input.lines().len() == 1 && app.input.lines()[0].is_empty() && !app.prompt_history.is_empty() { + let idx = match app.prompt_history_index { + Some(i) => if i == 0 { 0 } else { i - 1 }, + None => app.prompt_history.len() - 1, + }; + app.prompt_history_index = Some(idx); + let prev = app.prompt_history[idx].clone(); + app.input = TextArea::from(prev.lines().map(|s| s.to_string())); + app.input.move_cursor(tui_textarea::CursorMove::End); + } + } + KeyCode::Down if key.modifiers.contains(event::KeyModifiers::CONTROL) => { + let (row, _) = app.input.cursor(); + let lines_len = app.input.lines().len(); + if row >= lines_len - 1 && app.prompt_history_index.is_some() { + let idx = app.prompt_history_index.unwrap(); + if idx >= app.prompt_history.len() - 1 { + app.prompt_history_index = None; + app.input = TextArea::default(); + } else { + let new_idx = idx + 1; + app.prompt_history_index = Some(new_idx); + let next = app.prompt_history[new_idx].clone(); + app.input = TextArea::from(next.lines().map(|s| s.to_string())); + app.input.move_cursor(tui_textarea::CursorMove::End); + } + } + } + KeyCode::Up => { + if app.show_menu || app.show_provider_menu || app.show_model_menu || app.show_settings_menu { + let items_len = if app.show_menu { app.filtered_commands.len() } + else if app.show_provider_menu { PROVIDERS.len() } + else if app.show_settings_menu { app.settings_items.len() } + else { app.filtered_models.len() }; + if items_len > 0 { + let selected = app.menu_state.selected().unwrap_or(0); + let mut new_selected = if selected == 0 { items_len - 1 } else { selected - 1 }; + if app.show_model_menu { + while let Some(ModelMenuItem::Header(_)) = app.filtered_models.get(new_selected) { + new_selected = if new_selected == 0 { items_len - 1 } else { new_selected - 1 }; + if new_selected == selected { break; } + } + } else if app.show_settings_menu { + while let Some(SettingsMenuItem::Header(_)) = app.settings_items.get(new_selected) { + new_selected = if new_selected == 0 { items_len - 1 } else { new_selected - 1 }; + if new_selected == selected { break; } } - } else { - app.history.push(Message::assistant(Some(content), None, None)); } + app.menu_state.select(Some(new_selected)); } - StreamChunk::Thought { content } => { - if let Some(last) = app.history.last_mut() { - if last.role == Role::Assistant { - let mut current = last.thought.clone().unwrap_or_default(); - current.push_str(&content); - last.thought = Some(current); - } else { - app.history.push(Message::assistant(None, Some(content), None)); + } else { + if app.input.lines().len() == 1 && app.input.lines()[0].is_empty() || app.history_scroll > 0 || app.is_generating || key.modifiers.contains(event::KeyModifiers::SHIFT) { + app.history_scroll = app.history_scroll.saturating_sub(15); + app.auto_scroll = false; + } else { + app.input.input(Event::Key(key)); + } + } + } + KeyCode::Down => { + if app.show_menu || app.show_provider_menu || app.show_model_menu || app.show_settings_menu { + let items_len = if app.show_menu { app.filtered_commands.len() } + else if app.show_provider_menu { PROVIDERS.len() } + else if app.show_settings_menu { app.settings_items.len() } + else { app.filtered_models.len() }; + if items_len > 0 { + let selected = app.menu_state.selected().unwrap_or(0); + let mut new_selected = if selected >= items_len - 1 { 0 } else { selected + 1 }; + if app.show_model_menu { + while let Some(ModelMenuItem::Header(_)) = app.filtered_models.get(new_selected) { + new_selected = if new_selected >= items_len - 1 { 0 } else { new_selected + 1 }; + if new_selected == selected { break; } + } + } else if app.show_settings_menu { + while let Some(SettingsMenuItem::Header(_)) = app.settings_items.get(new_selected) { + new_selected = if new_selected >= items_len - 1 { 0 } else { new_selected + 1 }; + if new_selected == selected { break; } } + } + app.menu_state.select(Some(new_selected)); + } + } else { + if app.input.lines().len() == 1 && app.input.lines()[0].is_empty() || app.history_scroll < app.max_scroll || app.is_generating || key.modifiers.contains(event::KeyModifiers::SHIFT) { + app.history_scroll = app.history_scroll.saturating_add(15); + if app.history_scroll >= app.max_scroll { app.auto_scroll = true; } + } else { + app.input.input(Event::Key(key)); + } + } + } + KeyCode::Right if app.show_model_menu => { + let len = app.filtered_models.len(); + if len > 0 { + let current = app.menu_state.selected().unwrap_or(0); + let mut next_header_idx = None; + for i in (current + 1)..len { if let Some(ModelMenuItem::Header(_)) = app.filtered_models.get(i) { next_header_idx = Some(i); break; } } + if next_header_idx.is_none() { for i in 0..current { if let Some(ModelMenuItem::Header(_)) = app.filtered_models.get(i) { next_header_idx = Some(i); break; } } } + if let Some(h_idx) = next_header_idx { + let mut target = (h_idx + 1) % len; + while let Some(ModelMenuItem::Header(_)) = app.filtered_models.get(target) { target = (target + 1) % len; if target == h_idx { break; } } + app.menu_state.select(Some(target)); + } + } + } + KeyCode::Left if app.show_model_menu => { + let len = app.filtered_models.len(); + if len > 0 { + let current = app.menu_state.selected().unwrap_or(0); + let mut headers = Vec::new(); + for (i, item) in app.filtered_models.iter().enumerate() { if let ModelMenuItem::Header(_) = item { headers.push(i); } } + if !headers.is_empty() { + let current_header_idx_in_headers = headers.iter().enumerate().rev().find(|(_, &h_idx)| h_idx < current).map(|(i, _)| i); + let target_header_idx = match current_header_idx_in_headers { Some(i) => if i == 0 { *headers.last().unwrap() } else { headers[i - 1] }, None => *headers.last().unwrap() }; + let mut target = (target_header_idx + 1) % len; + while let Some(ModelMenuItem::Header(_)) = app.filtered_models.get(target) { target = (target + 1) % len; if target == target_header_idx { break; } } + app.menu_state.select(Some(target)); + } + } + } + KeyCode::Char('f') if key.modifiers.contains(event::KeyModifiers::CONTROL) && app.show_model_menu => { + if let Some(selected) = app.menu_state.selected() { + match app.filtered_models.get(selected) { + Some(ModelMenuItem::Model(model_info)) => { + let model_info = model_info.clone(); + let mut config = app.orchestrator.config.lock().await; + if config.favorites.iter().any(|m| m.name == model_info.name && m.provider_id == model_info.provider_id) { config.favorites.retain(|m| m.name != model_info.name || m.provider_id != model_info.provider_id); app.history.push(Message::system(format!("Removed {} from favorites", model_info.name))); } + else { config.favorites.push(model_info.clone()); app.history.push(Message::system(format!("Added {} to favorites", model_info.name))); } + let _ = routecode_sdk::utils::storage::save_config(&config); + } + _ => {} + } + } + } + _ => { + let event = Event::Key(key); + if app.is_inputting_api_key { app.api_key_input.input(event); } + else if app.show_model_menu { if app.model_search_input.input(event) { let search = app.model_search_input.lines()[0].to_lowercase().trim().to_string(); handle_model_search(app, &search, true).await; } } + else { app.input.input(event); app.update_filtered_commands(); } + } + } + Ok(KeyEventResult::Continue) +} + +async fn handle_mouse_event( + app: &mut App, + mouse: event::MouseEvent, + terminal: &mut Terminal, +) -> io::Result<()> { + app.mouse_events_count += 1; + // Always store current mouse position for render-time hover detection + app.mouse_row = Some(mouse.row); + app.mouse_col = Some(mouse.column); + match mouse.kind { + MouseEventKind::Moved => { + app.mouse_moved = true; + } + MouseEventKind::ScrollUp => { + if app.show_menu || app.show_provider_menu || app.show_model_menu || app.show_settings_menu { + let mut current = app.menu_state.selected().unwrap_or(0); + current = current.saturating_sub(3); + if app.show_model_menu { + while current > 0 && matches!(app.filtered_models.get(current), Some(crate::ui::ModelMenuItem::Header(_))) { + current -= 1; + } + } else if app.show_settings_menu { + while current > 0 && matches!(app.settings_items.get(current), Some(crate::ui::SettingsMenuItem::Header(_))) { + current -= 1; + } + } + app.menu_state.select(Some(current)); + } else { + app.history_scroll = app.history_scroll.saturating_sub(15); + app.auto_scroll = false; + } + } + MouseEventKind::ScrollDown => { + if app.show_menu || app.show_provider_menu || app.show_model_menu || app.show_settings_menu { + let current = app.menu_state.selected().unwrap_or(0); + let max = if app.show_menu { + app.filtered_commands.len() + } else if app.show_provider_menu { + crate::ui::PROVIDERS.len() + } else if app.show_settings_menu { + app.settings_items.len() + } else { + app.filtered_models.len() + }; + let mut next = current.saturating_add(3).min(max.saturating_sub(1)); + if app.show_model_menu { + while next < max - 1 && matches!(app.filtered_models.get(next), Some(crate::ui::ModelMenuItem::Header(_))) { + next += 1; + } + } else if app.show_settings_menu { + while next < max - 1 && matches!(app.settings_items.get(next), Some(crate::ui::SettingsMenuItem::Header(_))) { + next += 1; + } + } + app.menu_state.select(Some(next)); + } else { + app.history_scroll = app.history_scroll.saturating_add(15); + if app.history_scroll >= app.max_scroll { app.auto_scroll = true; } + } + } + MouseEventKind::Down(MouseButton::Left) | MouseEventKind::Up(MouseButton::Left) => { + if app.show_menu || app.show_provider_menu || app.show_model_menu || app.show_settings_menu { + if let Ok(size) = terminal.size() { + let (width, height) = if app.show_menu { + (60, (app.filtered_commands.len() + 6).min(15) as u16) + } else if app.show_provider_menu { + (60, (crate::ui::PROVIDERS.len() + 6).min(15) as u16) + } else if app.show_settings_menu { + (60, (app.settings_items.len() + 6).min(15) as u16) } else { - app.history.push(Message::assistant(None, Some(content), None)); + (70, (app.filtered_models.len() + 7).min(18) as u16) + }; + let modal_x = (size.width.saturating_sub(width)) / 2; + let modal_y = (size.height.saturating_sub(height)) / 2; + + let is_outside = mouse.column < modal_x || mouse.column >= modal_x + width || mouse.row < modal_y || mouse.row >= modal_y + height; + let is_esc = mouse.row <= modal_y + 2 && mouse.column >= modal_x + width.saturating_sub(10) && mouse.column <= modal_x + width; + let is_inside_list = mouse.row >= modal_y + 2 && mouse.row < modal_y + height - 1 && mouse.column >= modal_x + 1 && mouse.column < modal_x + width - 1; + + if is_outside || is_esc { + app.show_menu = false; + app.show_provider_menu = false; + app.show_model_menu = false; + app.show_settings_menu = false; + } else if is_inside_list && matches!(mouse.kind, MouseEventKind::Up(MouseButton::Left)) { + if app.show_settings_menu { + let idx = (mouse.row - (modal_y + 2)) as usize + app.menu_state.offset(); + if idx < app.settings_items.len() { + if let Some(SettingsMenuItem::Option { key, val, .. }) = app.settings_items.get(idx) { + if key == "logo_animation" { + let next_val = match val.as_str() { + "always" => "hover", + "hover" => "click", + _ => "always", + }; + let mut config = app.orchestrator.config.lock().await; + config.logo_animation = next_val.to_string(); + let _ = routecode_sdk::utils::storage::save_config(&config); + drop(config); + app.populate_settings().await; + } else if key == "logo_animation_color" { + let next_val = match val.as_str() { + "rainbow" => "neon", + "neon" => "cyberpunk", + "cyberpunk" => "sunset", + "sunset" => "mono", + _ => "rainbow", + }; + let mut config = app.orchestrator.config.lock().await; + config.logo_animation_color = next_val.to_string(); + let _ = routecode_sdk::utils::storage::save_config(&config); + drop(config); + app.populate_settings().await; + } + } + } + } } } - StreamChunk::ToolCall { tool_call } => { - app.active_tool = Some(tool_call.function.name.clone()); - if let Some(last) = app.history.last_mut() { - if last.role == Role::Assistant { - let mut calls = last.tool_calls.clone().unwrap_or_default(); - calls.push(tool_call); - last.tool_calls = Some(calls); + } else if app.screen == Screen::Session { + let has_thinking = app.history.iter().any(|m| m.thought.is_some()); + if matches!(mouse.kind, MouseEventKind::Down(MouseButton::Left)) { + let in_cooldown = app.last_toggle_time.map_or(false, |t| t.elapsed() < std::time::Duration::from_millis(400)); + + if !in_cooldown && has_thinking { + let is_double_click = if let Some((last_time, col, row)) = app.last_click_up { + let col_diff = (col as i32 - mouse.column as i32).abs(); + let row_diff = (row as i32 - mouse.row as i32).abs(); + last_time.elapsed() < std::time::Duration::from_millis(600) && col_diff <= 4 && row_diff <= 3 } else { - app.history.push(Message::assistant(None, None, Some(vec![tool_call]))); + false + }; + + if is_double_click { + app.collapse_thinking = !app.collapse_thinking; + app.last_click_up = None; + app.mouse_down_start = None; + app.last_toggle_time = Some(std::time::Instant::now()); + } else if let Ok(size) = terminal.size() { + // Compute hover FRESH with current mouse position + let hover = compute_thinking_hover(app, size); + if hover { + app.last_click_up = Some((std::time::Instant::now(), mouse.column, mouse.row)); + app.mouse_down_start = Some((std::time::Instant::now(), mouse.column, mouse.row)); + } else { + app.last_click_up = None; + } } - } else { - app.history.push(Message::assistant(None, None, Some(vec![tool_call]))); } } - StreamChunk::ToolResult { name, content, tool_call_id } => { - app.active_tool = None; - app.history.push(Message::tool(tool_call_id, name, content)); - } - StreamChunk::Done => { - app.is_generating = false; - app.active_tool = None; + if matches!(mouse.kind, MouseEventKind::Up(MouseButton::Left)) { + app.mouse_down_start = None; + app.temp_expand_thinking = false; } - StreamChunk::Error { content } => { - app.history.push(Message::system(format!("Error: {}", content))); - app.is_generating = false; - app.active_tool = None; + } else if app.screen == Screen::Welcome && matches!(mouse.kind, MouseEventKind::Down(MouseButton::Left)) { + if let Ok(size) = terminal.size() { + let logo_height = if size.height < 20 { 0 } else { 6 }; + let spacer_height = if size.height < 15 { 0 } else { size.height / 3 }; + if logo_height > 0 && mouse.row >= spacer_height && mouse.row < spacer_height + logo_height { + app.logo_anim_frames = 20; // 2 seconds at 100ms tick + } } - _ => {} } } + _ => {} } + Ok(()) } -async fn handle_model_search(app: &mut App, search: &str, force_reset: bool) { - let mut sections: Vec = Vec::new(); - let config = app.orchestrator.config.lock().await.clone(); - - let recent: Vec = config.recent_models.iter() - .filter(|m| m.name.to_lowercase().contains(search) || m.provider_id.to_lowercase().contains(search)) - .cloned() - .collect(); - if !recent.is_empty() { - sections.push(ModelMenuItem::Header("Recently Used".to_string())); - for m in recent { sections.push(ModelMenuItem::Model(m)); } - } - - let favorites: Vec = config.favorites.iter() - .filter(|m| m.name.to_lowercase().contains(search) || m.provider_id.to_lowercase().contains(search)) - .cloned() - .collect(); - if !favorites.is_empty() { - sections.push(ModelMenuItem::Header("Favorite Models".to_string())); - for m in favorites { sections.push(ModelMenuItem::Model(m)); } - } - - let mut by_provider: std::collections::HashMap> = std::collections::HashMap::new(); - for m in &app.all_available_models { - if m.name.to_lowercase().contains(search) || m.provider_id.to_lowercase().contains(search) { - by_provider.entry(m.provider_id.clone()).or_default().push(m.clone()); +async fn handle_stream_chunks(app: &mut App) { + while let Ok(chunk) = app.rx.try_recv() { + match chunk { + StreamChunk::Text { content } => { + if let Some(last) = app.history.last_mut() { + if last.role == Role::Assistant { + let mut current = last.content.as_ref().map(|s| s.to_string()).unwrap_or_default(); + current.push_str(&content); + last.content = Some(std::sync::Arc::from(current)); + } else { app.history.push(Message::assistant(Some(std::sync::Arc::from(content)), None, None)); } + } else { app.history.push(Message::assistant(Some(std::sync::Arc::from(content)), None, None)); } + } + StreamChunk::Thought { content } => { + if let Some(last) = app.history.last_mut() { + if last.role == Role::Assistant { + let mut current = last.thought.as_ref().map(|s| s.to_string()).unwrap_or_default(); + current.push_str(&content); + last.thought = Some(std::sync::Arc::from(current)); + } else { app.history.push(Message::assistant(None, Some(std::sync::Arc::from(content)), None)); } + } else { app.history.push(Message::assistant(None, Some(std::sync::Arc::from(content)), None)); } + } + StreamChunk::ToolCall { tool_call } => { + app.active_tool = Some(tool_call.function.name.clone()); + if let Some(last) = app.history.last_mut() { + if last.role == Role::Assistant { + let mut calls = last.tool_calls.clone().unwrap_or_default(); + if let Some(idx) = tool_call.index { if let Some(existing) = calls.iter_mut().find(|tc| tc.index == Some(idx)) { *existing = tool_call; } else { calls.push(tool_call); } } + else { if !calls.iter().any(|tc| tc.id == tool_call.id && !tc.id.is_empty()) { calls.push(tool_call); } } + last.tool_calls = Some(calls); + } else { app.history.push(Message::assistant(None, None, Some(vec![tool_call]))); } + } else { app.history.push(Message::assistant(None, None, Some(vec![tool_call]))); } + } + StreamChunk::ToolResult { name, content, tool_call_id } => { app.active_tool = None; app.history.push(Message::tool(tool_call_id, name, content)); } + StreamChunk::Done => { app.is_generating = false; app.active_tool = None; } + StreamChunk::Error { content } => { + let mut display_error = content.clone(); + let json_part = if let Some(idx) = content.find('{') { &content[idx..] } else { &content }; + if let Ok(val) = serde_json::from_str::(json_part) { + if let Some(msg) = val["error"]["message"].as_str() { display_error = msg.to_string(); } + else if let Some(error_obj) = val["error"].as_object() { if let Some(msg) = error_obj["message"].as_str() { display_error = msg.to_string(); } } + else if let Some(msg) = val["message"].as_str() { display_error = msg.to_string(); } + else if let Some(errors) = val["errors"].as_array() { if let Some(msg) = errors.get(0).and_then(|e| e["message"].as_str()) { display_error = msg.to_string(); } } + } + app.history.push(Message::system(format!("Error: {}", display_error))); + app.is_generating = false; + app.active_tool = None; + } + StreamChunk::Models { models } => { + app.all_available_models.extend(models); + let search = app.model_search_input.lines()[0].to_lowercase().trim().to_string(); + handle_model_search(app, &search, false).await; + } + StreamChunk::ModelsDone => { app.is_fetching_models = false; } + StreamChunk::FinalHistory { history } => { app.history = history; } + StreamChunk::RequestConfirmation { message, target, tx } => { + app.pending_command_confirmation = Some((message, target, tx.unwrap())); + } + _ => {} } } +} - let mut provider_ids: Vec = by_provider.keys().cloned().collect(); - provider_ids.sort(); +pub async fn run_app( + terminal: &mut Terminal, + mut app: App, +) -> io::Result<()> { + let mut last_tick = std::time::Instant::now(); + let tick_rate = std::time::Duration::from_millis(100); + let render_rate = std::time::Duration::from_millis(16); // ~60 FPS for smooth rendering - for p_id in provider_ids { - if let Some(models) = by_provider.get(&p_id) { - let p_name = PROVIDERS.iter().find(|p| p.id == p_id).map(|p| p.name).unwrap_or(&p_id); - sections.push(ModelMenuItem::Header(p_name.to_string())); - for m in models { sections.push(ModelMenuItem::Model(m.clone())); } - } - } + loop { + terminal.draw(|f| ui(f, &mut app))?; - app.filtered_models = sections; - - if force_reset { - if !app.filtered_models.is_empty() { - let mut first_model = None; - for (i, item) in app.filtered_models.iter().enumerate() { - if let ModelMenuItem::Model(_) = item { first_model = Some(i); break; } - } - app.menu_state.select(first_model); - } else { app.menu_state.select(None); } - } -} + let timeout = render_rate; -async fn handle_command(app: &mut App, input: &str) { - let parts: Vec<&str> = input.split_whitespace().collect(); - if parts.is_empty() { return; } - let command = parts[0]; - let args = &parts[1..]; + if event::poll(timeout)? { + let mut events = Vec::new(); + while event::poll(std::time::Duration::from_millis(0))? { + events.push(event::read()?); + } - match command { - "/model" => { - app.history.push(Message::system("Fetching available models...")); - app.all_available_models.clear(); - app.model_search_input = TextArea::default(); - app.model_search_input.set_cursor_line_style(Style::default()); - app.model_search_input.set_placeholder_text(" Search models..."); - app.model_search_input.set_placeholder_style(Style::default().fg(COLOR_SECONDARY)); + let is_burst = events.len() > 1; - let config = app.orchestrator.config.lock().await.clone(); - for p_info in PROVIDERS { - let env_key = format!("{}_API_KEY", p_info.id.to_uppercase().replace("-", "_")); - let mut api_key = std::env::var(env_key).ok().or_else(|| config.api_keys.get(p_info.id).cloned()); - if api_key.is_none() && p_info.id.starts_with("cloudflare") { - api_key = std::env::var("CLOUDFLARE_API_KEY").ok(); - } - if let Some(key) = api_key { - let provider = routecode_sdk::agents::resolve_provider(p_info.id, key); - match provider.list_models().await { - Ok(models) => { - for m_name in models { - app.all_available_models.push(DynamicModelInfo { name: m_name, provider_id: p_info.id.to_string() }); + for event in events { + match event { + Event::Key(key) => { + if key.kind == KeyEventKind::Press { + match handle_key_event(&mut app, key, is_burst).await? { + KeyEventResult::Exit => return Ok(()), + KeyEventResult::Continue => {} } } - Err(e) => { log::error!("Failed to list models for {}: {}", p_info.id, e); } } + Event::Paste(text) => { app.input.insert_str(&text); } + Event::Mouse(mouse) => { + handle_mouse_event(&mut app, mouse, terminal).await?; + } + _ => {} } } - handle_model_search(app, "", true).await; - if app.filtered_models.is_empty() { - app.history.push(Message::system("No models found. Ensure providers are connected.")); - } else { app.show_model_menu = true; } } - "/resume" => { - if let Some(name) = args.first() { - if let Ok(session) = routecode_sdk::utils::storage::load_session(name) { - app.history = session.messages; - app.current_model = session.model; - let mut u = app.orchestrator.usage.lock().await; - *u = session.usage; - app.history.push(Message::system(format!("Session resumed: {}", name))); - app.screen = Screen::Session; + + if last_tick.elapsed() >= tick_rate { + app.tick_count += 1; + app.logo_anim_frames = app.logo_anim_frames.saturating_sub(1); + + if app.screen == Screen::Session { + if let Some((start_time, _, _)) = app.mouse_down_start { + if start_time.elapsed() >= std::time::Duration::from_millis(400) { + if app.thinking_hover_rendered { + app.temp_expand_thinking = true; + } + } } } + + last_tick = std::time::Instant::now(); } - "/sessions" => { - if let Ok(sessions) = routecode_sdk::utils::storage::list_sessions() { - if sessions.is_empty() { app.history.push(Message::system("No saved sessions found.")); } - else { app.history.push(Message::system(format!("Saved sessions:\n {}", sessions.join("\n ")))); } - } - } - "/clear" => { - app.history.clear(); - app.screen = Screen::Welcome; - } - "/stop" => { - if app.is_generating { - if let Some(handle) = app.current_task.take() { handle.abort(); } - app.is_generating = false; - app.active_tool = None; - app.history.push(Message::system("Generation cancelled.")); - } - } - "/help" => { - app.history.push(Message::system("Available commands:\n /model - Select model\n /provider - Manage connections\n /resume - Resume session\n /sessions - List sessions\n /clear - Clear history\n /help - Show help\n /exit - Use Esc to exit")); - } - "/provider" => { app.show_provider_menu = true; app.menu_state.select(Some(0)); } - _ => { app.history.push(Message::system(format!("Unknown command: {}", command))); } + + handle_stream_chunks(&mut app).await; } } -fn ui(f: &mut Frame, app: &mut App, usage: &Usage) { +fn ui(f: &mut Frame, app: &mut App) { let area = f.size(); f.render_widget(Block::default().style(Style::default().bg(COLOR_BG)), area); - - let main_layout = Layout::default() - .direction(Direction::Vertical) - .constraints([ - Constraint::Length(1), // Header - Constraint::Min(0), // Content - ]) - .split(area); - - let current_dir = std::env::current_dir() - .map(|p| p.file_name().unwrap_or_default().to_string_lossy().to_string()) - .unwrap_or_else(|_| "workspace".to_string()); - - let header_layout = Layout::default() - .direction(Direction::Horizontal) - .constraints([ - Constraint::Min(0), - Constraint::Length(25), // " RouteCode v0.1.1 " - ]) - .split(main_layout[0]); - + let main_layout = Layout::default().direction(Direction::Vertical).constraints([Constraint::Length(1), Constraint::Min(0)]).split(area); + let current_dir = std::env::current_dir().map(|p| p.file_name().unwrap_or_default().to_string_lossy().to_string()).unwrap_or_else(|_| "workspace".to_string()); + let header_layout = Layout::default().direction(Direction::Horizontal).constraints([Constraint::Min(0), Constraint::Length(25)]).split(main_layout[0]); let version = env!("CARGO_PKG_VERSION"); let header_title = format!(" RouteCode v{} ", version); - f.render_widget(Paragraph::new(Span::styled(format!(" {} ", current_dir), Style::default().fg(COLOR_SECONDARY))), header_layout[0]); f.render_widget(Paragraph::new(Span::styled(header_title, Style::default().fg(COLOR_PRIMARY).add_modifier(Modifier::BOLD))).alignment(ratatui::layout::Alignment::Right), header_layout[1]); - let input_area = match app.screen { Screen::Welcome => ui_welcome(f, app, main_layout[1]), - Screen::Session => ui_session(f, app, usage, main_layout[1]), + Screen::Session => ui_session(f, app, main_layout[1]), }; - - if app.show_menu { - render_menu(f, app, input_area); - } else if app.show_provider_menu { - render_provider_menu(f, app, input_area); - } else if app.show_model_menu { - render_model_menu(f, app, input_area); - } else if app.is_inputting_api_key { - render_api_key_dialog(f, app); - } -} - -pub fn clean_model_name(name: &str, provider_id: &str) -> String { - if provider_id.starts_with("cloudflare") && name.starts_with("@cf/") { - name.split('/').last().unwrap_or(name).to_string() - } else if (provider_id == "openrouter" || provider_id == "nvidia") && name.contains('/') { - name.split('/').last().unwrap_or(name).to_string() - } else { - name.to_string() - } + if app.show_menu { render_menu(f, app, input_area); } + else if app.show_provider_menu { render_provider_menu(f, app, input_area); } + else if app.show_model_menu { render_model_menu(f, app, input_area); } + else if app.show_settings_menu { render_settings_menu(f, app, input_area); } + else if app.is_inputting_api_key { render_api_key_dialog(f, app); } + else if app.pending_clear { render_confirmation_dialog(f, "Are you sure you want to clear all history? (y/n)"); } + else if app.pending_exit { render_confirmation_dialog(f, "Are you sure you want to exit RouteCode? (y/n)"); } + else if app.pending_command_confirmation.is_some() { render_command_confirmation_dialog(f, app); } } -fn draw_modal(f: &mut Frame, title: &str, width: u16, height: u16, footer: Vec) -> Rect { +fn render_command_confirmation_dialog(f: &mut Frame, app: &mut App) { let area = f.size(); - let modal_area = Rect::new( - (area.width.saturating_sub(width)) / 2, - (area.height.saturating_sub(height)) / 2, - width, - height, - ); - - f.render_widget(Clear, modal_area); - f.render_widget(Block::default().style(Style::default().bg(COLOR_BG)), modal_area); - - let main_layout = Layout::default() + let popup_layout = Layout::default() .direction(Direction::Vertical) .constraints([ - Constraint::Length(1), // Header - Constraint::Min(0), // Body - Constraint::Length(1), // Footer Spacer - Constraint::Length(1), // Footer + Constraint::Percentage(30), + Constraint::Length(10), + Constraint::Percentage(30), ]) - .margin(1) - .split(modal_area); + .split(area); - let header_layout = Layout::default() + let popup_horiz = Layout::default() .direction(Direction::Horizontal) - .constraints([Constraint::Min(0), Constraint::Length(5)]) - .split(main_layout[0]); - - f.render_widget(Paragraph::new(Span::styled(title, Style::default().add_modifier(Modifier::BOLD))), header_layout[0]); - f.render_widget(Paragraph::new(Span::styled("esc", Style::default().fg(COLOR_SECONDARY))), header_layout[1]); - f.render_widget(Paragraph::new(Line::from(footer)), main_layout[3]); - - main_layout[1] -} - -fn render_api_key_dialog(f: &mut Frame, app: &mut App) { - let provider_id = app.pending_provider_id.as_deref().unwrap_or("provider"); - let p_info = PROVIDERS.iter().find(|p| p.id == provider_id); - let provider_name = p_info.map(|p| p.name).unwrap_or(provider_id); - - let title = format!("Connect {}", provider_name); - let body_area = draw_modal(f, &title, 60, 10, vec![ - Span::styled("Press Enter to save", Style::default().add_modifier(Modifier::BOLD)), - ]); - - let layout = Layout::default() - .direction(Direction::Vertical) .constraints([ - Constraint::Length(1), // Prompt - Constraint::Length(1), // Spacer - Constraint::Length(3), // Input + Constraint::Percentage(15), + Constraint::Percentage(70), + Constraint::Percentage(15), ]) - .split(body_area); - - let (prompt, placeholder) = match app.api_key_input_stage { - ApiKeyInputStage::CloudflareAccountId => (format!("Enter Cloudflare Account ID:"), " Account ID..."), - ApiKeyInputStage::CloudflareGatewayId => (format!("Enter Cloudflare Gateway ID:"), " Gateway ID..."), - ApiKeyInputStage::CloudflareApiKey => (format!("Enter Cloudflare API Token:"), " API Token..."), - _ => (format!("Enter API key for {}:", provider_name), " Paste your API key here..."), - }; + .split(popup_layout[1]); - f.render_widget(Paragraph::new(prompt), layout[0]); + let inner_area = popup_horiz[1]; - app.api_key_input.set_placeholder_text(placeholder); - app.api_key_input.set_block(Block::default().borders(Borders::ALL).border_style(Style::default().fg(COLOR_SECONDARY))); - f.render_widget(app.api_key_input.widget(), layout[2]); + let block = Block::default() + .title(" Command Confirmation Required ") + .borders(ratatui::widgets::Borders::ALL) + .border_style(Style::default().fg(COLOR_PRIMARY)) + .style(Style::default().bg(COLOR_BG)); - let (row, col) = app.api_key_input.cursor(); - f.set_cursor(layout[2].x + 1 + col as u16, layout[2].y + 1 + row as u16); -} - -fn render_provider_menu(f: &mut Frame, app: &mut App, _input_area: Rect) { - let height = (PROVIDERS.len() + 6).min(15) as u16; - let body_area = draw_modal(f, "AI Providers", 60, height, vec![ - Span::styled("Space", Style::default().add_modifier(Modifier::BOLD)), - Span::raw(" toggle connection"), - ]); - - let config = futures::executor::block_on(app.orchestrator.config.lock()); - let items: Vec = PROVIDERS.iter().map(|p| { - let env_key = format!("{}_API_KEY", p.id.to_uppercase().replace("-", "_")); - let is_connected = config.api_keys.contains_key(p.id) || std::env::var(env_key).is_ok(); - let status = if is_connected { Span::styled(" ✔ connected", Style::default().fg(COLOR_SUCCESS)) } - else { Span::styled(" ✖ disconnected", Style::default().fg(COLOR_SECONDARY)) }; - - let total_width = body_area.width.saturating_sub(4); - let left = p.name.to_string(); - let status_str = if is_connected { "✔ connected" } else { "✖ disconnected" }; - let padding = total_width.saturating_sub(left.len() as u16).saturating_sub(status_str.len() as u16); - let spaces = " ".repeat(padding as usize); - - ListItem::new(Line::from(vec![ - Span::raw(format!(" {}", left)), - Span::raw(spaces), - status, - Span::raw(" "), - ])) - }).collect(); + let (message, target, _) = app.pending_command_confirmation.as_ref().unwrap(); - let list = List::new(items) - .highlight_style(Style::default().bg(COLOR_PRIMARY).fg(Color::Black)) - .highlight_symbol(""); - - f.render_stateful_widget(list, body_area, &mut app.menu_state); -} - -fn render_model_menu(f: &mut Frame, app: &mut App, _input_area: Rect) { - let height = (app.filtered_models.len() + 7).min(18) as u16; - let body_area = draw_modal(f, "Select model", 70, height, vec![ - Span::styled("Connect provider ", Style::default().add_modifier(Modifier::BOLD)), - Span::styled("ctrl+a", Style::default().fg(COLOR_SECONDARY)), - Span::raw(" "), - Span::styled("Favorite ", Style::default().add_modifier(Modifier::BOLD)), - Span::styled("ctrl+f", Style::default().fg(COLOR_SECONDARY)), - ]); - - let layout = Layout::default() - .direction(Direction::Vertical) - .constraints([ - Constraint::Length(2), // Search area - Constraint::Min(0), // List area - ]) - .split(body_area); + let mut lines = vec![ + ratatui::text::Line::from(vec![Span::styled(message, Style::default().fg(COLOR_TEXT))]), + ratatui::text::Line::from(vec![Span::styled(format!("> {}", target), Style::default().fg(COLOR_SYSTEM).add_modifier(Modifier::BOLD))]), + ratatui::text::Line::from(""), + ]; - let search_text = app.model_search_input.lines()[0].clone(); - let search_para = if search_text.is_empty() { - Paragraph::new(Span::styled("search models...", Style::default().fg(COLOR_SECONDARY))) + if app.inputting_command_feedback { + lines.push(ratatui::text::Line::from(vec![Span::styled("Please type your feedback below and press Enter (Esc to cancel):", Style::default().fg(COLOR_SECONDARY))])); } else { - Paragraph::new(Span::styled(&search_text, Style::default().fg(COLOR_TEXT))) - }; - f.render_widget(search_para, layout[0]); - - if app.show_model_menu && !app.is_inputting_api_key { - let (row, col) = app.model_search_input.cursor(); - f.set_cursor(layout[0].x + col as u16, layout[0].y + row as u16); + lines.push(ratatui::text::Line::from(vec![ + Span::styled("[Y]", Style::default().fg(COLOR_SUCCESS).add_modifier(Modifier::BOLD)), + Span::raw(" Allow once "), + Span::styled("[S]", Style::default().fg(COLOR_PRIMARY).add_modifier(Modifier::BOLD)), + Span::raw(" Allow for session "), + Span::styled("[W]", Style::default().fg(COLOR_PRIMARY).add_modifier(Modifier::BOLD)), + Span::raw(" Allow for Workspace "), + Span::styled("[F]", Style::default().fg(COLOR_SECONDARY).add_modifier(Modifier::BOLD)), + Span::raw(" Tell Agent something else "), + Span::styled("[D] or [Esc]", Style::default().fg(ratatui::style::Color::Red).add_modifier(Modifier::BOLD)), + Span::raw(" Deny"), + ])); } - let config = futures::executor::block_on(app.orchestrator.config.lock()); - let items: Vec = app.filtered_models.iter().map(|item| { - match item { - ModelMenuItem::Header(title) => { - ListItem::new(Line::from(vec![ - Span::styled(format!(" {}", title), Style::default().fg(COLOR_SECONDARY).add_modifier(Modifier::DIM)) - ])) - } - ModelMenuItem::Model(m) => { - let is_fav = config.favorites.iter().any(|fav| fav.name == m.name && fav.provider_id == m.provider_id); - let fav_star = if is_fav { " ★" } else { "" }; - let display_name = clean_model_name(&m.name, &m.provider_id).replace(":free", " Free"); - let p_name = PROVIDERS.iter().find(|p| p.id == m.provider_id).map(|p| p.name).unwrap_or(&m.provider_id); - let left = format!("{}{}", display_name, fav_star); - let right = p_name.to_string(); - let total_width = layout[1].width.saturating_sub(4); - let padding = total_width.saturating_sub(left.len() as u16).saturating_sub(right.len() as u16); - let spaces = " ".repeat(padding as usize); - ListItem::new(Line::from(vec![Span::raw(format!(" {}", left)), Span::raw(spaces), Span::raw(right), Span::raw(" ")])) - } - } - }).collect(); - - let list = List::new(items) - .highlight_style(Style::default().bg(COLOR_PRIMARY).fg(Color::Black)) - .highlight_symbol(""); - - f.render_stateful_widget(list, layout[1], &mut app.menu_state); -} + let paragraph = Paragraph::new(lines).block(block).wrap(ratatui::widgets::Wrap { trim: false }); + f.render_widget(ratatui::widgets::Clear, inner_area); + f.render_widget(paragraph, inner_area); -fn ui_welcome(f: &mut Frame, app: &mut App, area: Rect) -> Rect { - let logo_height = if area.height < 20 { 0 } else { 6 }; - let spacer_height = if area.height < 15 { 0 } else { area.height / 3 }; - let input_lines = app.input.lines().len() as u16; - let input_height = (input_lines + 2).min(12); - - let chunks = Layout::default() - .direction(Direction::Vertical) - .constraints([ - Constraint::Length(spacer_height), - Constraint::Length(logo_height), - Constraint::Length(input_height), // Dynamic Input - Constraint::Length(1), // Spacer - Constraint::Length(1), // Info & Tips - Constraint::Min(0), - Constraint::Length(1), // Footer - ]) - .split(area); - - if logo_height > 0 { - let logo_text = if area.width < 60 { - vec![ - Line::from(Span::styled(" __ _ ", Style::default().fg(COLOR_TEXT).add_modifier(Modifier::BOLD))), - Line::from(Span::styled(" |__) _|_ _ _/ _ _| _ ", Style::default().fg(COLOR_TEXT).add_modifier(Modifier::BOLD))), - Line::from(Span::styled(" | \\(_|(_(- \\__(_)(_|(/_ ", Style::default().fg(COLOR_TEXT).add_modifier(Modifier::BOLD))), - ] - } else { - vec![ - Line::from(Span::styled(" ____ _ ____ _ ", Style::default().fg(COLOR_TEXT).add_modifier(Modifier::BOLD))), - Line::from(Span::styled(" | _ \\ ___ _ _| |_ ___ / ___|___ __| | ___ ", Style::default().fg(COLOR_TEXT).add_modifier(Modifier::BOLD))), - Line::from(Span::styled(" | |_) / _ \\| | | | __/ _ \\ | / _ \\ / _` |/ _ \\", Style::default().fg(COLOR_TEXT).add_modifier(Modifier::BOLD))), - Line::from(Span::styled(" | _ < (_) | |_| | || __/ |__| (_) | (_| | __/", Style::default().fg(COLOR_TEXT).add_modifier(Modifier::BOLD))), - Line::from(Span::styled(" |_| \\_\\___/ \\__,_|\\__\\___|\\____\\___/ \\__,_|\\___|", Style::default().fg(COLOR_TEXT).add_modifier(Modifier::BOLD))), - ] + if app.inputting_command_feedback { + let input_rect = ratatui::layout::Rect { + x: inner_area.x + 2, + y: inner_area.y + 5, + width: inner_area.width.saturating_sub(4), + height: 3, }; - f.render_widget(Paragraph::new(logo_text).alignment(ratatui::layout::Alignment::Center), chunks[1]); + let input_block = Block::default().borders(ratatui::widgets::Borders::ALL).border_style(Style::default().fg(COLOR_PRIMARY)); + app.input.set_block(input_block); + f.render_widget(app.input.widget(), input_rect); + f.set_cursor(input_rect.x + app.input.cursor().1 as u16 + 1, input_rect.y + app.input.cursor().0 as u16 + 1); } - - let input_width_percent = if area.width < 50 { 0.95 } else if area.width < 100 { 0.8 } else { 0.6 }; - let input_width = (area.width as f32 * input_width_percent) as u16; - let input_area = Rect::new((area.width - input_width) / 2, chunks[2].y, input_width, input_height); - f.render_widget(Block::default().style(Style::default().bg(COLOR_INPUT_BG)), input_area); - let inner_input_area = Rect::new(input_area.x + 1, input_area.y + 1, input_area.width.saturating_sub(2), input_area.height.saturating_sub(2)); - app.input.set_block(Block::default().borders(Borders::NONE)); - f.render_widget(app.input.widget(), inner_input_area); - if !app.is_generating { f.set_cursor(inner_input_area.x + app.input.cursor().1 as u16, inner_input_area.y + app.input.cursor().0 as u16); } - - let cleaned_model = clean_model_name(&app.current_model, &app.current_provider_id); - let provider_info = vec![ - Span::styled("Model ", Style::default().fg(COLOR_SECONDARY)), - Span::styled(cleaned_model, Style::default().fg(COLOR_PRIMARY).add_modifier(Modifier::BOLD)), - Span::styled(" • Provider ", Style::default().fg(COLOR_SECONDARY)), - Span::styled(&app.provider_name, Style::default().fg(COLOR_TEXT)), - ]; - f.render_widget(Paragraph::new(Line::from(provider_info)).alignment(ratatui::layout::Alignment::Center), chunks[4]); - - let spinner = ["⠋", "⠙", "⠹", "⠸", "⠼", "⠴", "⠦", "⠧", "⠇", "⠏"]; - let frame = spinner[(app.tick_count % spinner.len() as u64) as usize]; - let tip_text = if app.is_generating { format!(" {} AI is working... ", frame) } else { "ctrl+p help | esc exit".to_string() }; - f.render_widget(Paragraph::new(tip_text).alignment(ratatui::layout::Alignment::Center).style(Style::default().fg(COLOR_SECONDARY).add_modifier(Modifier::DIM)), chunks[6]); - input_area } -fn ui_session(f: &mut Frame, app: &mut App, usage: &Usage, area: Rect) -> Rect { - let input_height = (app.input.lines().len() as u16 + 2).min(12); - let chunks = Layout::default() +fn render_confirmation_dialog(f: &mut Frame, message: &str) { + let area = f.size(); + let popup_layout = Layout::default() .direction(Direction::Vertical) - .constraints([Constraint::Min(1), Constraint::Length(input_height), Constraint::Length(1)]) + .constraints([ + Constraint::Percentage(40), + Constraint::Length(5), + Constraint::Percentage(40), + ]) .split(area); - - let history = render_history(&app.history); - - // 1. Auto-scroll logic - let mut total_height = 0; - let available_width = chunks[0].width.saturating_sub(4).max(1); // Account for some margin - for line in &history.lines { - let line_width: usize = line.spans.iter().map(|s| s.content.len()).sum(); - let wrapped_height = (line_width as u16 / available_width) + 1; - total_height += wrapped_height; - } - - // Pin to bottom if generating - if app.is_generating { - app.history_scroll = total_height.saturating_sub(chunks[0].height); - } - - f.render_widget(Paragraph::new(history).wrap(Wrap { trim: false }).scroll((app.history_scroll, 0)), chunks[0]); - f.render_widget(Block::default().style(Style::default().bg(COLOR_INPUT_BG)), chunks[1]); - let inner_input_area = Rect::new(chunks[1].x + 1, chunks[1].y + 1, chunks[1].width.saturating_sub(2), chunks[1].height.saturating_sub(2)); - app.input.set_block(Block::default().borders(Borders::NONE)); - f.render_widget(app.input.widget(), inner_input_area); - if !app.is_generating { f.set_cursor(inner_input_area.x + app.input.cursor().1 as u16, inner_input_area.y + app.input.cursor().0 as u16); } - let spinner = ["⠋", "⠙", "⠹", "⠸", "⠼", "⠴", "⠦", "⠧", "⠇", "⠏"]; - let frame = spinner[(app.tick_count % spinner.len() as u64) as usize]; - let generating_text = if app.is_generating { - if let Some(tool) = &app.active_tool { format!(" {} [Running {}...] ", frame, tool) } - else { format!(" {} [Thinking...] ", frame) } - } else { "".to_string() }; + let popup_horiz = Layout::default() + .direction(Direction::Horizontal) + .constraints([ + Constraint::Percentage(25), + Constraint::Percentage(50), + Constraint::Percentage(25), + ]) + .split(popup_layout[1]); - let cleaned_model = clean_model_name(&app.current_model, &app.current_provider_id); - let status_bar = Line::from(vec![ - Span::styled(format!(" {} ", cleaned_model), Style::default().fg(COLOR_PRIMARY).add_modifier(Modifier::BOLD)), - Span::styled(format!(" • Tokens: {} • Cost: ${:.4} ", usage.total_tokens, usage.total_cost), Style::default().fg(COLOR_SECONDARY)), - Span::styled(generating_text, Style::default().fg(COLOR_SYSTEM)), - Span::styled(" • ctrl+p help ", Style::default().fg(COLOR_SECONDARY).add_modifier(Modifier::DIM)), - ]); - f.render_widget(Paragraph::new(status_bar), chunks[2]); - chunks[1] -} + let block = Block::default() + .title(" Confirmation ") + .borders(ratatui::widgets::Borders::ALL) + .border_style(Style::default().fg(COLOR_PRIMARY)); -fn render_history(history: &[Message]) -> Text<'_> { - let mut lines = Vec::new(); - for m in history { - match m.role { - Role::User => { - lines.push(Line::from(vec![Span::styled(" ● User", Style::default().fg(COLOR_PRIMARY).add_modifier(Modifier::BOLD))])); - if let Some(content) = &m.content { for line in content.lines() { lines.push(Line::from(vec![Span::raw(" "), Span::raw(line)])); } } - } - Role::Assistant => { - lines.push(Line::from(vec![Span::styled(" ● RouteCode", Style::default().fg(COLOR_TEXT).add_modifier(Modifier::BOLD))])); - if let Some(thought) = &m.thought { - for line in thought.lines() { - lines.push(Line::from(vec![Span::styled(" │ ", Style::default().fg(COLOR_DIM)), Span::styled(line, Style::default().fg(COLOR_SECONDARY).add_modifier(Modifier::ITALIC))])); - } - } - if let Some(tool_calls) = &m.tool_calls { - for tc in tool_calls { - let args: serde_json::Value = serde_json::from_str(&tc.function.arguments).unwrap_or(serde_json::json!({})); - let arg_preview = if let Some(path) = args["path"].as_str() { - format!("({})", path) - } else { - format!("({})", tc.function.name) - }; - - lines.push(Line::from(vec![ - Span::styled(" 🛠 ", Style::default().fg(COLOR_PRIMARY)), - Span::styled(format!("Using {} ", tc.function.name), Style::default().fg(COLOR_TEXT)), - Span::styled(arg_preview, Style::default().fg(COLOR_SECONDARY).add_modifier(Modifier::DIM)), - ])); - } - } - if let Some(content) = &m.content { - for line in content.lines() { - if line.trim().starts_with("```") { lines.push(Line::from(vec![Span::raw(" "), Span::styled(line, Style::default().fg(COLOR_PRIMARY))])); } - else { lines.push(Line::from(vec![Span::raw(" "), Span::raw(line)])); } - } - } - } - Role::Tool => { - lines.push(Line::from(vec![Span::styled(format!(" ✓ Tool ({})", m.name.as_deref().unwrap_or("result")), Style::default().fg(COLOR_SECONDARY))])); - if let Some(content) = &m.content { - if let Ok(res) = serde_json::from_str::(content) { - if let Some(diff) = res.diff { - for line in diff.lines() { - let style = if line.starts_with('+') { - Style::default().fg(COLOR_SUCCESS) - } else if line.starts_with('-') { - Style::default().fg(Color::Red) - } else { - Style::default().fg(COLOR_DIM) - }; - lines.push(Line::from(vec![Span::raw(" "), Span::styled(line.to_string(), style)])); - } - } else if let Some(out) = res.content { - let preview = if out.len() > 100 { format!("{}...", &out[..100]) } else { out }; - lines.push(Line::from(vec![Span::styled(format!(" {}", preview), Style::default().fg(COLOR_DIM).add_modifier(Modifier::DIM))])); - } else if let Some(err) = res.error { - lines.push(Line::from(vec![Span::styled(format!(" Error: {}", err), Style::default().fg(Color::Red))])); - } - } else { - let preview = if content.len() > 100 { format!("{}...", &content[..100]) } else { content.clone() }; - lines.push(Line::from(vec![Span::styled(format!(" {}", preview), Style::default().fg(COLOR_DIM).add_modifier(Modifier::DIM))])); - } - } - } - Role::System => { - lines.push(Line::from(vec![Span::styled(" ● System", Style::default().fg(COLOR_SYSTEM).add_modifier(Modifier::DIM))])); - if let Some(content) = &m.content { lines.push(Line::from(vec![Span::styled(format!(" {}", content), Style::default().fg(COLOR_SYSTEM).add_modifier(Modifier::DIM))])); } - } - } - lines.push(Line::from("")); - } - Text::from(lines) -} + let p = Paragraph::new(Span::styled(message, Style::default().fg(COLOR_TEXT).add_modifier(Modifier::BOLD))) + .alignment(ratatui::layout::Alignment::Center) + .block(block); -fn render_menu(f: &mut Frame, app: &mut App, _input_area: Rect) { - let height = (app.filtered_commands.len() + 6).min(15) as u16; - let body_area = draw_modal(f, "Commands", 60, height, vec![Span::styled("Enter", Style::default().add_modifier(Modifier::BOLD)), Span::raw(" select command")]); - let items: Vec = app.filtered_commands.iter().map(|cmd| { - let total_width = body_area.width.saturating_sub(4); - let left = cmd.name.to_string(); - let right = cmd.description.to_string(); - let padding = total_width.saturating_sub(left.len() as u16).saturating_sub(right.len() as u16); - let spaces = " ".repeat(padding as usize); - ListItem::new(Line::from(vec![Span::raw(format!(" {}", left)), Span::raw(spaces), Span::styled(right, Style::default().fg(COLOR_SECONDARY)), Span::raw(" ")])) - }).collect(); - let list = List::new(items).highlight_style(Style::default().bg(COLOR_PRIMARY).fg(Color::Black)).highlight_symbol(""); - f.render_stateful_widget(list, body_area, &mut app.menu_state); + f.render_widget(ratatui::widgets::Clear, popup_horiz[1]); + f.render_widget(p, popup_horiz[1]); } #[cfg(test)] @@ -1173,7 +1277,7 @@ mod tests { impl AIProvider for MockProvider { fn name(&self) -> &str { "Mock" } async fn list_models(&self) -> Result, anyhow::Error> { Ok(vec![]) } - async fn ask(&self, _: Vec, _: &str, _: Option>) -> Result { + async fn ask(&self, _: Vec, _: &str, _: Option>, _: Option<&str>) -> Result { Err(anyhow::anyhow!("Not implemented")) } } diff --git a/apps/cli/src/ui/session.rs b/apps/cli/src/ui/session.rs new file mode 100644 index 0000000..328800e --- /dev/null +++ b/apps/cli/src/ui/session.rs @@ -0,0 +1,382 @@ +use ratatui::layout::{Constraint, Direction, Layout, Rect}; +use ratatui::style::{Modifier, Style, Color}; +use ratatui::text::{Line, Span, Text}; +use ratatui::widgets::{Block, Borders, Paragraph, Wrap}; +use ratatui::Frame; +use routecode_sdk::core::{Message, Role}; +use unicode_width::UnicodeWidthStr; +use crate::ui::App; +use crate::ui::components::{COLOR_INPUT_BG, COLOR_PRIMARY, COLOR_SECONDARY, COLOR_TEXT, COLOR_DIM, COLOR_SYSTEM, COLOR_SUCCESS, clean_model_name}; + +pub fn ui_session(f: &mut Frame, app: &mut App, area: Rect) -> Rect { + let input_height = (app.input.lines().len() as u16 + 2).min(12); + let chunks = Layout::default() + .direction(Direction::Vertical) + .constraints([Constraint::Min(1), Constraint::Length(input_height), Constraint::Length(1)]) + .split(area); + + // Compute thinking hover at render time using actual frame dimensions + let thinking_hovered = crate::ui::compute_thinking_hover(app, f.size()); + app.thinking_hover_rendered = thinking_hovered; + let is_collapsed = app.collapse_thinking && !app.temp_expand_thinking; + + let last_msg_len = app.history.last().map(|m| { + m.content.as_ref().map(|s| s.len()).unwrap_or(0) + + m.thought.as_ref().map(|s| s.len()).unwrap_or(0) + + m.tool_calls.as_ref().map(|tc| tc.len()).unwrap_or(0) + }).unwrap_or(0); + + let cache_valid = app.cached_text.is_some() + && app.history.len() == app.cached_history_len + && last_msg_len == app.cached_last_msg_len + && chunks[0].width == app.cached_width + && is_collapsed == app.cached_is_collapsed + && thinking_hovered == app.cached_thinking_hovered; + + if !cache_valid { + let history = render_history(&app.history, is_collapsed, thinking_hovered); + + // 1. Auto-scroll logic + let mut total_height: usize = 0; + let available_width = chunks[0].width.max(1) as usize; + for line in &history.lines { + let line_width: usize = line.spans.iter().map(|s| s.content.width()).sum(); + let wrapped_height = if line_width == 0 { 1 } else { + // Use a slightly smaller width for calculation to account for word wrapping + let calc_width = (available_width as f32 * 0.95).floor() as usize; + (line_width + calc_width - 1) / calc_width.max(1) + }; + total_height += wrapped_height; + } + // Safety buffer + total_height += 2; + + app.cached_history_len = app.history.len(); + app.cached_last_msg_len = last_msg_len; + app.cached_width = chunks[0].width; + app.cached_is_collapsed = is_collapsed; + app.cached_thinking_hovered = thinking_hovered; + app.cached_total_height = total_height; + app.cached_text = Some(history); + } + + let history_text = app.cached_text.as_ref().unwrap().clone(); + let total_height = app.cached_total_height; + + let max_scroll = total_height.saturating_sub(chunks[0].height as usize).min(u16::MAX as usize) as u16; + app.max_scroll = max_scroll; + + if app.auto_scroll { + app.history_scroll = max_scroll; + } else { + // Only re-enable auto-scroll if the user manually scrolls to the bottom of long content + if app.history_scroll >= max_scroll && max_scroll > 0 { + app.auto_scroll = true; + app.history_scroll = max_scroll; + } + } + + f.render_widget(Paragraph::new(history_text).wrap(Wrap { trim: false }).scroll((app.history_scroll, 0)), chunks[0]); + + f.render_widget(Block::default().style(Style::default().bg(COLOR_INPUT_BG)), chunks[1]); + + let inner_input_area = Rect::new( + chunks[1].x + 1, + chunks[1].y + 1, + chunks[1].width.saturating_sub(2), + chunks[1].height.saturating_sub(2) + ); + app.input.set_block(Block::default().borders(Borders::NONE)); + f.render_widget(app.input.widget(), inner_input_area); + + f.set_cursor(inner_input_area.x + app.input.cursor().1 as u16, inner_input_area.y + app.input.cursor().0 as u16); + + let spinner = ["⠋", "⠙", "⠹", "⠸", "⠼", "⠴", "⠦", "⠧", "⠇", "⠏"]; + let frame = spinner[(app.tick_count % spinner.len() as u64) as usize]; + + let generating_text = if app.is_generating { + if let Some(tool) = &app.active_tool { + format!(" {} [Running {}...] ", frame, tool) + } else { + format!(" {} [Thinking...] ", frame) + } + } else { + "".to_string() + }; + + let cleaned_model = clean_model_name(&app.current_model, &app.current_provider_id); + + let config_thinking = app.orchestrator.config.try_lock() + .map(|c| c.thinking_level.clone()) + .unwrap_or("default".to_string()); + + let thinking_tag = if config_thinking != "default" { + format!(" • [{}] ", config_thinking) + } else { + "".to_string() + }; + + let status_layout = Layout::default() + .direction(Direction::Horizontal) + .constraints([ + Constraint::Min(0), + Constraint::Length(app.provider_name.len() as u16 + 2), + ]) + .split(chunks[2]); + + let left_status = Line::from(vec![ + Span::styled(format!(" {} ", cleaned_model), Style::default().fg(COLOR_PRIMARY).add_modifier(Modifier::BOLD)), + Span::styled(thinking_tag, Style::default().fg(COLOR_SYSTEM).add_modifier(Modifier::BOLD)), + Span::styled(format!(" • Tokens: {} • Cost: ${:.4} ", app.usage.total_tokens, app.usage.total_cost), Style::default().fg(COLOR_SECONDARY)), + Span::styled(format!(" • Scroll: {}/{} ", app.history_scroll, app.max_scroll), Style::default().fg(COLOR_SECONDARY).add_modifier(Modifier::DIM)), + Span::styled(generating_text, Style::default().fg(COLOR_SYSTEM)), + Span::styled(" • ctrl+o toggle thinking • ctrl+p help ", Style::default().fg(COLOR_SECONDARY).add_modifier(Modifier::DIM)), + ]); + + let right_status = Paragraph::new(Span::styled( + format!(" {} ", app.provider_name), + Style::default().fg(COLOR_SECONDARY).add_modifier(Modifier::BOLD) + )).alignment(ratatui::layout::Alignment::Right); + + f.render_widget(Paragraph::new(left_status), status_layout[0]); + f.render_widget(right_status, status_layout[1]); + + chunks[1] +} + +pub fn render_history(history: &[Message], collapse_thinking: bool, thinking_hovered: bool) -> Text<'static> { + let mut lines = Vec::new(); + for m in history { + match m.role { + Role::User => { + lines.push(Line::from(vec![ + Span::styled(" ● User", Style::default().fg(COLOR_PRIMARY).add_modifier(Modifier::BOLD)), + ])); + if let Some(content) = &m.content { + for line in content.lines() { + lines.push(Line::from(vec![Span::raw(" "), Span::raw(line.to_string())])); + } + } + } + Role::Assistant => { + lines.push(Line::from(vec![ + Span::styled(" ● RouteCode", Style::default().fg(COLOR_TEXT).add_modifier(Modifier::BOLD)), + ])); + + if let Some(thought) = &m.thought { + if collapse_thinking { + if thinking_hovered { + lines.push(Line::from(vec![ + Span::styled(" ┃ ", Style::default().fg(COLOR_PRIMARY).add_modifier(Modifier::BOLD)), + Span::styled("▶ Thinking... (Double click to expand, one hold click to see thought)", Style::default().fg(COLOR_PRIMARY).add_modifier(Modifier::BOLD)), + ])); + } else { + lines.push(Line::from(vec![ + Span::styled(" │ ", Style::default().fg(COLOR_DIM)), + Span::styled("▶ Thinking... (ctrl+o to expand)", Style::default().fg(COLOR_SECONDARY).add_modifier(Modifier::ITALIC)), + ])); + } + } else { + if thinking_hovered { + lines.push(Line::from(vec![ + Span::styled(" ┃ ", Style::default().fg(COLOR_PRIMARY).add_modifier(Modifier::BOLD)), + Span::styled("▼ Thinking... (Double click to collapse)", Style::default().fg(COLOR_PRIMARY).add_modifier(Modifier::BOLD)), + ])); + } else { + lines.push(Line::from(vec![ + Span::styled(" │ ", Style::default().fg(COLOR_DIM)), + Span::styled("▼ Thinking... (ctrl+o to collapse)", Style::default().fg(COLOR_SECONDARY).add_modifier(Modifier::ITALIC)), + ])); + } + + let guide = if thinking_hovered { + Span::styled(" ┃ ", Style::default().fg(COLOR_PRIMARY).add_modifier(Modifier::BOLD)) + } else { + Span::styled(" │ ", Style::default().fg(COLOR_DIM)) + }; + + for line in thought.lines() { + let text = if thinking_hovered { + Span::styled(line.to_string(), Style::default().fg(COLOR_PRIMARY).add_modifier(Modifier::ITALIC)) + } else { + Span::styled(line.to_string(), Style::default().fg(COLOR_SECONDARY).add_modifier(Modifier::ITALIC)) + }; + lines.push(Line::from(vec![guide.clone(), text])); + } + } + } + + if let Some(tool_calls) = &m.tool_calls { + for tc in tool_calls { + let args: serde_json::Value = serde_json::from_str(&tc.function.arguments).unwrap_or(serde_json::json!({})); + let arg_preview = if let Some(path) = args["path"].as_str() { + format!("({})", path) + } else { + format!("({})", tc.function.name) + }; + + lines.push(Line::from(vec![ + Span::styled(" 🛠 ", Style::default().fg(COLOR_PRIMARY)), + Span::styled(format!("Using {} ", tc.function.name), Style::default().fg(COLOR_TEXT)), + Span::styled(arg_preview, Style::default().fg(COLOR_SECONDARY).add_modifier(Modifier::DIM)), + ])); + } + } + + if let Some(content) = &m.content { + let mut in_code_block = false; + for line in content.lines() { + if line.trim().starts_with("```") { + in_code_block = !in_code_block; + lines.push(Line::from(vec![Span::raw(" "), Span::styled(line.to_string(), Style::default().fg(COLOR_PRIMARY).add_modifier(Modifier::BOLD))])); + } else { + let mut line_spans = vec![Span::raw(" ")]; + line_spans.extend(parse_markdown_line(line, in_code_block)); + lines.push(Line::from(line_spans)); + } + } + } + } + Role::Tool => { + lines.push(Line::from(vec![ + Span::styled(format!(" ✓ Tool ({})", m.name.as_deref().unwrap_or("result")), Style::default().fg(COLOR_SECONDARY)), + ])); + if let Some(content) = &m.content { + if let Ok(res) = serde_json::from_str::(content) { + if let Some(diff) = res.diff { + for line in diff.lines() { + let style = if line.starts_with('+') { + Style::default().fg(COLOR_SUCCESS) + } else if line.starts_with('-') { + Style::default().fg(Color::Red) + } else { + Style::default().fg(COLOR_DIM) + }; + lines.push(Line::from(vec![Span::raw(" "), Span::styled(line.to_string(), style)])); + } + } else if let Some(out) = res.content { + let preview = if out.len() > 100 { + let end = out.char_indices().map(|(i, _)| i).take_while(|&i| i <= 100).last().unwrap_or(0); + format!("{}...", &out[..end]) + } else { out }; + lines.push(Line::from(vec![Span::styled(format!(" {}", preview), Style::default().fg(COLOR_DIM).add_modifier(Modifier::DIM))])); + } else if let Some(err) = res.error { + lines.push(Line::from(vec![Span::styled(format!(" Error: {}", err), Style::default().fg(Color::Red))])); + } + } else { + let preview = if content.len() > 100 { + let end = content.char_indices().map(|(i, _)| i).take_while(|&i| i <= 100).last().unwrap_or(0); + format!("{}...", &content[..end]) + } else { content.to_string() }; + lines.push(Line::from(vec![Span::styled(format!(" {}", preview), Style::default().fg(COLOR_DIM).add_modifier(Modifier::DIM))])); + } + } + } + Role::System => { + lines.push(Line::from(vec![ + Span::styled(" ● System", Style::default().fg(COLOR_SYSTEM).add_modifier(Modifier::DIM)), + ])); + if let Some(content) = &m.content { + lines.push(Line::from(vec![Span::styled(format!(" {}", content), Style::default().fg(COLOR_SYSTEM).add_modifier(Modifier::DIM))])); + } + } + } + lines.push(Line::from("")); + } + Text::from(lines) +} + +fn parse_markdown_line<'a>(line: &'a str, in_code_block: bool) -> Vec> { + if in_code_block { + return vec![Span::styled(line.to_string(), Style::default().fg(COLOR_SECONDARY))]; + } + + let trimmed = line.trim_start(); + let indent = &line[..line.len() - trimmed.len()]; + + if trimmed.starts_with("# ") || trimmed.starts_with("## ") || trimmed.starts_with("### ") || trimmed.starts_with("#### ") { + let mut spans = vec![Span::raw(indent.to_string())]; + spans.push(Span::styled(trimmed.to_string(), Style::default().fg(COLOR_PRIMARY).add_modifier(Modifier::BOLD))); + return spans; + } else if trimmed.starts_with("- ") || trimmed.starts_with("* ") { + let mut spans = vec![ + Span::raw(indent.to_string()), + Span::styled(trimmed[..2].to_string(), Style::default().fg(COLOR_PRIMARY).add_modifier(Modifier::BOLD)), + ]; + spans.extend(parse_inline_markdown(&trimmed[2..])); + return spans; + } + + let mut spans = vec![Span::raw(indent.to_string())]; + spans.extend(parse_inline_markdown(trimmed)); + spans +} + +fn parse_inline_markdown(text: &str) -> Vec> { + let mut spans = Vec::new(); + let mut current = String::new(); + let mut chars = text.chars().peekable(); + + let mut bold = false; + let mut italic = false; + let mut code = false; + + while let Some(c) = chars.next() { + if c == '`' { + if !current.is_empty() { + let mut style = Style::default(); + if code { + style = style.fg(COLOR_PRIMARY); + } else { + if bold { style = style.add_modifier(Modifier::BOLD); } + if italic { style = style.add_modifier(Modifier::ITALIC); } + } + spans.push(Span::styled(current.clone(), style)); + current.clear(); + } + code = !code; + continue; + } + + if !code && c == '*' { + if let Some(&next_c) = chars.peek() { + if next_c == '*' { + chars.next(); // Consume second '*' + if !current.is_empty() { + let mut style = Style::default(); + if bold { style = style.add_modifier(Modifier::BOLD); } + if italic { style = style.add_modifier(Modifier::ITALIC); } + spans.push(Span::styled(current.clone(), style)); + current.clear(); + } + bold = !bold; + continue; + } + } + if !current.is_empty() { + let mut style = Style::default(); + if bold { style = style.add_modifier(Modifier::BOLD); } + if italic { style = style.add_modifier(Modifier::ITALIC); } + spans.push(Span::styled(current.clone(), style)); + current.clear(); + } + italic = !italic; + continue; + } + + current.push(c); + } + + if !current.is_empty() { + let mut style = Style::default(); + if code { + style = style.fg(COLOR_PRIMARY); + } else { + if bold { style = style.add_modifier(Modifier::BOLD); } + if italic { style = style.add_modifier(Modifier::ITALIC); } + } + spans.push(Span::styled(current, style)); + } + + spans +} diff --git a/apps/cli/src/ui/welcome.rs b/apps/cli/src/ui/welcome.rs new file mode 100644 index 0000000..eb15161 --- /dev/null +++ b/apps/cli/src/ui/welcome.rs @@ -0,0 +1,151 @@ +use ratatui::layout::Rect; +use ratatui::style::{Modifier, Style}; +use ratatui::text::{Line, Span}; +use ratatui::widgets::{Block, Borders, Paragraph}; +use ratatui::Frame; +use crate::ui::App; +use crate::ui::components::{COLOR_INPUT_BG, COLOR_PRIMARY, COLOR_SECONDARY, COLOR_TEXT, clean_model_name}; + +pub fn ui_welcome(f: &mut Frame, app: &mut App, area: Rect) -> Rect { + let logo_height = if area.height < 20 { 0 } else { 6 }; + let spacer_height = if area.height < 15 { 0 } else { area.height / 3 }; + let input_lines = app.input.lines().len() as u16; + let input_height = (input_lines + 2).min(12); + + let chunks = ratatui::layout::Layout::default() + .direction(ratatui::layout::Direction::Vertical) + .constraints([ + ratatui::layout::Constraint::Length(spacer_height), + ratatui::layout::Constraint::Length(logo_height), + ratatui::layout::Constraint::Length(input_height), + ratatui::layout::Constraint::Length(1), + ratatui::layout::Constraint::Length(1), + ratatui::layout::Constraint::Min(0), + ratatui::layout::Constraint::Length(1) + ]) + .split(area); + + if logo_height > 0 { + let config_guard = app.orchestrator.config.try_lock(); + let (animation_mode, animation_color) = if let Ok(ref config) = config_guard { + (config.logo_animation.clone(), config.logo_animation_color.clone()) + } else { + ("always".to_string(), "rainbow".to_string()) + }; + + let colors = match animation_color.as_str() { + "neon" => vec![ + ratatui::style::Color::Rgb(0, 255, 127), + ratatui::style::Color::Rgb(0, 255, 255), + ratatui::style::Color::Rgb(57, 255, 20), + ratatui::style::Color::Rgb(0, 191, 255), + ], + "cyberpunk" => vec![ + ratatui::style::Color::Rgb(255, 0, 127), + ratatui::style::Color::Rgb(255, 0, 255), + ratatui::style::Color::Rgb(138, 43, 226), + ratatui::style::Color::Rgb(0, 255, 255), + ], + "sunset" => vec![ + ratatui::style::Color::Rgb(255, 69, 0), + ratatui::style::Color::Rgb(255, 140, 0), + ratatui::style::Color::Rgb(255, 215, 0), + ratatui::style::Color::Rgb(255, 105, 180), + ], + "mono" => vec![ + ratatui::style::Color::Rgb(240, 240, 240), + ratatui::style::Color::Rgb(180, 180, 180), + ratatui::style::Color::Rgb(120, 120, 120), + ratatui::style::Color::Rgb(80, 80, 80), + ], + _ => vec![ + ratatui::style::Color::Rgb(255, 50, 50), + ratatui::style::Color::Rgb(255, 150, 50), + ratatui::style::Color::Rgb(255, 255, 50), + ratatui::style::Color::Rgb(50, 255, 50), + ratatui::style::Color::Rgb(50, 150, 255), + ratatui::style::Color::Rgb(150, 50, 255), + ratatui::style::Color::Rgb(255, 50, 150), + ], + }; + + let small_logo = [ + " __ _ ", + " |__) _|_ _ _/ _ _| _ ", + " | \\(_|(_(- \\__(_)(_|(/_ " + ]; + + let large_logo = [ + " ____ _ ____ _ ", + " | _ \\ ___ _ _| |_ ___ / ___|___ __| | ___ ", + " | |_) / _ \\| | | | __/ _ \\ | / _ \\ / _` |/ _ \\", + " | _ < (_) | |_| | || __/ |__| (_) | (_| | __/", + " |_| \\_\\___/ \\__,_|\\__\\___|\\____\\___/ \\__,_|\\___|" + ]; + + let logo_lines = if area.width < 60 { &small_logo[..] } else { &large_logo[..] }; + let logo_width = logo_lines[0].len() as u16; + let start_x = area.x + (area.width.saturating_sub(logo_width)) / 2; + let end_x = start_x + logo_width; + let start_y = chunks[1].y; + let end_y = start_y + logo_height; + + let is_hovering = if let (Some(col), Some(row)) = (app.mouse_col, app.mouse_row) { + col >= start_x && col < end_x && row >= start_y && row < end_y + } else { + false + }; + + let is_animating = match animation_mode.as_str() { + "always" => true, + "hover" => is_hovering, + "click" => app.logo_anim_frames > 0, + _ => true, + }; + + let mut logo_text = Vec::new(); + + for (i, line) in logo_lines.iter().enumerate() { + let style = if is_animating { + let color_idx = (app.tick_count as usize + i * 2) % colors.len(); + Style::default().fg(colors[color_idx]).add_modifier(Modifier::BOLD) + } else { + Style::default().fg(COLOR_TEXT).add_modifier(Modifier::BOLD) + }; + logo_text.push(Line::from(Span::styled(*line, style))); + } + f.render_widget(Paragraph::new(logo_text).alignment(ratatui::layout::Alignment::Center), chunks[1]); + } + + let input_width_percent = if area.width < 50 { 0.95 } else if area.width < 100 { 0.8 } else { 0.6 }; + let input_width = (area.width as f32 * input_width_percent) as u16; + let input_area = Rect::new((area.width - input_width) / 2, chunks[2].y, input_width, input_height); + + f.render_widget(Block::default().style(Style::default().bg(COLOR_INPUT_BG)), input_area); + + let inner_input_area = Rect::new(input_area.x + 1, input_area.y + 1, input_area.width.saturating_sub(2), input_area.height.saturating_sub(2)); + app.input.set_block(Block::default().borders(Borders::NONE)); + f.render_widget(app.input.widget(), inner_input_area); + + f.set_cursor(inner_input_area.x + app.input.cursor().1 as u16, inner_input_area.y + app.input.cursor().0 as u16); + + let cleaned_model = clean_model_name(&app.current_model, &app.current_provider_id); + let provider_info = vec![ + Span::styled("Model ", Style::default().fg(COLOR_SECONDARY)), + Span::styled(cleaned_model, Style::default().fg(COLOR_PRIMARY).add_modifier(Modifier::BOLD)), + Span::styled(" • Provider ", Style::default().fg(COLOR_SECONDARY)), + Span::styled(&app.provider_name, Style::default().fg(COLOR_TEXT)), + ]; + f.render_widget(Paragraph::new(Line::from(provider_info)).alignment(ratatui::layout::Alignment::Center), chunks[4]); + + let spinner = ["⠋", "⠙", "⠹", "⠸", "⠼", "⠴", "⠦", "⠧", "⠇", "⠏"]; + let frame = spinner[(app.tick_count % spinner.len() as u64) as usize]; + let tip_text = if app.is_generating { + format!(" {} AI is working... ", frame) + } else { + "ctrl+p help | esc exit".to_string() + }; + f.render_widget(Paragraph::new(tip_text).alignment(ratatui::layout::Alignment::Center).style(Style::default().fg(COLOR_SECONDARY).add_modifier(Modifier::DIM)), chunks[6]); + + input_area +} diff --git a/cli_build_debug.bat b/cli_build_debug.bat index 702eed5..f8261c5 100644 --- a/cli_build_debug.bat +++ b/cli_build_debug.bat @@ -8,5 +8,6 @@ if %ERRORLEVEL% EQU 0 ( ) else ( echo. echo Build failed! - exit /b %ERRORLEVEL% + pause + ) diff --git a/libs/sdk/Cargo.toml b/libs/sdk/Cargo.toml index f9143c4..de98884 100644 --- a/libs/sdk/Cargo.toml +++ b/libs/sdk/Cargo.toml @@ -7,7 +7,7 @@ description = "Core logic for RouteCode" [dependencies] tokio = { version = "1.37", features = ["full"] } -serde = { version = "1.0", features = ["derive"] } +serde = { version = "1.0", features = ["derive", "rc"] } serde_json = "1.0" reqwest = { version = "0.12", features = ["json", "stream"] } anyhow = "1.0" @@ -23,6 +23,9 @@ uuid = { version = "1.8", features = ["v4", "serde"] } log = "0.4" env_logger = "0.11" similar = "2.5" +once_cell = "1.19" +glob = "0.3" +regex = "1.10" [dev-dependencies] tempfile = "3.10" diff --git a/libs/sdk/src/agents/anthropic.rs b/libs/sdk/src/agents/anthropic.rs index c17fee6..dc6decc 100644 --- a/libs/sdk/src/agents/anthropic.rs +++ b/libs/sdk/src/agents/anthropic.rs @@ -1,6 +1,7 @@ use crate::agents::traits::{AIProvider, StreamResponse}; -use crate::agents::types::{StreamChunk, Usage}; -use crate::core::{Message, Role, ToolCall, FunctionCall}; +use crate::agents::types::StreamChunk; +use crate::agents::utils::parse_anthropic_sse; +use crate::core::{Message, Role, ToolCall}; use async_stream::stream; use async_trait::async_trait; use futures::StreamExt; @@ -17,7 +18,10 @@ impl AnthropicProvider { pub fn new(api_key: String) -> Self { Self { api_key, - client: Client::new(), + client: Client::builder() + .timeout(std::time::Duration::from_secs(60)) + .build() + .unwrap_or_else(|_| Client::new()), } } } @@ -32,6 +36,9 @@ impl AIProvider for AnthropicProvider { // Anthropic doesn't have a public models endpoint in the same way OpenAI does that's easily accessible without specific permissions // Returning a common set of models Ok(vec![ + "claude-sonnet-4-5".to_string(), + "claude-sonnet-4-20250514".to_string(), + "claude-opus-4-20250514".to_string(), "claude-3-5-sonnet-20240620".to_string(), "claude-3-opus-20240229".to_string(), "claude-3-sonnet-20240229".to_string(), @@ -44,6 +51,7 @@ impl AIProvider for AnthropicProvider { messages: Vec, model: &str, tools: Option>, + _thinking_level: Option<&str>, ) -> Result { let mut anthropic_messages = Vec::new(); let mut system_prompt = String::new(); @@ -55,17 +63,46 @@ impl AIProvider for AnthropicProvider { system_prompt.push_str(content); } } - _ => { - let role_str = match msg.role { - Role::User => "user", - Role::Assistant => "assistant", - _ => "user", - }; + Role::User => { anthropic_messages.push(json!({ - "role": role_str, + "role": "user", "content": msg.content.unwrap_or_default(), })); } + Role::Assistant => { + let mut content = Vec::new(); + if let Some(t) = &msg.thought { + content.push(json!({ "type": "thought", "thought": t })); + } + if let Some(c) = &msg.content { + content.push(json!({ "type": "text", "text": c })); + } + if let Some(calls) = &msg.tool_calls { + for tc in calls { + let input: Value = serde_json::from_str(&tc.function.arguments).unwrap_or(json!({})); + content.push(json!({ + "type": "tool_use", + "id": tc.id, + "name": tc.function.name, + "input": input, + })); + } + } + anthropic_messages.push(json!({ + "role": "assistant", + "content": content, + })); + } + Role::Tool => { + anthropic_messages.push(json!({ + "role": "user", + "content": [{ + "type": "tool_result", + "tool_use_id": msg.tool_call_id.unwrap_or_default(), + "content": msg.content.unwrap_or_default(), + }], + })); + } } } @@ -73,7 +110,7 @@ impl AIProvider for AnthropicProvider { "model": model, "messages": anthropic_messages, "stream": true, - "max_tokens": 4096, + "max_tokens": 16384, }); if !system_prompt.is_empty() { @@ -110,68 +147,15 @@ impl AIProvider for AnthropicProvider { let mut bytes_stream = response.bytes_stream(); let mut buffer = String::new(); - let mut active_tool_calls: HashMap = HashMap::new(); + let mut active_tool_calls: HashMap = HashMap::new(); let s = stream! { while let Some(item) = bytes_stream.next().await { match item { Ok(bytes) => { - buffer.push_str(&String::from_utf8_lossy(&bytes)); - while let Some(line_end) = buffer.find('\n') { - let line = buffer[..line_end].to_string(); - buffer.drain(..=line_end); - let line = line.trim(); - if line.is_empty() { continue; } - - if let Some(data) = line.strip_prefix("data: ") { - if let Ok(val) = serde_json::from_str::(data) { - let event_type = val["type"].as_str().unwrap_or(""); - match event_type { - "content_block_delta" => { - if let Some(delta) = val.get("delta") { - if let Some(text) = delta["text"].as_str() { - yield Ok(StreamChunk::Text { content: text.to_string() }); - } - if let Some(_partial_json) = delta["partial_json"].as_str() { - // Handle partial tool call JSON - // In Anthropic, we get tool_use blocks - } - } - } - "content_block_start" => { - if let Some(block) = val.get("content_block") { - if block["type"] == "tool_use" { - let id = block["id"].as_str().unwrap_or("").to_string(); - let name = block["name"].as_str().unwrap_or("").to_string(); - active_tool_calls.insert(id.clone(), ToolCall { - id, - r#type: "function".to_string(), - index: None, - function: FunctionCall { - name, - arguments: String::new(), - }, - }); - } - } - } - "message_delta" => { - if let Some(usage) = val.get("usage") { - let prompt = usage["input_tokens"].as_u64().unwrap_or(0) as u32; - let completion = usage["output_tokens"].as_u64().unwrap_or(0) as u32; - yield Ok(StreamChunk::Usage { - usage: Usage { - prompt_tokens: prompt, - completion_tokens: completion, - total_tokens: prompt + completion, - } - }); - } - } - _ => {} - } - } - } + let chunks = parse_anthropic_sse(&mut buffer, &mut active_tool_calls, &String::from_utf8_lossy(&bytes)); + for chunk in chunks { + yield Ok(chunk); } } Err(e) => yield Err(anyhow::Error::from(e)), diff --git a/libs/sdk/src/agents/cloudflare.rs b/libs/sdk/src/agents/cloudflare.rs index 757d9c0..33c9375 100644 --- a/libs/sdk/src/agents/cloudflare.rs +++ b/libs/sdk/src/agents/cloudflare.rs @@ -20,7 +20,10 @@ impl CloudflareWorkersAI { Self { account_id, api_token, - client: Client::new(), + client: Client::builder() + .timeout(std::time::Duration::from_secs(60)) + .build() + .unwrap_or_else(|_| Client::new()), } } } @@ -94,6 +97,7 @@ impl AIProvider for CloudflareWorkersAI { messages: Vec, model: &str, _tools: Option>, + _thinking_level: Option<&str>, ) -> Result { // Workers AI has an OpenAI-compatible endpoint now, which is easier to use. let url = format!( @@ -101,12 +105,18 @@ impl AIProvider for CloudflareWorkersAI { self.account_id ); - let body = json!({ + let mut body = json!({ "model": model, "messages": messages, "stream": true, }); + if let Some(t) = _tools { + if !t.is_empty() { + body["tools"] = json!(t); + } + } + let response = self .client .post(&url) @@ -129,9 +139,7 @@ impl AIProvider for CloudflareWorkersAI { match item { Ok(bytes) => { let chunks = parse_sse_buffer(&mut buffer, &mut active_tool_calls, &String::from_utf8_lossy(&bytes)); - for chunk in chunks { - yield Ok(chunk); - } + for chunk in chunks { yield Ok(chunk); } } Err(e) => yield Err(anyhow::Error::from(e)), } @@ -156,7 +164,10 @@ impl CloudflareAIGateway { account_id, gateway_id, api_token, - client: Client::new(), + client: Client::builder() + .timeout(std::time::Duration::from_secs(60)) + .build() + .unwrap_or_else(|_| Client::new()), } } } @@ -176,6 +187,7 @@ impl AIProvider for CloudflareAIGateway { messages: Vec, model: &str, tools: Option>, + _thinking_level: Option<&str>, ) -> Result { // AI Gateway works as a proxy. // If the model is in "provider/model" format, we use the /compat endpoint. @@ -216,9 +228,7 @@ impl AIProvider for CloudflareAIGateway { match item { Ok(bytes) => { let chunks = parse_sse_buffer(&mut buffer, &mut active_tool_calls, &String::from_utf8_lossy(&bytes)); - for chunk in chunks { - yield Ok(chunk); - } + for chunk in chunks { yield Ok(chunk); } } Err(e) => yield Err(anyhow::Error::from(e)), } diff --git a/libs/sdk/src/agents/gemini.rs b/libs/sdk/src/agents/gemini.rs index a9a0fa6..1bef924 100644 --- a/libs/sdk/src/agents/gemini.rs +++ b/libs/sdk/src/agents/gemini.rs @@ -1,11 +1,12 @@ use crate::agents::traits::{AIProvider, StreamResponse}; use crate::agents::types::StreamChunk; -use crate::core::{Message, Role}; +use crate::core::{FunctionCall, Message, Role, ToolCall}; use async_stream::stream; use async_trait::async_trait; use futures::StreamExt; use reqwest::Client; use serde_json::{json, Value}; +use uuid::Uuid; pub struct GeminiProvider { api_key: String, @@ -16,7 +17,10 @@ impl GeminiProvider { pub fn new(api_key: String) -> Self { Self { api_key, - client: Client::new(), + client: Client::builder() + .timeout(std::time::Duration::from_secs(60)) + .build() + .unwrap_or_else(|_| Client::new()), } } } @@ -28,11 +32,34 @@ impl AIProvider for GeminiProvider { } async fn list_models(&self) -> Result, anyhow::Error> { - Ok(vec![ - "gemini-1.5-pro".to_string(), - "gemini-1.5-flash".to_string(), - "gemini-1.0-pro".to_string(), - ]) + let url = format!( + "https://generativelanguage.googleapis.com/v1beta/models?key={}", + self.api_key + ); + + let response = self.client.get(&url).send().await?; + + if !response.status().is_success() { + return Ok(vec![ + "gemini-1.5-pro".to_string(), + "gemini-1.5-flash".to_string(), + "gemini-2.0-flash-exp".to_string(), + ]); + } + + let val: Value = response.json().await?; + let mut models = Vec::new(); + + if let Some(models_arr) = val["models"].as_array() { + for m in models_arr { + if let Some(name) = m["name"].as_str() { + let id = name.strip_prefix("models/").unwrap_or(name); + models.push(id.to_string()); + } + } + } + + Ok(models) } async fn ask( @@ -40,30 +67,105 @@ impl AIProvider for GeminiProvider { messages: Vec, model: &str, _tools: Option>, + _thinking_level: Option<&str>, ) -> Result { - let mut contents = Vec::new(); - for msg in messages { - let role = match msg.role { - Role::User => "user", - Role::Assistant => "model", - _ => "user", // Default to user for others - }; - contents.push(json!({ - "role": role, - "parts": [{"text": msg.content.unwrap_or_default()}] - })); - } - let url = format!( "https://generativelanguage.googleapis.com/v1beta/models/{}:streamGenerateContent?key={}", model, self.api_key ); - let body = json!({ + let mut contents = Vec::new(); + let mut system_instruction = String::new(); + + for msg in messages { + match msg.role { + Role::System => { + if let Some(c) = &msg.content { + system_instruction.push_str(c); + } + } + Role::User => { + contents.push(json!({ + "role": "user", + "parts": [{ "text": msg.content.unwrap_or_default() }] + })); + } + Role::Assistant => { + let mut parts = Vec::new(); + if let Some(c) = &msg.content { + parts.push(json!({ "text": c })); + } + if let Some(calls) = &msg.tool_calls { + for tc in calls { + let args: Value = serde_json::from_str(&tc.function.arguments).unwrap_or(json!({})); + parts.push(json!({ + "functionCall": { + "name": tc.function.name, + "args": args + } + })); + } + } + if !parts.is_empty() { + contents.push(json!({ + "role": "model", + "parts": parts + })); + } + } + Role::Tool => { + let fn_name = msg.name.clone().unwrap_or_else(|| "tool".to_string()); + contents.push(json!({ + "role": "function", + "parts": [{ + "functionResponse": { + "name": fn_name, + "response": { "result": msg.content.unwrap_or_default() } + } + }] + })); + } + } + } + + let mut body = json!({ "contents": contents, }); - let response = self.client.post(&url).json(&body).send().await?; + if !system_instruction.is_empty() { + body["systemInstruction"] = json!({ + "parts": [{ "text": system_instruction }] + }); + } + + if let Some(level) = _thinking_level { + if level != "default" { + body["thinking_level"] = json!(level); + } + } + + if let Some(t) = _tools { + let mut gemini_tools = Vec::new(); + for tool in t { + if let Some(f) = tool.get("function") { + gemini_tools.push(json!({ + "name": f["name"], + "description": f["description"], + "parameters": f["parameters"], + })); + } + } + if !gemini_tools.is_empty() { + body["tools"] = json!([{ "function_declarations": gemini_tools }]); + } + } + + let response = self + .client + .post(&url) + .json(&body) + .send() + .await?; if !response.status().is_success() { let err_text = response.text().await?; @@ -78,16 +180,53 @@ impl AIProvider for GeminiProvider { match item { Ok(bytes) => { buffer.push_str(&String::from_utf8_lossy(&bytes)); - // Gemini returns a JSON array of objects over the stream, but not standard SSE - // It's actually a bit tricky to parse manually without a proper stream decoder if it's large - // For now, let's assume it's small enough or comes in chunks of JSON objects - if let Ok(val) = serde_json::from_str::(&buffer) { - if let Some(candidates) = val[0]["candidates"].as_array() { - if let Some(text) = candidates[0]["content"]["parts"][0]["text"].as_str() { - yield Ok(StreamChunk::Text { content: text.to_string() }); + loop { + let trimmed = buffer.trim_start_matches(|c: char| c == '[' || c == ']' || c == ',' || c.is_whitespace()); + if trimmed.len() < buffer.len() { + let cut = buffer.len() - trimmed.len(); + buffer.drain(..cut); + } + + if buffer.is_empty() { + break; + } + + let mut stream = serde_json::Deserializer::from_str(&buffer).into_iter::(); + match stream.next() { + Some(Ok(val)) => { + let offset = stream.byte_offset(); + if let Some(candidates) = val["candidates"].as_array() { + if let Some(content) = candidates[0].get("content") { + if let Some(parts) = content["parts"].as_array() { + for part in parts { + if let Some(text) = part["text"].as_str() { + yield Ok(StreamChunk::Text { content: text.to_string() }); + } + if let Some(fn_call) = part.get("functionCall") { + let name = fn_call["name"].as_str().unwrap_or("").to_string(); + let args = fn_call["args"].clone(); + let arguments = serde_json::to_string(&args).unwrap_or_default(); + let tool_call = ToolCall { + id: format!("call_{}", Uuid::new_v4().simple()), + r#type: "function".to_string(), + index: Some(0), + function: FunctionCall { + name, + arguments, + }, + }; + yield Ok(StreamChunk::ToolCall { tool_call }); + } + } + } + } + } + buffer.drain(..offset); + } + _ => { + break; } - } - buffer.clear(); + } } } Err(e) => yield Err(anyhow::Error::from(e)), diff --git a/libs/sdk/src/agents/mod.rs b/libs/sdk/src/agents/mod.rs index f5bef13..663ea0c 100644 --- a/libs/sdk/src/agents/mod.rs +++ b/libs/sdk/src/agents/mod.rs @@ -55,7 +55,7 @@ pub fn resolve_provider(provider_name: &str, api_key: String) -> std::sync::Arc< )), "opencode-zen" | "opencode_zen" => std::sync::Arc::new(OpenCodeProvider::new( api_key, - "https://opencode.ai/zen/go/v1".to_string(), + "https://opencode.ai/zen/v1".to_string(), "OpenCode Zen".to_string(), true, )), diff --git a/libs/sdk/src/agents/openai.rs b/libs/sdk/src/agents/openai.rs index 0002b3d..af8421d 100644 --- a/libs/sdk/src/agents/openai.rs +++ b/libs/sdk/src/agents/openai.rs @@ -1,4 +1,5 @@ use crate::agents::traits::{AIProvider, StreamResponse}; +use crate::agents::types::StreamChunk; use crate::agents::utils::parse_sse_buffer; use crate::core::{Message, ToolCall}; use async_stream::stream; @@ -21,7 +22,10 @@ impl OpenAIProvider { api_key, base_url, provider_name, - client: Client::new(), + client: Client::builder() + .timeout(std::time::Duration::from_secs(60)) + .build() + .unwrap_or_else(|_| Client::new()), } } } @@ -69,16 +73,24 @@ impl AIProvider for OpenAIProvider { messages: Vec, model: &str, tools: Option>, + thinking_level: Option<&str>, ) -> Result { let mut body = json!({ "model": model, "messages": messages, "stream": true, + "max_tokens": 16384, }); if let Some(t) = tools { body["tools"] = json!(t); } + + if let Some(level) = thinking_level { + if level != "default" { + body["thinking_level"] = json!(level); + } + } let url = if self.base_url.ends_with('/') { format!("{}chat/completions", self.base_url) @@ -115,14 +127,14 @@ impl AIProvider for OpenAIProvider { for chunk in chunks { yield Ok(chunk); } - } - Err(e) => { + } + Err(e) => { yield Err(anyhow::Error::from(e)); - } - } - } - }; - + } + } + } + yield Ok(StreamChunk::Done); + }; Ok(Box::pin(s)) } } diff --git a/libs/sdk/src/agents/opencode.rs b/libs/sdk/src/agents/opencode.rs index a5e34dd..405966c 100644 --- a/libs/sdk/src/agents/opencode.rs +++ b/libs/sdk/src/agents/opencode.rs @@ -1,5 +1,5 @@ use crate::agents::traits::{AIProvider, StreamResponse}; -use crate::agents::utils::parse_sse_buffer; +use crate::agents::utils::{parse_sse_buffer, parse_anthropic_sse}; use crate::core::{Message, ToolCall, Role}; use async_stream::stream; use async_trait::async_trait; @@ -23,17 +23,18 @@ impl OpenCodeProvider { base_url, is_zen, provider_name, - client: Client::new(), + client: Client::builder() + .timeout(std::time::Duration::from_secs(60)) + .build() + .unwrap_or_else(|_| Client::new()), } } fn get_prefixed_model(&self, model: &str) -> String { - let prefix = if self.is_zen { "opencode-zen/" } else { "opencode-go/" }; - if model.starts_with(prefix) { - model.to_string() - } else { - format!("{}{}", prefix, model) - } + // Based on documentation and error reports, the OpenCode API + // expects the raw model ID (e.g. "gpt-5.5") because the + // provider (Zen/Go) is already determined by the base URL. + model.to_string() } } @@ -71,21 +72,64 @@ impl AIProvider for OpenCodeProvider { } } - // Fallback + // Fallback based on documentation screenshots if self.is_zen { Ok(vec![ + "gpt-5.5".to_string(), + "gpt-5.5-pro".to_string(), + "gpt-5.4".to_string(), + "gpt-5.4-pro".to_string(), + "gpt-5.4-mini".to_string(), + "gpt-5.4-nano".to_string(), + "gpt-5.3-codex".to_string(), + "gpt-5.3-codex-spark".to_string(), + "gpt-5.2".to_string(), + "gpt-5.2-codex".to_string(), + "gpt-5.1".to_string(), "gpt-5.1-codex".to_string(), + "gpt-5.1-codex-max".to_string(), + "gpt-5.1-codex-mini".to_string(), + "gpt-5".to_string(), + "gpt-5-codex".to_string(), + "gpt-5-nano".to_string(), "claude-opus-4-7".to_string(), + "claude-opus-4-6".to_string(), + "claude-opus-4-5".to_string(), + "claude-opus-4-1".to_string(), + "claude-sonnet-4-6".to_string(), + "claude-sonnet-4-5".to_string(), + "claude-sonnet-4".to_string(), + "claude-haiku-4-5".to_string(), + "claude-3-5-haiku".to_string(), "gemini-3.1-pro".to_string(), + "gemini-3-flash".to_string(), + "qwen3.6-plus".to_string(), + "qwen3.6-plus-free".to_string(), + "qwen3.5-plus".to_string(), + "minimax-m2.7".to_string(), + "minimax-m2.5".to_string(), + "minimax-m2.5-free".to_string(), + "glm-5.1".to_string(), + "glm-5".to_string(), + "kimi-k2.6".to_string(), + "kimi-k2.5".to_string(), "big-pickle".to_string(), "deepseek-v4-flash-free".to_string(), ]) } else { Ok(vec![ + "glm-5".to_string(), "glm-5.1".to_string(), + "kimi-k2.5".to_string(), "kimi-k2.6".to_string(), + "mimo-v2.5".to_string(), + "mimo-v2.5-pro".to_string(), + "minimax-m2.5".to_string(), "minimax-m2.7".to_string(), + "qwen3.5-plus".to_string(), "qwen3.6-plus".to_string(), + "deepseek-v4-pro".to_string(), + "deepseek-v4-flash".to_string(), ]) } } @@ -95,6 +139,7 @@ impl AIProvider for OpenCodeProvider { messages: Vec, model: &str, _tools: Option>, + thinking_level: Option<&str>, ) -> Result { let prefixed_model = self.get_prefixed_model(model); let model_lower = model.to_lowercase(); @@ -105,7 +150,7 @@ impl AIProvider for OpenCodeProvider { } else if model_lower.starts_with("gpt") { format!("{}/responses", self.base_url) } else if model_lower.starts_with("gemini") { - // Google style endpoint + // Google style endpoint - append streaming suffix format!("{}/models/{}:streamGenerateContent", self.base_url, prefixed_model) } else if !self.is_zen && model_lower.contains("minimax") { // MiniMax in Go uses /messages @@ -119,43 +164,83 @@ impl AIProvider for OpenCodeProvider { if endpoint.ends_with("/messages") { // Anthropic Format let mut anthropic_messages = Vec::new(); - let mut system_prompt = String::new(); + let mut global_system = String::new(); for msg in messages { match msg.role { - Role::System => if let Some(content) = &msg.content { system_prompt.push_str(content); } - _ => { - let role_str = match msg.role { - Role::User => "user", - Role::Assistant => "assistant", - _ => "user", - }; - anthropic_messages.push(json!({ "role": role_str, "content": msg.content.unwrap_or_default() })); + Role::System => { + if let Some(c) = &msg.content { global_system.push_str(c); } + } + Role::User => { + anthropic_messages.push(json!({ "role": "user", "content": msg.content.unwrap_or_default() })); + } + Role::Assistant => { + let mut content = Vec::new(); + if let Some(t) = &msg.thought { + content.push(json!({ "type": "thought", "thought": t })); + } + if let Some(c) = &msg.content { + content.push(json!({ "type": "text", "text": c })); + } + if let Some(calls) = &msg.tool_calls { + for tc in calls { + let input: Value = serde_json::from_str(&tc.function.arguments).unwrap_or(json!({})); + content.push(json!({ + "type": "tool_use", + "id": tc.id, + "name": tc.function.name, + "input": input, + })); + } + } + anthropic_messages.push(json!({ "role": "assistant", "content": content })); + } + Role::Tool => { + // In Anthropic format, tool results are sent as 'user' role with 'tool_result' content + anthropic_messages.push(json!({ + "role": "user", + "content": [{ + "type": "tool_result", + "tool_use_id": msg.tool_call_id.unwrap_or_default(), + "content": msg.content.unwrap_or_default(), + }] + })); + } + } + } + let mut body = json!({ "model": prefixed_model, "messages": anthropic_messages, "stream": true, "max_tokens": 16384 }); + if !global_system.is_empty() { body["system"] = json!(global_system); } + + if let Some(level) = thinking_level { + if level != "default" { body["thinking_level"] = json!(level); } + } + + if let Some(t) = _tools { + let mut anthropic_tools = Vec::new(); + for tool in t { + if let Some(f) = tool.get("function") { + anthropic_tools.push(json!({ + "name": f["name"], + "description": f["description"], + "input_schema": f["parameters"], + })); } } + body["tools"] = json!(anthropic_tools); } - let mut body = json!({ "model": prefixed_model, "messages": anthropic_messages, "stream": true, "max_tokens": 4096 }); - if !system_prompt.is_empty() { body["system"] = json!(system_prompt); } let response = self.client.post(&endpoint).header("Authorization", format!("Bearer {}", self.api_key)).json(&body).send().await?; if !response.status().is_success() { return Err(anyhow::anyhow!("OpenCode error: {}", response.text().await?)); } let mut bytes_stream = response.bytes_stream(); let mut buffer = String::new(); + let mut active_tool_calls: HashMap = HashMap::new(); let s = stream! { while let Some(item) = bytes_stream.next().await { match item { Ok(bytes) => { - buffer.push_str(&String::from_utf8_lossy(&bytes)); - while let Some(line_end) = buffer.find('\n') { - let line = buffer[..line_end].to_string(); - buffer.drain(..=line_end); - if let Some(data) = line.trim().strip_prefix("data: ") { - if let Ok(val) = serde_json::from_str::(data) { - if val["type"] == "content_block_delta" { - if let Some(text) = val["delta"]["text"].as_str() { yield Ok(crate::agents::types::StreamChunk::Text { content: text.to_string() }); } - } - } - } + let chunks = parse_anthropic_sse(&mut buffer, &mut active_tool_calls, &String::from_utf8_lossy(&bytes)); + for chunk in chunks { + yield Ok(chunk); } } Err(e) => yield Err(anyhow::Error::from(e)), @@ -171,7 +256,26 @@ impl AIProvider for OpenCodeProvider { let role = match msg.role { Role::User => "user", Role::Assistant => "model", _ => "user" }; contents.push(json!({ "role": role, "parts": [{"text": msg.content.unwrap_or_default()}] })); } - let body = json!({ "contents": contents }); + let mut body = json!({ "contents": contents }); + + if let Some(level) = thinking_level { + if level != "default" { body["thinking_level"] = json!(level); } + } + + if let Some(t) = _tools { + let mut gemini_tools = Vec::new(); + for tool in t { + if let Some(f) = tool.get("function") { + gemini_tools.push(json!({ + "name": f["name"], + "description": f["description"], + "parameters": f["parameters"], + })); + } + } + body["tools"] = json!([{ "function_declarations": gemini_tools }]); + } + let response = self.client.post(&endpoint).header("Authorization", format!("Bearer {}", self.api_key)).json(&body).send().await?; if !response.status().is_success() { return Err(anyhow::anyhow!("OpenCode error: {}", response.text().await?)); } @@ -182,9 +286,13 @@ impl AIProvider for OpenCodeProvider { match item { Ok(bytes) => { buffer.push_str(&String::from_utf8_lossy(&bytes)); + // Google stream is often a JSON array or multiple objects if let Ok(val) = serde_json::from_str::(&buffer) { if let Some(candidates) = val[0]["candidates"].as_array() { if let Some(text) = candidates[0]["content"]["parts"][0]["text"].as_str() { yield Ok(crate::agents::types::StreamChunk::Text { content: text.to_string() }); } + } else if let Some(candidates) = val["candidates"].as_array() { + // Single object format + if let Some(text) = candidates[0]["content"]["parts"][0]["text"].as_str() { yield Ok(crate::agents::types::StreamChunk::Text { content: text.to_string() }); } } buffer.clear(); } @@ -197,7 +305,13 @@ impl AIProvider for OpenCodeProvider { Ok(Box::pin(s)) } else { // OpenAI Format (Default + GPT /responses) - let body = json!({ "model": prefixed_model, "messages": messages, "stream": true }); + let mut body = json!({ "model": prefixed_model, "messages": messages, "stream": true, "max_tokens": 16384 }); + if let Some(t) = _tools { body["tools"] = json!(t); } + + if let Some(level) = thinking_level { + if level != "default" { body["thinking_level"] = json!(level); } + } + let response = self.client.post(&endpoint).header("Authorization", format!("Bearer {}", self.api_key)).json(&body).send().await?; if !response.status().is_success() { return Err(anyhow::anyhow!("OpenCode error: {}", response.text().await?)); } diff --git a/libs/sdk/src/agents/openrouter.rs b/libs/sdk/src/agents/openrouter.rs index 47c287d..adcf571 100644 --- a/libs/sdk/src/agents/openrouter.rs +++ b/libs/sdk/src/agents/openrouter.rs @@ -1,4 +1,5 @@ use crate::agents::traits::{AIProvider, StreamResponse}; +use crate::agents::types::StreamChunk; use crate::agents::utils::parse_sse_buffer; use crate::core::{Message, ToolCall}; use async_stream::stream; @@ -17,7 +18,10 @@ impl OpenRouter { pub fn new(api_key: String) -> Self { Self { api_key, - client: Client::new(), + client: Client::builder() + .timeout(std::time::Duration::from_secs(60)) + .build() + .unwrap_or_else(|_| Client::new()), } } } @@ -59,6 +63,7 @@ impl AIProvider for OpenRouter { messages: Vec, model: &str, tools: Option>, + thinking_level: Option<&str>, ) -> Result { let mut body = json!({ "model": model, @@ -70,6 +75,14 @@ impl AIProvider for OpenRouter { body["tools"] = json!(t); } + if let Some(level) = thinking_level { + if level != "default" { + // OpenRouter often maps this to specific parameters or suffixes + // For now, we'll try passing it as a provider-specific field + body["provider"] = json!({ "thinking_level": level }); + } + } + let response = self .client .post("https://openrouter.ai/api/v1/chat/completions") @@ -97,62 +110,14 @@ impl AIProvider for OpenRouter { for chunk in chunks { yield Ok(chunk); } - } - Err(e) => { + } + Err(e) => { yield Err(anyhow::Error::from(e)); - } - } - } - }; - + } + } + } + yield Ok(StreamChunk::Done); + }; Ok(Box::pin(s)) } } - -#[cfg(test)] -mod tests { - use super::*; - use crate::agents::types::StreamChunk; - - #[test] - fn test_parse_sse_buffer_text() { - let mut buffer = String::new(); - let mut active_tool_calls = HashMap::new(); - - // Partial data - let data1 = "data: {\"choices\": [{\"delta\": {\"content\": \"Hello\"}}]}\n"; - let chunks = parse_sse_buffer(&mut buffer, &mut active_tool_calls, data1); - assert_eq!(chunks.len(), 1); - if let StreamChunk::Text { content } = &chunks[0] { - assert_eq!(content, "Hello"); - } - - // Split data across chunks - let data2 = "data: {\"choices\": [{\"delta\": {\"content\": \" world\"}}]}"; // No newline - let chunks = parse_sse_buffer(&mut buffer, &mut active_tool_calls, data2); - assert_eq!(chunks.len(), 0); // Should be buffered - - let chunks = parse_sse_buffer(&mut buffer, &mut active_tool_calls, "\n"); - assert_eq!(chunks.len(), 1); - if let StreamChunk::Text { content } = &chunks[0] { - assert_eq!(content, " world"); - } - } - - #[test] - fn test_parse_sse_buffer_tool_calls() { - let mut buffer = String::new(); - let mut active_tool_calls = HashMap::new(); - - let data = "data: {\"choices\": [{\"delta\": {\"tool_calls\": [{\"index\": 0, \"id\": \"call_1\", \"function\": {\"name\": \"ls\"}}]}}]}\ndata: {\"choices\": [{\"delta\": {\"tool_calls\": [{\"index\": 0, \"function\": {\"arguments\": \"{\\\"path\\\": \\\".\\\"}\"}}]}}]}\n"; - - let chunks = parse_sse_buffer(&mut buffer, &mut active_tool_calls, data); - assert_eq!(chunks.len(), 2); - - if let StreamChunk::ToolCall { tool_call } = &chunks[1] { - assert_eq!(tool_call.id, "call_1"); - assert_eq!(tool_call.function.name, "ls"); - assert_eq!(tool_call.function.arguments, "{\"path\": \".\"}"); - } - } -} diff --git a/libs/sdk/src/agents/traits.rs b/libs/sdk/src/agents/traits.rs index 4f52ed5..444db2b 100644 --- a/libs/sdk/src/agents/traits.rs +++ b/libs/sdk/src/agents/traits.rs @@ -15,5 +15,6 @@ pub trait AIProvider: Send + Sync { messages: Vec, model: &str, tools: Option>, + thinking_level: Option<&str>, ) -> Result; } diff --git a/libs/sdk/src/agents/types.rs b/libs/sdk/src/agents/types.rs index e6e6c83..5e8065f 100644 --- a/libs/sdk/src/agents/types.rs +++ b/libs/sdk/src/agents/types.rs @@ -1,5 +1,15 @@ -use crate::core::ToolCall; +use crate::core::{ToolCall, Message}; use serde::{Deserialize, Serialize}; +use std::sync::Arc; + +#[derive(Debug, Clone)] +pub enum ConfirmationResponse { + AllowOnce, + AllowSession, + AllowWorkspace, + Deny, + Feedback(String), +} #[derive(Debug, Clone, Serialize, Deserialize)] #[serde(tag = "type", rename_all = "snake_case")] @@ -24,6 +34,19 @@ pub enum StreamChunk { Error { content: String, }, + FinalHistory { + history: Vec, + }, + Models { + models: Vec, + }, + ModelsDone, + RequestConfirmation { + message: String, + target: String, + #[serde(skip)] + tx: Option>>>>, + }, Done, } diff --git a/libs/sdk/src/agents/utils.rs b/libs/sdk/src/agents/utils.rs index 8b057a1..5ae78c2 100644 --- a/libs/sdk/src/agents/utils.rs +++ b/libs/sdk/src/agents/utils.rs @@ -26,13 +26,14 @@ pub fn parse_sse_buffer( if let Ok(val) = serde_json::from_str::(data) { if let Some(choice) = val["choices"].get(0) { if let Some(delta) = choice.get("delta") { - if let Some(content) = delta["content"].as_str() { + if let Some(content) = delta.get("content").and_then(|v| v.as_str()) { chunks.push(StreamChunk::Text { content: content.to_string(), }); } if let Some(thought) = delta.get("reasoning_content").and_then(|v| v.as_str()) + .or_else(|| delta.get("thought").and_then(|v| v.as_str())) { chunks.push(StreamChunk::Thought { content: thought.to_string(), @@ -41,33 +42,35 @@ pub fn parse_sse_buffer( if let Some(tool_calls) = delta.get("tool_calls").and_then(|v| v.as_array()) { for tc_delta in tool_calls { - let index = tc_delta["index"].as_u64().unwrap_or(0) as usize; - let entry = - active_tool_calls.entry(index).or_insert_with(|| ToolCall { - index: Some(index), - id: String::new(), - r#type: "function".to_string(), - function: FunctionCall { - name: String::new(), - arguments: String::new(), - }, - }); + if let Some(idx_val) = tc_delta.get("index") { + let index = idx_val.as_u64().unwrap_or(0) as usize; + let entry = + active_tool_calls.entry(index).or_insert_with(|| ToolCall { + index: Some(index), + id: String::new(), + r#type: "function".to_string(), + function: FunctionCall { + name: String::new(), + arguments: String::new(), + }, + }); - if let Some(id) = tc_delta["id"].as_str() { - entry.id.push_str(id); - } - if let Some(f) = tc_delta.get("function") { - if let Some(name) = f["name"].as_str() { - entry.function.name.push_str(name); + if let Some(id) = tc_delta.get("id").and_then(|v| v.as_str()) { + entry.id = id.to_string(); } - if let Some(args) = f["arguments"].as_str() { - entry.function.arguments.push_str(args); + if let Some(f) = tc_delta.get("function") { + if let Some(name) = f.get("name").and_then(|v| v.as_str()) { + entry.function.name = name.to_string(); + } + if let Some(args) = f.get("arguments").and_then(|v| v.as_str()) { + entry.function.arguments.push_str(args); + } } - } - chunks.push(StreamChunk::ToolCall { - tool_call: entry.clone(), - }); + chunks.push(StreamChunk::ToolCall { + tool_call: entry.clone(), + }); + } } } } @@ -82,3 +85,93 @@ pub fn parse_sse_buffer( } chunks } + +pub fn parse_anthropic_sse( + buffer: &mut String, + active_tool_calls: &mut HashMap, + new_data: &str, +) -> Vec { + buffer.push_str(new_data); + let mut chunks = Vec::new(); + + while let Some(line_end) = buffer.find('\n') { + let line = buffer[..line_end].to_string(); + buffer.drain(..=line_end); + let line = line.trim(); + if line.is_empty() { + continue; + } + + if let Some(data) = line.strip_prefix("data: ") { + if let Ok(val) = serde_json::from_str::(data) { + let event_type = val["type"].as_str().unwrap_or(""); + match event_type { + "content_block_delta" => { + if let Some(delta) = val.get("delta") { + if let Some(text) = delta["text"].as_str() { + chunks.push(StreamChunk::Text { + content: text.to_string(), + }); + } + if let Some(thought) = delta["thought"].as_str() { + chunks.push(StreamChunk::Thought { + content: thought.to_string(), + }); + } + if let Some(partial_json) = delta["partial_json"].as_str() { + let index = val["index"].as_u64().unwrap_or(0) as usize; + if let Some(tool_call) = active_tool_calls.get_mut(&index) { + tool_call.function.arguments.push_str(partial_json); + chunks.push(StreamChunk::ToolCall { + tool_call: tool_call.clone(), + }); + } + } + } + } + "content_block_start" => { + let index = val["index"].as_u64().unwrap_or(0) as usize; + if let Some(block) = val.get("content_block") { + if block["type"] == "tool_use" { + let id = block["id"].as_str().unwrap_or("").to_string(); + let name = block["name"].as_str().unwrap_or("").to_string(); + let tool_call = ToolCall { + id, + r#type: "function".to_string(), + index: Some(index), + function: FunctionCall { + name, + arguments: String::new(), + }, + }; + active_tool_calls.insert(index, tool_call.clone()); + chunks.push(StreamChunk::ToolCall { tool_call }); + } + } + } + "content_block_stop" => { + let index = val["index"].as_u64().unwrap_or(0) as usize; + if let Some(tool_call) = active_tool_calls.remove(&index) { + chunks.push(StreamChunk::ToolCall { tool_call }); + } + } + "message_delta" => { + if let Some(usage) = val.get("usage") { + let prompt = usage["input_tokens"].as_u64().unwrap_or(0) as u32; + let completion = usage["output_tokens"].as_u64().unwrap_or(0) as u32; + chunks.push(StreamChunk::Usage { + usage: Usage { + prompt_tokens: prompt, + completion_tokens: completion, + total_tokens: prompt + completion, + }, + }); + } + } + _ => {} + } + } + } + } + chunks +} diff --git a/libs/sdk/src/core/config.rs b/libs/sdk/src/core/config.rs index d61fa58..ea1b2e4 100644 --- a/libs/sdk/src/core/config.rs +++ b/libs/sdk/src/core/config.rs @@ -21,6 +21,24 @@ pub struct Config { pub favorites: Vec, #[serde(default)] pub recent_models: Vec, + #[serde(default = "default_thinking_level")] + pub thinking_level: String, + #[serde(default = "default_logo_animation")] + pub logo_animation: String, + #[serde(default = "default_logo_animation_color")] + pub logo_animation_color: String, +} + +fn default_thinking_level() -> String { + "default".to_string() +} + +fn default_logo_animation() -> String { + "always".to_string() +} + +fn default_logo_animation_color() -> String { + "rainbow".to_string() } impl Default for Config { @@ -34,6 +52,9 @@ impl Default for Config { last_update_check: 0.0, favorites: Vec::new(), recent_models: Vec::new(), + thinking_level: "default".to_string(), + logo_animation: "always".to_string(), + logo_animation_color: "rainbow".to_string(), } } } diff --git a/libs/sdk/src/core/message.rs b/libs/sdk/src/core/message.rs index 845e3b4..0e19b54 100644 --- a/libs/sdk/src/core/message.rs +++ b/libs/sdk/src/core/message.rs @@ -1,4 +1,5 @@ use serde::{Deserialize, Serialize}; +use std::sync::Arc; #[derive(Debug, Clone, Serialize, Deserialize, PartialEq)] #[serde(rename_all = "lowercase")] @@ -9,7 +10,7 @@ pub enum Role { Tool, } -#[derive(Debug, Clone, Serialize, Deserialize)] +#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)] pub struct ToolCall { #[serde(default)] pub index: Option, @@ -18,18 +19,19 @@ pub struct ToolCall { pub function: FunctionCall, } -#[derive(Debug, Clone, Serialize, Deserialize)] +#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)] pub struct FunctionCall { pub name: String, pub arguments: String, } -#[derive(Debug, Clone, Serialize, Deserialize)] +#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)] pub struct Message { pub role: Role, - pub content: Option, #[serde(skip_serializing_if = "Option::is_none")] - pub thought: Option, + pub content: Option>, + #[serde(rename = "reasoning_content", skip_serializing_if = "Option::is_none")] + pub thought: Option>, #[serde(skip_serializing_if = "Option::is_none")] pub tool_calls: Option>, #[serde(skip_serializing_if = "Option::is_none")] @@ -39,7 +41,7 @@ pub struct Message { } impl Message { - pub fn user(content: impl Into) -> Self { + pub fn user(content: impl Into>) -> Self { Self { role: Role::User, content: Some(content.into()), @@ -51,8 +53,8 @@ impl Message { } pub fn assistant( - content: Option, - thought: Option, + content: Option>, + thought: Option>, tool_calls: Option>, ) -> Self { Self { @@ -65,10 +67,10 @@ impl Message { } } - pub fn tool(id: String, name: String, content: String) -> Self { + pub fn tool(id: String, name: String, content: impl Into>) -> Self { Self { role: Role::Tool, - content: Some(content), + content: Some(content.into()), thought: None, tool_calls: None, tool_call_id: Some(id), @@ -76,7 +78,7 @@ impl Message { } } - pub fn system(content: impl Into) -> Self { + pub fn system(content: impl Into>) -> Self { Self { role: Role::System, content: Some(content.into()), diff --git a/libs/sdk/src/core/orchestrator.rs b/libs/sdk/src/core/orchestrator.rs index a52f0f3..66a2422 100644 --- a/libs/sdk/src/core/orchestrator.rs +++ b/libs/sdk/src/core/orchestrator.rs @@ -12,6 +12,8 @@ pub struct AgentOrchestrator { tool_registry: Arc, pub config: Arc>, pub usage: Arc>, + pub allow_session_commands: std::sync::atomic::AtomicBool, + pub allow_session_outside_access: std::sync::atomic::AtomicBool, } impl AgentOrchestrator { @@ -25,6 +27,8 @@ impl AgentOrchestrator { tool_registry, config, usage: Arc::new(Mutex::new(Usage::default())), + allow_session_commands: std::sync::atomic::AtomicBool::new(false), + allow_session_outside_access: std::sync::atomic::AtomicBool::new(false), } } @@ -44,15 +48,19 @@ impl AgentOrchestrator { // 1. Build System Prompt with Project Context let mut system_content = String::from( "You are RouteCode, a senior software engineer AI coding assistant.\n\ - You help users with their codebase through a terminal interface.\n", + You help users with their codebase through a terminal interface.\n\ + You have access to tools for file operations, navigation, and bash commands.\n\ + When you need to explore or modify the codebase, use the appropriate tools.\n", ); + let project_root = crate::utils::storage::find_project_root(); + // Inject Project Context - if let Ok(readme) = std::fs::read_to_string("README.md") { + if let Ok(readme) = std::fs::read_to_string(project_root.join("README.md")) { system_content.push_str("\n--- PROJECT README ---\n"); system_content.push_str(&readme); } - if let Ok(routecode_md) = std::fs::read_to_string("ROUTECODE.md") { + if let Ok(routecode_md) = std::fs::read_to_string(project_root.join("ROUTECODE.md")) { system_content.push_str("\n--- PROJECT INSTRUCTIONS (ROUTECODE.md) ---\n"); system_content.push_str(&routecode_md); } @@ -79,14 +87,42 @@ impl AgentOrchestrator { model: &str, tx: Option>, ) -> Result<(), anyhow::Error> { + match self.run_with_depth(history, model, tx.clone(), 0).await { + Ok(_) => Ok(()), + Err(e) => { + if let Some(ref tx) = tx { + let _ = tx.send(StreamChunk::Error { content: e.to_string() }); + let _ = tx.send(StreamChunk::Done); + } + Err(e) + } + } + } + + async fn run_with_depth( + &self, + history: &mut Vec, + model: &str, + tx: Option>, + depth: usize, + ) -> Result<(), anyhow::Error> { + if depth >= 25 { + return Err(anyhow::anyhow!("Maximum tool recursion depth (25) reached. Aborting to prevent infinite loop.")); + } + let tools = Some(self.tool_registry.get_all_schemas()); let messages = self.prepare_messages(history).await; log::debug!("Sending AI request to model: {} (messages: {})", model, messages.len()); + let thinking_level = { + let config = self.config.lock().await; + config.thinking_level.clone() + }; + let stream = { let p = self.provider.lock().await; - p.ask(messages, model, tools).await? + p.ask(messages, model, tools, Some(&thinking_level)).await? }; let mut stream = stream; @@ -96,7 +132,13 @@ impl AgentOrchestrator { let mut tool_calls: Vec = Vec::new(); while let Some(chunk_res) = stream.next().await { - let chunk = chunk_res?; + let chunk = match chunk_res { + Ok(c) => c, + Err(e) => { + log::error!("Stream error: {}", e); + return Err(e); + } + }; log::debug!("Received chunk: {:?}", chunk); if let Some(ref tx) = tx { @@ -127,26 +169,34 @@ impl AgentOrchestrator { } StreamChunk::Usage { usage } => { let mut u = self.usage.lock().await; - u.add(usage.prompt_tokens, usage.completion_tokens, model); + u.add(usage.prompt_tokens, usage.completion_tokens, model).await; } StreamChunk::Error { content } => { return Err(anyhow::anyhow!("Provider error: {}", content)); } StreamChunk::ToolResult { .. } => {} StreamChunk::Done => {} + StreamChunk::FinalHistory { .. } => {} + StreamChunk::Models { .. } => {} + StreamChunk::ModelsDone => {} + StreamChunk::RequestConfirmation { .. } => {} } } let assistant_msg = Message::assistant( if assistant_content.is_empty() { - None + if !assistant_thought.is_empty() || !tool_calls.is_empty() { + Some(std::sync::Arc::from("")) + } else { + None + } } else { - Some(assistant_content) + Some(std::sync::Arc::from(assistant_content)) }, if assistant_thought.is_empty() { None } else { - Some(assistant_thought) + Some(std::sync::Arc::from(assistant_thought)) }, if tool_calls.is_empty() { None @@ -160,8 +210,108 @@ impl AgentOrchestrator { if !tool_calls.is_empty() { for tc in tool_calls { if let Some(tool) = self.tool_registry.get(&tc.function.name) { - let args = serde_json::from_str(&tc.function.arguments)?; - let result = tool.execute(args).await?; + let args: serde_json::Value = match serde_json::from_str(&tc.function.arguments) { + Ok(a) => a, + Err(e) => { + return Err(anyhow::anyhow!( + "Failed to parse tool arguments: {}. \ + This usually means the AI's response was truncated because it reached its output token limit. \ + Try asking for a smaller part of the task or increasing the limit.", + e + )); + } + }; + let mut execute_allowed = true; + let mut custom_error_msg = None; + + use std::sync::atomic::Ordering; + + if tc.function.name == "bash" { + if !self.allow_session_commands.load(Ordering::SeqCst) { + if let Some(ref sender) = tx { + let command_str = args["command"].as_str().unwrap_or("").to_string(); + let (oneshot_tx, oneshot_rx) = tokio::sync::oneshot::channel(); + let tx_wrapped = Arc::new(tokio::sync::Mutex::new(Some(oneshot_tx))); + + if let Err(e) = sender.send(StreamChunk::RequestConfirmation { + message: "The AI agent wants to execute the following bash command:".to_string(), + target: command_str, + tx: Some(tx_wrapped), + }) { + log::error!("Failed to send RequestConfirmation to UI: {}", e); + } + + match oneshot_rx.await { + Ok(crate::agents::types::ConfirmationResponse::AllowOnce) => {} + Ok(crate::agents::types::ConfirmationResponse::AllowSession) | Ok(crate::agents::types::ConfirmationResponse::AllowWorkspace) => { + self.allow_session_commands.store(true, Ordering::SeqCst); + } + Ok(crate::agents::types::ConfirmationResponse::Deny) => { + execute_allowed = false; + custom_error_msg = Some("Command execution denied by user.".to_string()); + } + Ok(crate::agents::types::ConfirmationResponse::Feedback(msg)) => { + execute_allowed = false; + custom_error_msg = Some(format!("Command execution denied by user with feedback: {}", msg)); + } + Err(_) => { + execute_allowed = false; + custom_error_msg = Some("Command execution cancelled (confirmation channel closed).".to_string()); + } + } + } + } + } else if ["file_read", "file_write", "file_edit", "ls", "tree", "grep"].contains(&tc.function.name.as_str()) { + if !self.allow_session_outside_access.load(Ordering::SeqCst) { + let path_str = args["path"].as_str().unwrap_or("."); + if crate::utils::storage::is_path_outside_workspace(path_str) { + if let Some(ref sender) = tx { + let (oneshot_tx, oneshot_rx) = tokio::sync::oneshot::channel(); + let tx_wrapped = Arc::new(tokio::sync::Mutex::new(Some(oneshot_tx))); + + if let Err(e) = sender.send(StreamChunk::RequestConfirmation { + message: "The AI agent wants to access a path OUTSIDE the current workspace:".to_string(), + target: path_str.to_string(), + tx: Some(tx_wrapped), + }) { + log::error!("Failed to send RequestConfirmation to UI: {}", e); + } + + match oneshot_rx.await { + Ok(crate::agents::types::ConfirmationResponse::AllowOnce) => {} + Ok(crate::agents::types::ConfirmationResponse::AllowSession) | Ok(crate::agents::types::ConfirmationResponse::AllowWorkspace) => { + self.allow_session_outside_access.store(true, Ordering::SeqCst); + } + Ok(crate::agents::types::ConfirmationResponse::Deny) => { + execute_allowed = false; + custom_error_msg = Some(format!("Access to outside path '{}' denied by user.", path_str)); + } + Ok(crate::agents::types::ConfirmationResponse::Feedback(msg)) => { + execute_allowed = false; + custom_error_msg = Some(format!("Access to outside path '{}' denied by user with feedback: {}", path_str, msg)); + } + Err(_) => { + execute_allowed = false; + custom_error_msg = Some("Access cancelled (confirmation channel closed).".to_string()); + } + } + } else { + // If there's no UI (headless), just block it by default. + execute_allowed = false; + custom_error_msg = Some(format!("Access to outside path '{}' denied (no UI confirmation available).", path_str)); + } + } + } + } + + let result = if execute_allowed { + match tool.execute(args).await { + Ok(res) => res, + Err(e) => crate::core::ToolResult::error(format!("Tool execution failed: {}", e)), + } + } else { + crate::core::ToolResult::error(custom_error_msg.unwrap_or_default()) + }; let content = serde_json::to_string(&result)?; let tool_msg = @@ -180,10 +330,11 @@ impl AgentOrchestrator { } } // Recurse after tool execution - return Box::pin(self.run(history, model, tx)).await; + return Box::pin(self.run_with_depth(history, model, tx, depth + 1)).await; } if let Some(ref tx) = tx { + let _ = tx.send(StreamChunk::FinalHistory { history: history.clone() }); if let Err(e) = tx.send(StreamChunk::Done) { log::error!("Failed to send Done chunk to UI: {}", e); } @@ -211,7 +362,7 @@ mod tests { impl AIProvider for MockProvider { fn name(&self) -> &str { "Mock" } async fn list_models(&self) -> Result, anyhow::Error> { Ok(vec!["mock".to_string()]) } - async fn ask(&self, _msgs: Vec, _model: &str, _tools: Option>) -> Result { + async fn ask(&self, _msgs: Vec, _model: &str, _tools: Option>, _thinking_level: Option<&str>) -> Result { let mut resps = self.responses.lock().await; if resps.is_empty() { return Err(anyhow::anyhow!("No more mock responses")); @@ -250,7 +401,7 @@ mod tests { assert_eq!(history.len(), 2); assert_eq!(history[1].role, Role::Assistant); - assert_eq!(history[1].content, Some("Hello!".to_string())); + assert_eq!(history[1].content.as_deref(), Some("Hello!")); } #[tokio::test] @@ -294,6 +445,6 @@ mod tests { assert!(history[1].tool_calls.is_some()); assert_eq!(history[2].role, Role::Tool); assert_eq!(history[3].role, Role::Assistant); - assert_eq!(history[3].content, Some("Tool executed!".to_string())); + assert_eq!(history[3].content.as_deref(), Some("Tool executed!")); } } diff --git a/libs/sdk/src/tools/file_ops.rs b/libs/sdk/src/tools/file_ops.rs index 9478dc5..f4daaed 100644 --- a/libs/sdk/src/tools/file_ops.rs +++ b/libs/sdk/src/tools/file_ops.rs @@ -12,12 +12,30 @@ fn normalize_path(path: &str) -> PathBuf { p = &p[11..]; } else if p.starts_with("/workspace") { p = &p[10..]; - } else if p.starts_with("/") { - p = &p[1..]; } PathBuf::from(p) } +fn is_within_workspace(path: &Path) -> Result { + if cfg!(test) { + return Ok(true); + } + let current_dir = std::env::current_dir()?.canonicalize()?; + let mut p = path; + while !p.exists() { + if let Some(parent) = p.parent() { + p = parent; + } else { + break; + } + } + if !p.exists() { + return Ok(true); + } + let target = p.canonicalize()?; + Ok(target.starts_with(current_dir)) +} + fn ensure_parent_dir(path: &Path) -> Result<(), std::io::Error> { if let Some(parent) = path.parent() { if !parent.exists() && !parent.as_os_str().is_empty() { @@ -67,6 +85,11 @@ impl Tool for FileReadTool { .as_str() .ok_or_else(|| anyhow::anyhow!("Missing path"))?; let path = normalize_path(raw_path); + match is_within_workspace(&path) { + Ok(true) => {} + Ok(false) => return Ok(ToolResult::error(format!("Access denied: Path '{}' is outside the workspace boundary", path.display()))), + Err(e) => return Ok(ToolResult::error(format!("Failed to verify path '{}': {}", path.display(), e))), + } match fs::read_to_string(&path) { Ok(content) => Ok(ToolResult::success(content)), Err(e) => Ok(ToolResult::error(format!("Failed to read file '{}': {}", path.display(), e))), @@ -104,6 +127,11 @@ impl Tool for FileWriteTool { .ok_or_else(|| anyhow::anyhow!("Missing content"))?; let path = normalize_path(raw_path); + match is_within_workspace(&path) { + Ok(true) => {} + Ok(false) => return Ok(ToolResult::error(format!("Access denied: Path '{}' is outside the workspace boundary", path.display()))), + Err(e) => return Ok(ToolResult::error(format!("Failed to verify path '{}': {}", path.display(), e))), + } let old_content = fs::read_to_string(&path).unwrap_or_default(); let diff = generate_diff(&old_content, content); @@ -154,6 +182,11 @@ impl Tool for FileEditTool { let allow_multiple = args["allow_multiple"].as_bool().unwrap_or(false); let path = normalize_path(raw_path); + match is_within_workspace(&path) { + Ok(true) => {} + Ok(false) => return Ok(ToolResult::error(format!("Access denied: Path '{}' is outside the workspace boundary", path.display()))), + Err(e) => return Ok(ToolResult::error(format!("Failed to verify path '{}': {}", path.display(), e))), + } let content = match fs::read_to_string(&path) { Ok(c) => c, Err(e) => return Ok(ToolResult::error(format!("Failed to read file '{}': {}", path.display(), e))), diff --git a/libs/sdk/src/tools/navigation.rs b/libs/sdk/src/tools/navigation.rs index 379e1bd..bbc789d 100644 --- a/libs/sdk/src/tools/navigation.rs +++ b/libs/sdk/src/tools/navigation.rs @@ -44,6 +44,74 @@ impl Tool for LsTool { } } +pub struct TreeTool; + +#[async_trait] +impl Tool for TreeTool { + fn name(&self) -> &str { + "tree" + } + fn description(&self) -> &str { + "List files and directories recursively in a tree-like format" + } + fn parameters(&self) -> Value { + json!({ + "type": "object", + "properties": { + "path": { "type": "string", "description": "The directory path to start from (default: .)", "default": "." }, + "depth": { "type": "integer", "description": "Max recursion depth (default: 3)", "default": 3 } + } + }) + } + + async fn execute(&self, args: Value) -> Result { + let path_str = args["path"].as_str().unwrap_or("."); + let max_depth = args["depth"].as_u64().unwrap_or(3) as usize; + + let mut output = String::new(); + let path = std::path::Path::new(path_str); + + if !path.exists() { + return Ok(ToolResult::error(format!("Path '{}' does not exist", path_str))); + } + + fn walk(dir: &std::path::Path, prefix: &str, current_depth: usize, max_depth: usize, output: &mut String) -> std::io::Result<()> { + if current_depth > max_depth { return Ok(()); } + + let entries: Vec<_> = fs::read_dir(dir)? + .flatten() + .filter(|entry| { + let name = entry.file_name().to_string_lossy().to_string(); + name != ".git" && name != "node_modules" && name != "target" + }) + .collect(); + + let count = entries.len(); + for (idx, entry) in entries.into_iter().enumerate() { + let is_last = idx == count - 1; + let path = entry.path(); + let name = entry.file_name().to_string_lossy().to_string(); + + let connector = if is_last { "└── " } else { "├── " }; + output.push_str(&format!("{}{}{}\n", prefix, connector, name)); + + if path.is_dir() { + let new_prefix = format!("{}{}", prefix, if is_last { " " } else { "│ " }); + walk(&path, &new_prefix, current_depth + 1, max_depth, output)?; + } + } + Ok(()) + } + + output.push_str(&format!("{}\n", path_str)); + if let Err(e) = walk(path, "", 1, max_depth, &mut output) { + return Ok(ToolResult::error(format!("Failed to walk directory: {}", e))); + } + + Ok(ToolResult::success(output)) + } +} + pub struct GrepTool; #[async_trait] @@ -71,29 +139,91 @@ impl Tool for GrepTool { .as_str() .ok_or_else(|| anyhow::anyhow!("Missing pattern"))?; let path = args["path"].as_str().unwrap_or("."); + let include = args["include"].as_str(); + + let glob_pattern = if let Some(inc) = include { + Some(glob::Pattern::new(inc).map_err(|e| anyhow::anyhow!("Invalid glob pattern '{}': {}", inc, e))?) + } else { + None + }; + + let regex_pattern = regex::Regex::new(pattern).ok(); // Using a simple recursive walk for grep let mut results = Vec::new(); fn walk_and_search( dir: &std::path::Path, + search_root: &std::path::Path, pattern: &str, + regex_pattern: Option<®ex::Regex>, + glob_pattern: Option<&glob::Pattern>, results: &mut Vec, ) -> io::Result<()> { if dir.is_dir() { for entry in fs::read_dir(dir)? { let entry = entry?; let path = entry.path(); + let name = entry.file_name().to_string_lossy().to_string(); + if name == ".git" || name == "node_modules" || name == "target" { + continue; + } if path.is_dir() { - walk_and_search(&path, pattern, results)?; - } else if let Ok(content) = fs::read_to_string(&path) { - for (idx, line) in content.lines().enumerate() { - if line.contains(pattern) { - results.push(format!( - "{}:{}: {}", - path.display(), - idx + 1, - line.trim() - )); + walk_and_search(&path, search_root, pattern, regex_pattern, glob_pattern, results)?; + } else { + if let Some(glob_pat) = glob_pattern { + let mut matches = false; + if let Some(filename) = path.file_name().and_then(|f| f.to_str()) { + if glob_pat.matches(filename) { + matches = true; + } + } + if !matches { + if let Ok(rel_path) = path.strip_prefix(search_root) { + if glob_pat.matches_path(rel_path) { + matches = true; + } + } + } + if !matches { + if glob_pat.matches_path(&path) { + matches = true; + } + } + if !matches { + continue; + } + } + + let is_binary = || -> bool { + use std::io::Read; + if let Ok(mut file) = fs::File::open(&path) { + let mut buffer = [0; 1024]; + if let Ok(bytes_read) = file.read(&mut buffer) { + return buffer[..bytes_read].contains(&0); + } + } + false + }; + + if is_binary() { + continue; + } + + if let Ok(content) = fs::read_to_string(&path) { + for (idx, line) in content.lines().enumerate() { + let is_match = if let Some(rx) = regex_pattern { + rx.is_match(line) + } else { + line.contains(pattern) + }; + if is_match { + results.push(format!( + "{}:{}: {}", + path.display(), + idx + 1, + line.trim() + )); + } } } } @@ -103,7 +233,8 @@ impl Tool for GrepTool { } use std::io; - if let Err(e) = walk_and_search(std::path::Path::new(path), pattern, &mut results) { + let search_root = std::path::Path::new(path); + if let Err(e) = walk_and_search(search_root, search_root, pattern, regex_pattern.as_ref(), glob_pattern.as_ref(), &mut results) { return Ok(ToolResult::error(format!("Search failed: {}", e))); } @@ -153,7 +284,16 @@ mod tests { ) .unwrap(); + let file_path_rs = dir.path().join("test.rs"); + fs::write( + &file_path_rs, + "line 1: hello in rust", + ) + .unwrap(); + let tool = GrepTool; + + // Test normal grep without include filter let args = json!({ "pattern": "hello", "path": dir.path().to_str().unwrap() @@ -164,6 +304,32 @@ mod tests { let content = res.content.unwrap(); assert!(content.contains("test.txt:1: line 1: hello")); assert!(content.contains("test.txt:3: line 3: hello again")); - assert!(!content.contains("line 2: world")); + assert!(content.contains("test.rs:1: line 1: hello in rust")); + + // Test grep with include filter (*.rs) + let args_inc = json!({ + "pattern": "hello", + "path": dir.path().to_str().unwrap(), + "include": "*.rs" + }); + let res_inc = tool.execute(args_inc).await.unwrap(); + + assert!(res_inc.success); + let content_inc = res_inc.content.unwrap(); + assert!(!content_inc.contains("test.txt")); + assert!(content_inc.contains("test.rs:1: line 1: hello in rust")); + + // Test grep with Regex pattern (e.g. h[e-o]llo) + let args_regex = json!({ + "pattern": "h[e-o]llo", + "path": dir.path().to_str().unwrap() + }); + let res_regex = tool.execute(args_regex).await.unwrap(); + + assert!(res_regex.success); + let content_regex = res_regex.content.unwrap(); + assert!(content_regex.contains("test.txt:1: line 1: hello")); + assert!(content_regex.contains("test.txt:3: line 3: hello again")); + assert!(content_regex.contains("test.rs:1: line 1: hello in rust")); } } diff --git a/libs/sdk/src/utils/costs.rs b/libs/sdk/src/utils/costs.rs index 55d32f3..38565ff 100644 --- a/libs/sdk/src/utils/costs.rs +++ b/libs/sdk/src/utils/costs.rs @@ -1,4 +1,7 @@ use serde::{Deserialize, Serialize}; +use std::collections::HashMap; +use std::sync::{Arc, RwLock}; +use once_cell::sync::Lazy; #[derive(Debug, Clone, Serialize, Deserialize, Default)] pub struct Usage { @@ -8,49 +11,136 @@ pub struct Usage { pub total_cost: f64, } +#[derive(Debug, Clone, Serialize, Deserialize)] pub struct ModelRates { - pub input_per_1k: f64, - pub output_per_1k: f64, + pub input_per_1m: f64, + pub output_per_1m: f64, } +#[derive(Debug, Serialize, Deserialize)] +struct ModelsDevProvider { + pub models: HashMap, +} + +#[derive(Debug, Serialize, Deserialize)] +struct ModelsDevModel { + pub cost: Option, +} + +#[derive(Debug, Serialize, Deserialize)] +struct ModelsDevCost { + pub input: f64, + pub output: f64, +} + +static RATE_CACHE: Lazy>>> = Lazy::new(|| { + Arc::new(RwLock::new(HashMap::new())) +}); + impl Usage { - pub fn add(&mut self, input: u32, output: u32, model: &str) { + pub async fn add(&mut self, input: u32, output: u32, model: &str) { self.input_tokens += input; self.output_tokens += output; self.total_tokens += input + output; - let cost = calculate_cost(input, output, model); + let cost = calculate_cost(input, output, model).await; self.total_cost += cost; } } -pub fn calculate_cost(input: u32, output: u32, model: &str) -> f64 { - let rates = get_model_rates(model); - let input_cost = (input as f64 / 1000.0) * rates.input_per_1k; - let output_cost = (output as f64 / 1000.0) * rates.output_per_1k; +pub async fn calculate_cost(input: u32, output: u32, model: &str) -> f64 { + let rates = get_model_rates(model).await; + let input_cost = (input as f64 / 1_000_000.0) * rates.input_per_1m; + let output_cost = (output as f64 / 1_000_000.0) * rates.output_per_1m; input_cost + output_cost } -fn get_model_rates(model: &str) -> ModelRates { - // Default rates (GPT-4o style) - let mut rates = ModelRates { - input_per_1k: 0.005, - output_per_1k: 0.015, +pub async fn refresh_rates() -> anyhow::Result<()> { + log::debug!("Refreshing model rates from models.dev..."); + let client = reqwest::Client::builder() + .timeout(std::time::Duration::from_secs(10)) + .build()?; + + let response = client.get("https://models.dev/api.json").send().await?; + if !response.status().is_success() { + return Err(anyhow::anyhow!("Failed to fetch rates: {}", response.status())); + } + + let data: HashMap = response.json().await?; + let mut new_rates = HashMap::new(); + + for (_provider_id, provider) in data { + for (model_id, model) in provider.models { + if let Some(cost) = model.cost { + new_rates.insert(model_id, ModelRates { + input_per_1m: cost.input, + output_per_1m: cost.output, + }); + } + } + } + + if !new_rates.is_empty() { + let mut cache = RATE_CACHE.write().unwrap_or_else(|e| e.into_inner()); + *cache = new_rates; + log::info!("Successfully updated {} model rates from models.dev", cache.len()); + } + + Ok(()) +} + +use std::sync::atomic::{AtomicBool, Ordering}; + +static REFRESH_TRIGGERED: AtomicBool = AtomicBool::new(false); + +async fn get_model_rates(model: &str) -> ModelRates { + // 1. Check cache + { + let cache = RATE_CACHE.read().unwrap_or_else(|e| e.into_inner()); + if let Some(rates) = cache.get(model) { + return rates.clone(); + } + + // Try fuzzy match if exact match fails (e.g. "gpt-4o-2024-05-13" vs "gpt-4o") + for (cached_id, rates) in cache.iter() { + if model.contains(cached_id) || cached_id.contains(model) { + return rates.clone(); + } + } + } + + // 2. If cache is empty, trigger one-time background refresh + let cache_is_empty = { + let cache = RATE_CACHE.read().unwrap_or_else(|e| e.into_inner()); + cache.is_empty() }; + if cache_is_empty { + if !REFRESH_TRIGGERED.swap(true, Ordering::SeqCst) { + tokio::spawn(async { + let _ = refresh_rates().await; + }); + } + } + + // 3. Fallback to hardcoded defaults for common models if API fails or model is missing + get_fallback_rates(model) +} + +fn get_fallback_rates(model: &str) -> ModelRates { if model.contains("gpt-4o-mini") { - rates.input_per_1k = 0.00015; - rates.output_per_1k = 0.0006; + ModelRates { input_per_1m: 0.15, output_per_1m: 0.60 } + } else if model.contains("gpt-4o") { + ModelRates { input_per_1m: 5.0, output_per_1m: 15.0 } } else if model.contains("claude-3-5-sonnet") { - rates.input_per_1k = 0.003; - rates.output_per_1k = 0.015; - } else if model.contains("claude-3-opus") { - rates.input_per_1k = 0.015; - rates.output_per_1k = 0.075; + ModelRates { input_per_1m: 3.0, output_per_1m: 15.0 } } else if model.contains("deepseek-v3") || model.contains("deepseek-chat") { - rates.input_per_1k = 0.0001; - rates.output_per_1k = 0.0002; + ModelRates { input_per_1m: 0.14, output_per_1m: 0.28 } + } else { + if !model.is_empty() { + log::warn!("Unknown model '{}' for cost calculation. Using default fallback rates.", model); + } + // Default GPT-4o style fallback + ModelRates { input_per_1m: 5.0, output_per_1m: 15.0 } } - - rates } diff --git a/libs/sdk/src/utils/storage.rs b/libs/sdk/src/utils/storage.rs index fd18a04..e213cbe 100644 --- a/libs/sdk/src/utils/storage.rs +++ b/libs/sdk/src/utils/storage.rs @@ -3,6 +3,8 @@ use crate::utils::costs::Usage; use serde::{Deserialize, Serialize}; use std::fs; use std::path::PathBuf; +use std::hash::{Hash, Hasher}; +use std::collections::hash_map::DefaultHasher; #[derive(Debug, Clone, Serialize, Deserialize)] pub struct Session { @@ -12,52 +14,249 @@ pub struct Session { pub timestamp: i64, } +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct SessionConfig { + #[serde(default)] + pub allow_all_commands: bool, + #[serde(default)] + pub allowed_commands: Vec, + #[serde(default)] + pub allow_all_outside_access: bool, +} + +impl Default for SessionConfig { + fn default() -> Self { + Self { + allow_all_commands: false, + allowed_commands: Vec::new(), + allow_all_outside_access: false, + } + } +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct WorkspaceConfig { + #[serde(default)] + pub allow_all_outside_access: bool, + #[serde(default)] + pub allowed_outside_paths: Vec, +} + +impl Default for WorkspaceConfig { + fn default() -> Self { + Self { + allow_all_outside_access: false, + allowed_outside_paths: Vec::new(), + } + } +} + +pub fn load_workspace_config() -> anyhow::Result { + let path = get_workspace_dir().join("workspace_config.json"); + if !path.exists() { + return Ok(WorkspaceConfig::default()); + } + let json = std::fs::read_to_string(path)?; + let config = serde_json::from_str(&json).unwrap_or_default(); + Ok(config) +} + +pub fn save_workspace_config(config: &WorkspaceConfig) -> anyhow::Result<()> { + let dir = get_workspace_dir(); + if !dir.exists() { + std::fs::create_dir_all(&dir)?; + } + let path = dir.join("workspace_config.json"); + let json = serde_json::to_string_pretty(config)?; + std::fs::write(path, json)?; + Ok(()) +} + pub fn get_base_dir() -> PathBuf { dirs::home_dir() .map(|p| p.join(".routecode")) .unwrap_or_else(|| PathBuf::from(".routecode")) } -pub fn save_session(name: &str, session: &Session) -> anyhow::Result<()> { - let dir = get_base_dir().join("sessions"); - if !dir.exists() { - fs::create_dir_all(&dir)?; +pub fn find_project_root() -> PathBuf { + let mut current = std::env::current_dir().unwrap_or_else(|_| PathBuf::from(".")); + + loop { + if current.join(".git").exists() || current.join("ROUTECODE.md").exists() { + return current; + } + + if let Some(parent) = current.parent() { + current = parent.to_path_buf(); + } else { + // Fallback to CWD + return std::env::current_dir().unwrap_or_else(|_| PathBuf::from(".")); + } } +} + +pub fn get_workspace_dir() -> PathBuf { + let root = find_project_root(); + let root_str = root.to_string_lossy().to_string(); + + let mut hasher = DefaultHasher::new(); + root_str.hash(&mut hasher); + let hash = format!("{:x}", hasher.finish()); + + let folder_name = root.file_name() + .map(|s| s.to_string_lossy().to_string()) + .unwrap_or_else(|| "workspace".to_string()); + + let safe_folder_name = folder_name.replace(|c: char| !c.is_alphanumeric(), "_"); + let workspace_id = format!("{}_{}", safe_folder_name, &hash[..8]); + + get_base_dir().join("workspaces").join(workspace_id) +} - let path = dir.join(format!("{}.json", name)); +pub fn is_path_outside_workspace(path_str: &str) -> bool { + let root = find_project_root(); + let root_canon = root.canonicalize().unwrap_or(root.clone()); + + let mut p_str = path_str; + if p_str.starts_with("/workspace/") { + p_str = &p_str[11..]; + } else if p_str.starts_with("/workspace") { + p_str = &p_str[10..]; + } + + let path = PathBuf::from(p_str); + let absolute_path = if path.is_absolute() { + path + } else { + root.join(path) + }; + + let mut p = absolute_path; + while !p.exists() { + if let Some(parent) = p.parent() { + p = parent.to_path_buf(); + } else { + break; + } + } + + if p.exists() { + if let Ok(canon) = p.canonicalize() { + return !canon.starts_with(&root_canon); + } + } + + false +} + +pub fn sanitize_session_name(name: &str) -> String { + name.chars() + .filter(|c| c.is_alphanumeric() || *c == '-' || *c == '_') + .collect() +} + +pub fn save_session(name: &str, session: &Session) -> anyhow::Result<()> { + let safe_name = sanitize_session_name(name); + if safe_name.is_empty() { return Err(anyhow::anyhow!("Invalid session name")); } + + let workspace_dir = get_workspace_dir(); + let session_dir = workspace_dir.join("sessions").join(&safe_name); + + if !session_dir.exists() { + fs::create_dir_all(&session_dir)?; + } + + let history_path = session_dir.join("history.json"); let json = serde_json::to_string_pretty(session)?; - fs::write(path, json)?; + fs::write(history_path, json)?; Ok(()) } pub fn load_session(name: &str) -> anyhow::Result { - let path = get_base_dir() - .join("sessions") - .join(format!("{}.json", name)); + let safe_name = sanitize_session_name(name); + if safe_name.is_empty() { return Err(anyhow::anyhow!("Invalid session name")); } + + let workspace_dir = get_workspace_dir(); + + let old_path = get_base_dir().join("sessions").join(format!("{}.json", safe_name)); + let new_path = workspace_dir.join("sessions").join(&safe_name).join("history.json"); + + let path = if new_path.exists() { new_path } else { old_path }; + let json = fs::read_to_string(path)?; let session = serde_json::from_str(&json)?; Ok(session) } pub fn list_sessions() -> anyhow::Result> { - let dir = get_base_dir().join("sessions"); - if !dir.exists() { - return Ok(Vec::new()); - } - + let workspace_dir = get_workspace_dir(); + let new_sessions_dir = workspace_dir.join("sessions"); + let old_sessions_dir = get_base_dir().join("sessions"); + let mut sessions = Vec::new(); - for entry in fs::read_dir(dir)? { - let entry = entry?; - let path = entry.path(); - if path.extension().is_some_and(|ext| ext == "json") { - if let Some(name) = path.file_stem().and_then(|s| s.to_str()) { - sessions.push(name.to_string()); + + if new_sessions_dir.exists() { + for entry in fs::read_dir(new_sessions_dir)? { + let entry = entry?; + let path = entry.path(); + if path.is_dir() && path.join("history.json").exists() { + if let Some(name) = path.file_name().and_then(|s| s.to_str()) { + sessions.push(name.to_string()); + } + } + } + } + + if old_sessions_dir.exists() { + for entry in fs::read_dir(old_sessions_dir)? { + let entry = entry?; + let path = entry.path(); + if path.extension().is_some_and(|ext| ext == "json") { + if let Some(name) = path.file_stem().and_then(|s| s.to_str()) { + if !sessions.contains(&name.to_string()) { + sessions.push(name.to_string()); + } + } } } } + Ok(sessions) } +pub fn load_session_config(name: &str) -> anyhow::Result { + let safe_name = sanitize_session_name(name); + if safe_name.is_empty() { return Err(anyhow::anyhow!("Invalid session name")); } + + let workspace_dir = get_workspace_dir(); + let config_path = workspace_dir.join("sessions").join(&safe_name).join("session_config.json"); + + if !config_path.exists() { + return Ok(SessionConfig::default()); + } + + let json = fs::read_to_string(config_path)?; + let config = serde_json::from_str(&json).unwrap_or_default(); + Ok(config) +} + +pub fn save_session_config(name: &str, config: &SessionConfig) -> anyhow::Result<()> { + let safe_name = sanitize_session_name(name); + if safe_name.is_empty() { return Err(anyhow::anyhow!("Invalid session name")); } + + let workspace_dir = get_workspace_dir(); + let session_dir = workspace_dir.join("sessions").join(&safe_name); + + if !session_dir.exists() { + fs::create_dir_all(&session_dir)?; + } + + let config_path = session_dir.join("session_config.json"); + let json = serde_json::to_string_pretty(config)?; + fs::write(config_path, json)?; + Ok(()) +} + pub fn load_config() -> anyhow::Result { let path = get_base_dir().join("config.json"); if !path.exists() { diff --git a/libs/sdk/src/utils/tokens.rs b/libs/sdk/src/utils/tokens.rs index dc39d16..6129a1c 100644 --- a/libs/sdk/src/utils/tokens.rs +++ b/libs/sdk/src/utils/tokens.rs @@ -1,30 +1,53 @@ use crate::core::Message; -use tiktoken_rs::cl100k_base; +use once_cell::sync::Lazy; +use tiktoken_rs::{cl100k_base, CoreBPE}; -pub fn count_tokens(messages: &[Message]) -> usize { - let bpe = match cl100k_base() { - Ok(b) => b, - Err(_) => return 0, - }; +static BPE: Lazy> = Lazy::new(|| { + match cl100k_base() { + Ok(b) => Some(b), + Err(e) => { + log::warn!("tiktoken cl100k_base initialization failed: {}. Fallback estimation will be used.", e); + None + } + } +}); - let mut total_tokens = 0; - for m in messages { - // Role overhead - total_tokens += 4; +pub fn count_tokens(messages: &[Message]) -> usize { + if let Some(bpe) = &*BPE { + let mut total_tokens = 0; + for m in messages { + // Role overhead + total_tokens += 4; - if let Some(content) = &m.content { - total_tokens += bpe.encode_with_special_tokens(content).len(); - } - if let Some(thought) = &m.thought { - total_tokens += bpe.encode_with_special_tokens(thought).len(); + if let Some(content) = &m.content { + total_tokens += bpe.encode_with_special_tokens(content).len(); + } + if let Some(thought) = &m.thought { + total_tokens += bpe.encode_with_special_tokens(thought).len(); + } + if let Some(tool_calls) = &m.tool_calls { + for tc in tool_calls { + total_tokens += bpe.encode_with_special_tokens(&tc.function.name).len(); + total_tokens += bpe.encode_with_special_tokens(&tc.function.arguments).len(); + total_tokens += 10; // overhead + } + } } - if let Some(tool_calls) = &m.tool_calls { - for tc in tool_calls { - total_tokens += bpe.encode_with_special_tokens(&tc.function.name).len(); - total_tokens += bpe.encode_with_special_tokens(&tc.function.arguments).len(); - total_tokens += 10; // overhead + total_tokens + } else { + let mut total = 0; + for m in messages { + total += 4; // Role overhead + if let Some(content) = &m.content { total += content.len() / 4; } + if let Some(thought) = &m.thought { total += thought.len() / 4; } + if let Some(tool_calls) = &m.tool_calls { + for tc in tool_calls { + total += tc.function.name.len() / 4; + total += tc.function.arguments.len() / 4; + total += 10; + } } } + total } - total_tokens }