diff --git a/packages/tauri-app/Cargo.lock b/packages/tauri-app/Cargo.lock index 2268df75..ae157f48 100644 --- a/packages/tauri-app/Cargo.lock +++ b/packages/tauri-app/Cargo.lock @@ -444,6 +444,12 @@ version = "1.0.4" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "9330f8b2ff13f34540b44e946ef35111825727b38d33286ef986142615121801" +[[package]] +name = "cfg_aliases" +version = "0.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "613afe47fcd5fac7ccf1db93babcb082c5994d996f20b8b159f2ad1658eb5724" + [[package]] name = "chrono" version = "0.4.44" @@ -467,6 +473,7 @@ dependencies = [ "once_cell", "parking_lot", "regex", + "reqwest 0.12.28", "serde", "serde_json", "serde_yaml", @@ -1156,6 +1163,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "07bbe89c50d7a535e539b8c17bc0b49bdb77747034daa8087407d655f3f7cc1d" dependencies = [ "futures-core", + "futures-sink", ] [[package]] @@ -1379,8 +1387,10 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "ff2abc00be7fca6ebc474524697ae276ad847ad0a6b3faa4bcb027e9a4614ad0" dependencies = [ "cfg-if", + "js-sys", "libc", "wasi 0.11.1+wasi-snapshot-preview1", + "wasm-bindgen", ] [[package]] @@ -1390,9 +1400,11 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "899def5c37c4fd7b2664648c28120ecec138e4d395b459e5ca34f9cce2dd77fd" dependencies = [ "cfg-if", + "js-sys", "libc", "r-efi 5.3.0", "wasip2", + "wasm-bindgen", ] [[package]] @@ -1710,6 +1722,23 @@ dependencies = [ "want", ] +[[package]] +name = "hyper-rustls" +version = "0.27.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e3c93eb611681b207e1fe55d5a71ecf91572ec8a6705cdb6857f7d8d5242cf58" +dependencies = [ + "http", + "hyper", + "hyper-util", + "rustls", + "rustls-pki-types", + "tokio", + "tokio-rustls", + "tower-service", + "webpki-roots", +] + [[package]] name = "hyper-util" version = "0.1.20" @@ -2157,6 +2186,12 @@ version = "0.4.29" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "5e5032e24019045c762d3c0f28f5b6b8bbf38563a65908389bf7978758920897" +[[package]] +name = "lru-slab" +version = "0.1.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "112b39cec0b298b6c1999fee3e31427f74f676e4cb9879ed1a121b43661a4154" + [[package]] name = "mac" version = "0.1.1" @@ -2995,6 +3030,61 @@ dependencies = [ "memchr", ] +[[package]] +name = "quinn" +version = "0.11.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b9e20a958963c291dc322d98411f541009df2ced7b5a4f2bd52337638cfccf20" +dependencies = [ + "bytes", + "cfg_aliases", + "pin-project-lite", + "quinn-proto", + "quinn-udp", + "rustc-hash", + "rustls", + "socket2", + "thiserror 2.0.18", + "tokio", + "tracing", + "web-time", +] + +[[package]] +name = "quinn-proto" +version = "0.11.14" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "434b42fec591c96ef50e21e886936e66d3cc3f737104fdb9b737c40ffb94c098" +dependencies = [ + "bytes", + "getrandom 0.3.4", + "lru-slab", + "rand 0.9.2", + "ring", + "rustc-hash", + "rustls", + "rustls-pki-types", + "slab", + "thiserror 2.0.18", + "tinyvec", + "tracing", + "web-time", +] + +[[package]] +name = "quinn-udp" +version = "0.5.14" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "addec6a0dcad8a8d96a771f815f0eaf55f9d1805756410b39f5fa81332574cbd" +dependencies = [ + "cfg_aliases", + "libc", + "once_cell", + "socket2", + "tracing", + "windows-sys 0.60.2", +] + [[package]] name = "quote" version = "1.0.45" @@ -3212,6 +3302,46 @@ version = "0.8.10" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "dc897dd8d9e8bd1ed8cdad82b5966c3e0ecae09fb1907d58efaa013543185d0a" +[[package]] +name = "reqwest" +version = "0.12.28" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "eddd3ca559203180a307f12d114c268abf583f59b03cb906fd0b3ff8646c1147" +dependencies = [ + "base64 0.22.1", + "bytes", + "futures-channel", + "futures-core", + "futures-util", + "http", + "http-body", + "http-body-util", + "hyper", + "hyper-rustls", + "hyper-util", + "js-sys", + "log", + "percent-encoding", + "pin-project-lite", + "quinn", + "rustls", + "rustls-pki-types", + "serde", + "serde_json", + "serde_urlencoded", + "sync_wrapper", + "tokio", + "tokio-rustls", + "tower", + "tower-http", + "tower-service", + "url", + "wasm-bindgen", + "wasm-bindgen-futures", + "web-sys", + "webpki-roots", +] + [[package]] name = "reqwest" version = "0.13.2" @@ -3270,6 +3400,20 @@ dependencies = [ "windows-sys 0.60.2", ] +[[package]] +name = "ring" +version = "0.17.14" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a4689e6c2294d81e88dc6261c768b63bc4fcdb852be6d1352498b114f61383b7" +dependencies = [ + "cc", + "cfg-if", + "getrandom 0.2.17", + "libc", + "untrusted", + "windows-sys 0.52.0", +] + [[package]] name = "rustc-hash" version = "2.1.1" @@ -3311,6 +3455,41 @@ dependencies = [ "windows-sys 0.61.2", ] +[[package]] +name = "rustls" +version = "0.23.37" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "758025cb5fccfd3bc2fd74708fd4682be41d99e5dff73c377c0646c6012c73a4" +dependencies = [ + "once_cell", + "ring", + "rustls-pki-types", + "rustls-webpki", + "subtle", + "zeroize", +] + +[[package]] +name = "rustls-pki-types" +version = "1.14.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "be040f8b0a225e40375822a563fa9524378b9d63112f53e19ffff34df5d33fdd" +dependencies = [ + "web-time", + "zeroize", +] + +[[package]] +name = "rustls-webpki" +version = "0.103.10" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "df33b2b81ac578cabaf06b89b0631153a3f416b0a886e8a7a1707fb51abbd1ef" +dependencies = [ + "ring", + "rustls-pki-types", + "untrusted", +] + [[package]] name = "rustversion" version = "1.0.22" @@ -3531,6 +3710,18 @@ dependencies = [ "serde_core", ] +[[package]] +name = "serde_urlencoded" +version = "0.7.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d3491c14715ca2294c4d6a88f15e84739788c1d030eed8c110436aafdaa2f3fd" +dependencies = [ + "form_urlencoded", + "itoa", + "ryu", + "serde", +] + [[package]] name = "serde_with" version = "3.18.0" @@ -3792,6 +3983,12 @@ version = "0.11.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "7da8b5736845d9f2fcb837ea5d9e2628564b3b043a70948a3f0b778838c5fb4f" +[[package]] +name = "subtle" +version = "2.6.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "13c2bddecc57b384dee18652358fb23172facb8a2c51ccc10d74c157bdea3292" + [[package]] name = "swift-rs" version = "1.0.7" @@ -3943,7 +4140,7 @@ dependencies = [ "percent-encoding", "plist", "raw-window-handle", - "reqwest", + "reqwest 0.13.2", "serde", "serde_json", "serde_repr", @@ -4367,6 +4564,21 @@ dependencies = [ "zerovec", ] +[[package]] +name = "tinyvec" +version = "1.11.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3e61e67053d25a4e82c844e8424039d9745781b3fc4f32b8d55ed50f5f667ef3" +dependencies = [ + "tinyvec_macros", +] + +[[package]] +name = "tinyvec_macros" +version = "0.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1f3ccbac311fea05f86f61904b462b55fb3df8837a366dfc601a0161d0532f20" + [[package]] name = "tokio" version = "1.50.0" @@ -4381,6 +4593,16 @@ dependencies = [ "windows-sys 0.61.2", ] +[[package]] +name = "tokio-rustls" +version = "0.26.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1729aa945f29d91ba541258c8df89027d5792d85a8841fb65e8bf0f4ede4ef61" +dependencies = [ + "rustls", + "tokio", +] + [[package]] name = "tokio-util" version = "0.7.18" @@ -4691,6 +4913,12 @@ version = "0.2.11" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "673aac59facbab8a9007c7f6108d11f63b603f7cabff99fabf650fea5c32b861" +[[package]] +name = "untrusted" +version = "0.9.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8ecb6da28b8a351d773b68d5825ac39017e680750f980f3a1a85cd8dd28a47c1" + [[package]] name = "url" version = "2.5.8" @@ -4937,6 +5165,16 @@ dependencies = [ "wasm-bindgen", ] +[[package]] +name = "web-time" +version = "1.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5a6580f308b1fad9207618087a65c04e7a10bc77e02c8e84e9b00dd4b12fa0bb" +dependencies = [ + "js-sys", + "wasm-bindgen", +] + [[package]] name = "web_atoms" version = "0.2.3" @@ -4993,6 +5231,15 @@ dependencies = [ "system-deps", ] +[[package]] +name = "webpki-roots" +version = "1.0.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "22cfaf3c063993ff62e73cb4311efde4db1efb31ab78a3e5c457939ad5cc0bed" +dependencies = [ + "rustls-pki-types", +] + [[package]] name = "webview2-com" version = "0.38.2" @@ -5286,6 +5533,15 @@ dependencies = [ "windows-targets 0.48.5", ] +[[package]] +name = "windows-sys" +version = "0.52.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "282be5f36a8ce781fad8c8ae18fa3f9beff57ec1b52cb3de0789201425d9a33d" +dependencies = [ + "windows-targets 0.52.6", +] + [[package]] name = "windows-sys" version = "0.59.0" @@ -5927,6 +6183,12 @@ dependencies = [ "synstructure", ] +[[package]] +name = "zeroize" +version = "1.8.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b97154e67e32c85465826e8bcc1c59429aaaf107c1e4a9e53c8d8ccd5eff88d0" + [[package]] name = "zerotrie" version = "0.2.3" diff --git a/packages/tauri-app/Cargo.toml b/packages/tauri-app/Cargo.toml index 5322fc26..6fdaca53 100644 --- a/packages/tauri-app/Cargo.toml +++ b/packages/tauri-app/Cargo.toml @@ -1,3 +1,10 @@ [workspace] members = ["src-tauri"] resolver = "2" + +[profile.release] +panic = "abort" +lto = true +codegen-units = 1 +opt-level = "s" +strip = true diff --git a/packages/tauri-app/src-tauri/Cargo.toml b/packages/tauri-app/src-tauri/Cargo.toml index 13a89b1f..9d798ef8 100644 --- a/packages/tauri-app/src-tauri/Cargo.toml +++ b/packages/tauri-app/src-tauri/Cargo.toml @@ -26,6 +26,7 @@ tauri-plugin-opener = "2" tauri-plugin-global-shortcut = "2" url = "2" tauri-plugin-notification = "2" +reqwest = { version = "0.12", default-features = false, features = ["blocking", "json", "rustls-tls"] } [target.'cfg(windows)'.dependencies] windows-sys = { version = "0.59", features = ["Win32_UI_Shell"] } diff --git a/packages/tauri-app/src-tauri/src/cli_manager.rs b/packages/tauri-app/src-tauri/src/cli_manager.rs index 5844566e..22ecdc13 100644 --- a/packages/tauri-app/src-tauri/src/cli_manager.rs +++ b/packages/tauri-app/src-tauri/src/cli_manager.rs @@ -1,3 +1,4 @@ +use crate::desktop_event_transport::DesktopEventStreamConfig; use dirs::home_dir; use parking_lot::Mutex; use regex::Regex; @@ -365,6 +366,8 @@ pub struct CliProcessManager { child: Arc>>, ready: Arc, bootstrap_token: Arc>>, + session_cookie: Arc>>, + auth_cookie_name: Arc>>, } impl CliProcessManager { @@ -374,6 +377,8 @@ impl CliProcessManager { child: Arc::new(Mutex::new(None)), ready: Arc::new(AtomicBool::new(false)), bootstrap_token: Arc::new(Mutex::new(None)), + session_cookie: Arc::new(Mutex::new(None)), + auth_cookie_name: Arc::new(Mutex::new(None)), } } @@ -382,6 +387,8 @@ impl CliProcessManager { self.stop()?; self.ready.store(false, Ordering::SeqCst); *self.bootstrap_token.lock() = None; + *self.session_cookie.lock() = None; + *self.auth_cookie_name.lock() = None; { let mut status = self.status.lock(); status.state = CliState::Starting; @@ -396,6 +403,8 @@ impl CliProcessManager { let child_arc = self.child.clone(); let ready_flag = self.ready.clone(); let token_arc = self.bootstrap_token.clone(); + let session_cookie_arc = self.session_cookie.clone(); + let auth_cookie_name_arc = self.auth_cookie_name.clone(); thread::spawn(move || { if let Err(err) = Self::spawn_cli( app.clone(), @@ -403,6 +412,8 @@ impl CliProcessManager { child_arc, ready_flag, token_arc, + session_cookie_arc, + auth_cookie_name_arc, dev, ) { log_line(&format!("cli spawn failed: {err}")); @@ -499,6 +510,7 @@ impl CliProcessManager { status.port = None; status.url = None; status.error = None; + *self.session_cookie.lock() = None; Ok(()) } @@ -507,12 +519,33 @@ impl CliProcessManager { self.status.lock().clone() } + pub fn desktop_event_stream_config(&self) -> Option { + let base_url = self.status.lock().url.clone()?; + let events_url = format!("{}/api/events", base_url.trim_end_matches('/')); + let client_id = format!("tauri-{}", std::process::id()); + let cookie_name = self + .auth_cookie_name + .lock() + .clone() + .unwrap_or_else(|| SESSION_COOKIE_NAME_PREFIX.to_string()); + + Some(DesktopEventStreamConfig { + base_url, + events_url, + client_id, + cookie_name, + session_cookie: self.session_cookie.lock().clone(), + }) + } + fn spawn_cli( app: AppHandle, status: Arc>, child_holder: Arc>>, ready: Arc, bootstrap_token: Arc>>, + session_cookie: Arc>>, + auth_cookie_name_holder: Arc>>, dev: bool, ) -> anyhow::Result<()> { log_line("resolving CLI entry"); @@ -523,6 +556,8 @@ impl CliProcessManager { resolution.runner, resolution.entry, host )); let auth_cookie_name = Arc::new(generate_auth_cookie_name()); + // Store the generated cookie name so desktop_event_stream_config() can use it later. + *auth_cookie_name_holder.lock() = Some(auth_cookie_name.as_str().to_string()); let args = resolution.build_args(dev, &host, auth_cookie_name.as_str()); log_line(&format!("CLI args: {:?}", args)); if dev { @@ -608,6 +643,7 @@ impl CliProcessManager { let app_clone = app.clone(); let ready_clone = ready.clone(); let token_clone = bootstrap_token.clone(); + let session_cookie_clone = session_cookie.clone(); let auth_cookie_name_clone = auth_cookie_name.clone(); thread::spawn(move || { @@ -627,6 +663,7 @@ impl CliProcessManager { let status = status_clone.clone(); let ready = ready_clone.clone(); let token = token_clone.clone(); + let session_cookie = session_cookie_clone.clone(); let auth_cookie_name = auth_cookie_name_clone.clone(); thread::spawn(move || { Self::process_stream( @@ -636,6 +673,7 @@ impl CliProcessManager { &status, &ready, &token, + &session_cookie, auth_cookie_name.as_str(), ); }); @@ -646,6 +684,7 @@ impl CliProcessManager { let status = status_clone.clone(); let ready = ready_clone.clone(); let token = token_clone.clone(); + let session_cookie = session_cookie_clone.clone(); let auth_cookie_name = auth_cookie_name_clone.clone(); thread::spawn(move || { Self::process_stream( @@ -655,6 +694,7 @@ impl CliProcessManager { &status, &ready, &token, + &session_cookie, auth_cookie_name.as_str(), ); }); @@ -773,10 +813,12 @@ impl CliProcessManager { status: &Arc>, ready: &Arc, bootstrap_token: &Arc>>, + session_cookie: &Arc>>, auth_cookie_name: &str, ) { let mut buffer = String::new(); - let local_url_regex = Regex::new(r"^Local\s+Connection\s+URL\s*:\s*(https?://\S+)\s*$").ok(); + let local_url_regex = + Regex::new(r"^Local\s+Connection\s+URL\s*:\s*(https?://\S+)\s*$").ok(); let token_prefix = "CODENOMAD_BOOTSTRAP_TOKEN:"; loop { @@ -813,12 +855,12 @@ impl CliProcessManager { status, ready, bootstrap_token, + session_cookie, auth_cookie_name, url, ); continue; } - } } Err(_) => break, @@ -831,6 +873,7 @@ impl CliProcessManager { status: &Arc>, ready: &Arc, bootstrap_token: &Arc>>, + session_cookie: &Arc>>, auth_cookie_name: &str, base_url: String, ) { @@ -855,14 +898,15 @@ impl CliProcessManager { if scheme.as_deref() != Some("http") { navigate_main(app, &base_url); } else { - match exchange_bootstrap_token(&base_url, &token, &auth_cookie_name) { + match exchange_bootstrap_token(&base_url, &token, auth_cookie_name) { Ok(Some(session_id)) => { if let Err(err) = - set_session_cookie(app, &base_url, &auth_cookie_name, &session_id) + set_session_cookie(app, &base_url, auth_cookie_name, &session_id) { log_line(&format!("failed to set session cookie: {err}")); navigate_main(app, &format!("{base_url}/login")); } else { + *session_cookie.lock() = Some(session_id.clone()); navigate_main(app, &base_url); } } @@ -1022,15 +1066,23 @@ fn resolve_tsx(_app: &AppHandle) -> Option { let cwd = std::env::current_dir().ok(); let workspace = workspace_root(); let mut candidates = vec![ - cwd.as_ref().map(|p| p.join("node_modules/tsx/dist/cli.mjs")), - cwd.as_ref().map(|p| p.join("node_modules/tsx/dist/cli.cjs")), + cwd.as_ref() + .map(|p| p.join("node_modules/tsx/dist/cli.mjs")), + cwd.as_ref() + .map(|p| p.join("node_modules/tsx/dist/cli.cjs")), cwd.as_ref().map(|p| p.join("node_modules/tsx/dist/cli.js")), - cwd.as_ref().map(|p| p.join("../node_modules/tsx/dist/cli.mjs")), - cwd.as_ref().map(|p| p.join("../node_modules/tsx/dist/cli.cjs")), - cwd.as_ref().map(|p| p.join("../node_modules/tsx/dist/cli.js")), - cwd.as_ref().map(|p| p.join("../../node_modules/tsx/dist/cli.mjs")), - cwd.as_ref().map(|p| p.join("../../node_modules/tsx/dist/cli.cjs")), - cwd.as_ref().map(|p| p.join("../../node_modules/tsx/dist/cli.js")), + cwd.as_ref() + .map(|p| p.join("../node_modules/tsx/dist/cli.mjs")), + cwd.as_ref() + .map(|p| p.join("../node_modules/tsx/dist/cli.cjs")), + cwd.as_ref() + .map(|p| p.join("../node_modules/tsx/dist/cli.js")), + cwd.as_ref() + .map(|p| p.join("../../node_modules/tsx/dist/cli.mjs")), + cwd.as_ref() + .map(|p| p.join("../../node_modules/tsx/dist/cli.cjs")), + cwd.as_ref() + .map(|p| p.join("../../node_modules/tsx/dist/cli.js")), workspace .as_ref() .map(|p| p.join("node_modules/tsx/dist/cli.mjs")), diff --git a/packages/tauri-app/src-tauri/src/desktop_event_transport.rs b/packages/tauri-app/src-tauri/src/desktop_event_transport.rs new file mode 100644 index 00000000..49cc7552 --- /dev/null +++ b/packages/tauri-app/src-tauri/src/desktop_event_transport.rs @@ -0,0 +1,724 @@ +use parking_lot::Mutex; +use reqwest::blocking::{Client, Response}; +use reqwest::StatusCode; +use serde::{Deserialize, Serialize}; +use serde_json::Value; +use std::collections::HashMap; +use std::io::{BufRead, BufReader}; +use std::sync::atomic::{AtomicBool, AtomicU64, Ordering}; +use std::sync::mpsc::{self, RecvTimeoutError, SyncSender}; +use std::sync::Arc; +use std::thread; +use std::time::{Duration, Instant, SystemTime, UNIX_EPOCH}; +use tauri::{AppHandle, Emitter, Manager, Url}; + +mod assembler; +mod stream; +mod transport; + +use stream::*; +use transport::*; + +const EVENT_BATCH_NAME: &str = "desktop:event-batch"; +const EVENT_STATUS_NAME: &str = "desktop:event-stream-status"; +const FLUSH_INTERVAL_MS: u64 = 16; +const DELTA_STREAM_WINDOW_MS: u64 = 48; +const ACTIVE_STREAM_DISPLAY_WINDOW_MS: u64 = 16; +const ACTIVE_STREAM_DISPLAY_CHUNK_MAX: usize = 96; +const ACTIVE_STREAM_STORE_WINDOW_MS: u64 = 250; +const ACTIVE_STREAM_SNAPSHOT_WINDOW_MS: u64 = 200; +const ACTIVE_STREAM_HOLD_WINDOW_MS: u64 = 12; +const ACTIVE_SESSION_MAX_BATCH_EVENTS: usize = 64; +const MAX_BATCH_EVENTS: usize = 256; +const DEFAULT_RECONNECT_INITIAL_DELAY_MS: u64 = 1_000; +const DEFAULT_RECONNECT_MAX_DELAY_MS: u64 = 10_000; +const DEFAULT_RECONNECT_MULTIPLIER: f64 = 2.0; +const STREAM_CONNECT_TIMEOUT_MS: u64 = 5_000; +const STREAM_TCP_KEEPALIVE_MS: u64 = 30_000; +const STREAM_STALL_TIMEOUT_MS: u64 = 30_000; + +#[derive(Clone, Debug, PartialEq, Eq)] +pub struct DesktopEventStreamConfig { + pub base_url: String, + pub events_url: String, + pub client_id: String, + pub cookie_name: String, + pub session_cookie: Option, +} + +#[derive(Clone, Debug, Default, Deserialize)] +#[serde(default, rename_all = "camelCase")] +pub struct DesktopEventsStartRequest { + pub reconnect: Option, +} + +#[derive(Clone, Debug, Default, Deserialize)] +#[serde(default, rename_all = "camelCase")] +pub struct DesktopEventReconnectPolicy { + pub initial_delay_ms: Option, + pub max_delay_ms: Option, + pub multiplier: Option, + pub max_attempts: Option, +} + +#[derive(Clone, Serialize)] +#[serde(rename_all = "camelCase")] +pub struct DesktopEventsStartResult { + pub started: bool, + pub generation: Option, + pub reason: Option, +} + +#[derive(Clone, Debug, PartialEq)] +struct ResolvedDesktopEventReconnectPolicy { + initial_delay_ms: u64, + max_delay_ms: u64, + multiplier: f64, + max_attempts: Option, +} + +impl ResolvedDesktopEventReconnectPolicy { + fn resolve(policy: Option<&DesktopEventReconnectPolicy>) -> Self { + let initial_delay_ms = policy + .and_then(|value| value.initial_delay_ms) + .unwrap_or(DEFAULT_RECONNECT_INITIAL_DELAY_MS) + .max(1); + let max_delay_ms = policy + .and_then(|value| value.max_delay_ms) + .unwrap_or(DEFAULT_RECONNECT_MAX_DELAY_MS) + .max(initial_delay_ms); + let multiplier = policy + .and_then(|value| value.multiplier) + .filter(|value| value.is_finite() && *value >= 1.0) + .unwrap_or(DEFAULT_RECONNECT_MULTIPLIER); + let max_attempts = policy + .and_then(|value| value.max_attempts) + .filter(|value| *value > 0); + + Self { + initial_delay_ms, + max_delay_ms, + multiplier, + max_attempts, + } + } +} + +#[derive(Clone, Debug, PartialEq)] +struct DesktopEventTransportConfig { + stream: DesktopEventStreamConfig, + reconnect: ResolvedDesktopEventReconnectPolicy, +} + +impl DesktopEventTransportConfig { + fn new(stream: DesktopEventStreamConfig, request: &DesktopEventsStartRequest) -> Self { + Self { + stream, + reconnect: ResolvedDesktopEventReconnectPolicy::resolve(request.reconnect.as_ref()), + } + } +} + +#[derive(Clone, Serialize)] +#[serde(rename_all = "camelCase")] +struct WorkspaceEventBatchPayload { + generation: u64, + sequence: u64, + emitted_at: u128, + events: Vec, +} + +#[derive(Clone, Serialize)] +#[serde(rename_all = "camelCase")] +struct DesktopEventStreamStatusPayload { + generation: u64, + state: &'static str, + reconnect_attempt: u32, + terminal: bool, + reason: Option, + next_delay_ms: Option, + status_code: Option, + stats: DesktopEventTransportStats, +} + +#[derive(Clone, Default, Serialize)] +#[serde(rename_all = "camelCase")] +struct DesktopEventTransportStats { + raw_events: u64, + emitted_events: u64, + emitted_batches: u64, + delta_coalesces: u64, + snapshot_coalesces: u64, + status_coalesces: u64, + superseded_deltas_dropped: u64, +} + +struct DesktopEventTransportState { + stop: Option>, + config: Option, + active_target: Option, +} + +#[derive(Clone, PartialEq, Eq)] +pub struct ActiveSessionTarget { + pub instance_id: String, + pub session_id: String, +} + +pub struct DesktopEventTransportManager { + state: Arc>, + generation: Arc, +} + +enum ReaderMessage { + Activity, + Event(Value), + End(Option), +} + +enum PendingEntry { + Delta { + key: String, + scope: String, + instance_id: String, + session_id: Option, + event: Value, + started_at: Instant, + }, + Status { + key: String, + event: Value, + }, + Snapshot { + key: String, + event: Value, + }, + Event(Value), +} + +enum EventDeliveryPolicy { + CoalesceDelta(String), + CoalesceStatus(String), + CoalesceSnapshot(String), + Passthrough, +} + +enum OpenStreamErrorKind { + Unauthorized, + Http, + Transport, +} + +struct OpenStreamError { + kind: OpenStreamErrorKind, + message: String, + status_code: Option, +} + +#[derive(Default)] +struct PendingBatch { + events: Vec, +} + +#[derive(Clone)] +struct ActiveTextDelta { + instance_id: String, + session_id: String, + message_id: String, + part_id: String, + delta: String, +} + +struct ActiveTextPartBuffer { + instance_id: String, + session_id: String, + message_id: String, + part_id: String, + display_pending: String, + store_pending: String, + last_display_emit: Instant, + last_store_emit: Instant, +} + +impl ActiveTextPartBuffer { + fn new(delta: ActiveTextDelta, now: Instant) -> Self { + Self { + instance_id: delta.instance_id, + session_id: delta.session_id, + message_id: delta.message_id, + part_id: delta.part_id, + display_pending: delta.delta.clone(), + store_pending: delta.delta, + last_display_emit: now, + last_store_emit: now, + } + } +} + +#[derive(Clone)] +struct ActiveTextSnapshot { + key: String, + instance_id: String, + session_id: String, + message_id: String, + part_id: String, + event: Value, +} + +struct BufferedTextSnapshot { + instance_id: String, + session_id: String, + message_id: String, + part_id: String, + event: Value, + buffered_at: Instant, +} + +#[derive(Default)] +struct ActiveTextAssembler { + parts: HashMap, +} + +#[derive(Default)] +struct ActiveTextSnapshotBuffer { + parts: HashMap, +} + +impl DesktopEventTransportManager { + pub fn new() -> Self { + Self { + state: Arc::new(Mutex::new(DesktopEventTransportState { + stop: None, + config: None, + active_target: None, + })), + generation: Arc::new(AtomicU64::new(0)), + } + } + + pub fn set_active_session_target(&self, target: Option) { + let mut state = self.state.lock(); + state.active_target = target; + } + + pub fn start( + &self, + app: AppHandle, + stream_config: Option, + request: Option, + ) -> DesktopEventsStartResult { + let Some(stream_config) = stream_config else { + return DesktopEventsStartResult { + started: false, + generation: None, + reason: Some("desktop event stream unavailable".to_string()), + }; + }; + + let request = request.unwrap_or_default(); + let transport_config = DesktopEventTransportConfig::new(stream_config, &request); + + let mut state = self.state.lock(); + if state.config.as_ref() == Some(&transport_config) { + if let Some(stop) = &state.stop { + if !stop.load(Ordering::SeqCst) { + return DesktopEventsStartResult { + started: true, + generation: Some(self.generation.load(Ordering::SeqCst)), + reason: None, + }; + } + } + } + + if let Some(stop) = state.stop.take() { + stop.store(true, Ordering::SeqCst); + } + + let generation = self.generation.fetch_add(1, Ordering::SeqCst) + 1; + let stop = Arc::new(AtomicBool::new(false)); + state.stop = Some(stop.clone()); + state.config = Some(transport_config.clone()); + let shared_state = self.state.clone(); + let shared_generation = self.generation.clone(); + drop(state); + + thread::spawn(move || { + run_transport_loop( + app, + shared_state, + shared_generation, + generation, + stop, + transport_config, + ) + }); + + DesktopEventsStartResult { + started: true, + generation: Some(generation), + reason: None, + } + } + + pub fn stop(&self) { + let mut state = self.state.lock(); + if let Some(stop) = state.stop.take() { + stop.store(true, Ordering::SeqCst); + } + state.config = None; + state.active_target = None; + self.generation.fetch_add(1, Ordering::SeqCst); + } +} + +fn classify_event(event: &Value) -> EventDeliveryPolicy { + if let Some(key) = delta_key(event) { + return EventDeliveryPolicy::CoalesceDelta(key); + } + + if let Some(key) = status_key(event) { + return EventDeliveryPolicy::CoalesceStatus(key); + } + + if let Some(key) = snapshot_key(event) { + return EventDeliveryPolicy::CoalesceSnapshot(key); + } + + EventDeliveryPolicy::Passthrough +} + +fn coalesced_payload_event<'a>(event: &'a Value) -> &'a Value { + if event.get("type").and_then(Value::as_str) == Some("instance.event") { + event.get("event").unwrap_or(event) + } else { + event + } +} + +fn coalesced_instance_id(event: &Value) -> &str { + event + .get("instanceId") + .and_then(Value::as_str) + .unwrap_or_default() +} + +fn event_session_id(event: &Value) -> Option<&str> { + let inner = coalesced_payload_event(event); + let inner_type = inner.get("type")?.as_str()?; + let props = inner.get("properties")?; + + match inner_type { + "session.updated" => props + .get("info") + .and_then(|info| info.get("id")) + .and_then(Value::as_str) + .or_else(|| { + props + .get("sessionID") + .or_else(|| props.get("sessionId")) + .and_then(Value::as_str) + }), + "message.updated" => props + .get("info") + .and_then(|info| info.get("sessionID").or_else(|| info.get("sessionId"))) + .and_then(Value::as_str), + "message.part.updated" => props + .get("part") + .and_then(|part| part.get("sessionID").or_else(|| part.get("sessionId"))) + .and_then(Value::as_str), + "message.part.delta" + | "message.removed" + | "message.part.removed" + | "session.compacted" + | "session.diff" + | "session.idle" + | "session.status" => props + .get("sessionID") + .or_else(|| props.get("sessionId")) + .and_then(Value::as_str), + _ => None, + } +} + +fn parse_active_text_delta( + event: &Value, + active_target: Option<&ActiveSessionTarget>, +) -> Option { + let active_target = active_target?; + let instance_id = coalesced_instance_id(event); + if instance_id != active_target.instance_id { + return None; + } + let inner = coalesced_payload_event(event); + if inner.get("type")?.as_str()? != "message.part.delta" { + return None; + } + + let props = inner.get("properties")?; + let field = props.get("field")?.as_str()?; + if field != "text" { + return None; + } + + let event_session = props + .get("sessionID") + .or_else(|| props.get("sessionId")) + .and_then(Value::as_str)?; + if event_session != active_target.session_id { + return None; + } + + Some(ActiveTextDelta { + instance_id: instance_id.to_string(), + session_id: event_session.to_string(), + message_id: props + .get("messageID") + .or_else(|| props.get("messageId")) + .and_then(Value::as_str)? + .to_string(), + part_id: props + .get("partID") + .or_else(|| props.get("partId")) + .and_then(Value::as_str)? + .to_string(), + delta: props.get("delta")?.as_str()?.to_string(), + }) +} + +fn make_assistant_stream_chunk_event(entry: &ActiveTextPartBuffer, delta: &str) -> Value { + serde_json::json!({ + "type": "instance.event", + "instanceId": entry.instance_id, + "event": { + "type": "assistant.stream.chunk", + "properties": { + "sessionID": entry.session_id, + "messageID": entry.message_id, + "partID": entry.part_id, + "field": "text", + "delta": delta, + } + } + }) +} + +fn make_message_part_delta_event(entry: &ActiveTextPartBuffer, delta: &str) -> Value { + serde_json::json!({ + "type": "instance.event", + "instanceId": entry.instance_id, + "event": { + "type": "message.part.delta", + "properties": { + "sessionID": entry.session_id, + "messageID": entry.message_id, + "partID": entry.part_id, + "field": "text", + "delta": delta, + } + } + }) +} + +fn parse_active_text_snapshot( + event: &Value, + active_target: Option<&ActiveSessionTarget>, +) -> Option { + let active_target = active_target?; + let instance_id = coalesced_instance_id(event); + if instance_id != active_target.instance_id { + return None; + } + + let inner = coalesced_payload_event(event); + if inner.get("type")?.as_str()? != "message.part.updated" { + return None; + } + + let part = inner.get("properties")?.get("part")?; + if part.get("type")?.as_str()? != "text" { + return None; + } + if part.get("text")?.as_str().is_none() { + return None; + } + + let event_session = part + .get("sessionID") + .or_else(|| part.get("sessionId")) + .and_then(Value::as_str)?; + if event_session != active_target.session_id { + return None; + } + + let message_id = part + .get("messageID") + .or_else(|| part.get("messageId")) + .and_then(Value::as_str)?; + let part_id = part.get("id")?.as_str()?; + + Some(ActiveTextSnapshot { + key: format!( + "{}:{}:{}:{}", + instance_id, event_session, message_id, part_id + ), + instance_id: instance_id.to_string(), + session_id: event_session.to_string(), + message_id: message_id.to_string(), + part_id: part_id.to_string(), + event: event.clone(), + }) +} + +fn snapshot_key(event: &Value) -> Option { + let instance_id = coalesced_instance_id(event); + let inner = coalesced_payload_event(event); + let inner_type = inner.get("type")?.as_str()?; + let props = inner.get("properties")?; + + match inner_type { + "message.part.updated" => { + let session_id = props + .get("part") + .and_then(|part| part.get("sessionID").or_else(|| part.get("sessionId"))) + .and_then(Value::as_str)?; + let message_id = props + .get("part") + .and_then(|part| part.get("messageID").or_else(|| part.get("messageId"))) + .and_then(Value::as_str)?; + let part_id = props + .get("part") + .and_then(|part| part.get("id")) + .and_then(Value::as_str)?; + + Some(format!( + "message.part.updated:{}:{}:{}:{}", + instance_id, session_id, message_id, part_id + )) + } + "message.updated" => { + let info = props.get("info")?; + let session_id = info + .get("sessionID") + .or_else(|| info.get("sessionId")) + .and_then(Value::as_str)?; + let message_id = info.get("id").and_then(Value::as_str)?; + + Some(format!( + "message.updated:{}:{}:{}", + instance_id, session_id, message_id + )) + } + "session.updated" | "session.status" => { + let session_id = props + .get("info") + .and_then(|info| info.get("id")) + .and_then(Value::as_str) + .or_else(|| { + props + .get("sessionID") + .or_else(|| props.get("sessionId")) + .and_then(Value::as_str) + })?; + + Some(format!("{}:{}:{}", inner_type, instance_id, session_id)) + } + _ => None, + } +} + +fn delta_scope(event: &Value) -> Option { + let instance_id = coalesced_instance_id(event); + let inner = coalesced_payload_event(event); + if inner.get("type")?.as_str()? != "message.part.delta" { + return None; + } + + let props = inner.get("properties")?; + let session_id = props + .get("sessionID") + .or_else(|| props.get("sessionId")) + .and_then(Value::as_str) + .unwrap_or_default(); + let message_id = props + .get("messageID") + .or_else(|| props.get("messageId")) + .and_then(Value::as_str)?; + let part_id = props + .get("partID") + .or_else(|| props.get("partId")) + .and_then(Value::as_str)?; + + Some(format!( + "message.part:{}:{}:{}:{}", + instance_id, session_id, message_id, part_id + )) +} + +fn delta_key(event: &Value) -> Option { + let scope = delta_scope(event)?; + let props = coalesced_payload_event(event).get("properties")?; + let field = props.get("field")?.as_str()?; + + Some(format!("{}:{}", scope, field)) +} + +fn snapshot_superseded_delta_scope(event: &Value) -> Option { + let instance_id = coalesced_instance_id(event); + let inner = coalesced_payload_event(event); + if inner.get("type")?.as_str()? != "message.part.updated" { + return None; + } + + let part = inner.get("properties")?.get("part")?; + let session_id = part + .get("sessionID") + .or_else(|| part.get("sessionId")) + .and_then(Value::as_str)?; + let message_id = part + .get("messageID") + .or_else(|| part.get("messageId")) + .and_then(Value::as_str)?; + let part_id = part.get("id")?.as_str()?; + + Some(format!( + "message.part:{}:{}:{}:{}", + instance_id, session_id, message_id, part_id + )) +} + +fn append_delta(target: &mut Value, event: &Value) { + let next_delta = coalesced_payload_event(event) + .get("properties") + .and_then(|value| value.get("delta")) + .and_then(Value::as_str) + .unwrap_or_default(); + + if let Some(existing_delta) = coalesced_payload_event_mut(target) + .and_then(|event| event.get_mut("properties")) + .and_then(Value::as_object_mut) + .and_then(|props| props.get_mut("delta")) + { + let combined = existing_delta.as_str().unwrap_or_default().to_string() + next_delta; + *existing_delta = Value::String(combined); + } +} + +fn coalesced_payload_event_mut(event: &mut Value) -> Option<&mut serde_json::Map> { + if event.get("type").and_then(Value::as_str) == Some("instance.event") { + event.get_mut("event").and_then(Value::as_object_mut) + } else { + event.as_object_mut() + } +} + +fn status_key(event: &Value) -> Option { + match event.get("type")?.as_str()? { + "instance.eventStatus" => Some(coalesced_instance_id(event).to_string()), + "session.status" => snapshot_key(event), + _ => None, + } +} + +#[cfg(test)] +mod tests; diff --git a/packages/tauri-app/src-tauri/src/desktop_event_transport/assembler.rs b/packages/tauri-app/src-tauri/src/desktop_event_transport/assembler.rs new file mode 100644 index 00000000..82299452 --- /dev/null +++ b/packages/tauri-app/src-tauri/src/desktop_event_transport/assembler.rs @@ -0,0 +1,501 @@ +use super::*; + +impl PendingBatch { + pub(super) fn push(&mut self, event: Value, stats: &mut DesktopEventTransportStats) { + match classify_event(&event) { + EventDeliveryPolicy::CoalesceDelta(key) => { + let Some(scope) = delta_scope(&event) else { + self.events.push(PendingEntry::Event(event)); + return; + }; + + if let Some(PendingEntry::Delta { + key: existing_key, + event: existing_event, + .. + }) = self.events.last_mut() + { + if existing_key == &key { + append_delta(existing_event, &event); + stats.delta_coalesces = stats.delta_coalesces.saturating_add(1); + return; + } + } + + self.events.push(PendingEntry::Delta { + key, + scope, + instance_id: coalesced_instance_id(&event).to_string(), + session_id: event_session_id(&event).map(|value| value.to_string()), + event, + started_at: Instant::now(), + }); + } + EventDeliveryPolicy::CoalesceStatus(key) => { + if let Some(PendingEntry::Status { + key: existing_key, + event: existing_event, + }) = self.events.last_mut() + { + if existing_key == &key { + *existing_event = event; + stats.status_coalesces = stats.status_coalesces.saturating_add(1); + return; + } + } + + self.events.push(PendingEntry::Status { key, event }); + } + EventDeliveryPolicy::CoalesceSnapshot(key) => { + if let Some(part_scope) = snapshot_superseded_delta_scope(&event) { + let mut dropped = 0_u64; + while matches!( + self.events.last(), + Some(PendingEntry::Delta { scope, .. }) if scope == &part_scope + ) { + self.events.pop(); + dropped = dropped.saturating_add(1); + } + if dropped > 0 { + stats.superseded_deltas_dropped = + stats.superseded_deltas_dropped.saturating_add(dropped); + } + } + + if let Some(PendingEntry::Snapshot { + key: existing_key, + event: existing_event, + }) = self.events.last_mut() + { + if existing_key == &key { + *existing_event = event; + stats.snapshot_coalesces = stats.snapshot_coalesces.saturating_add(1); + return; + } + } + + self.events.push(PendingEntry::Snapshot { key, event }); + } + EventDeliveryPolicy::Passthrough => { + self.events.push(PendingEntry::Event(event)); + } + } + } + + pub(super) fn take_events(&mut self) -> Vec { + let pending = std::mem::take(&mut self.events); + pending + .into_iter() + .map(|entry| match entry { + PendingEntry::Delta { event, .. } => event, + PendingEntry::Status { event, .. } => event, + PendingEntry::Snapshot { event, .. } => event, + PendingEntry::Event(event) => event, + }) + .collect() + } + + pub(super) fn is_empty(&self) -> bool { + self.events.is_empty() + } + + pub(super) fn pending_len(&self) -> usize { + self.events.len() + } + + pub(super) fn should_hold_single_delta( + &self, + now: Instant, + active_target: Option<&ActiveSessionTarget>, + ) -> bool { + matches!( + self.events.as_slice(), + [PendingEntry::Delta { started_at, instance_id, session_id, .. }] + if now.duration_since(*started_at) < Duration::from_millis( + if active_target + .map(|target| { + target.instance_id.as_str() == instance_id.as_str() + && target.session_id.as_str() == session_id.as_deref().unwrap_or_default() + }) + .unwrap_or(false) + { + ACTIVE_STREAM_HOLD_WINDOW_MS + } else { + DELTA_STREAM_WINDOW_MS + } + ) + ) + } +} + +impl ActiveTextAssembler { + pub(super) fn absorb(&mut self, delta: ActiveTextDelta, now: Instant) -> Vec { + let key = format!( + "{}:{}:{}:{}", + delta.instance_id, delta.session_id, delta.message_id, delta.part_id + ); + + match self.parts.entry(key) { + std::collections::hash_map::Entry::Occupied(mut occupied) => { + let entry = occupied.get_mut(); + if entry.display_pending.is_empty() && entry.store_pending.is_empty() { + entry.instance_id = delta.instance_id.clone(); + entry.session_id = delta.session_id.clone(); + entry.message_id = delta.message_id.clone(); + entry.part_id = delta.part_id.clone(); + } + + entry.display_pending.push_str(&delta.delta); + entry.store_pending.push_str(&delta.delta); + Self::collect_due_for_part(entry, now) + } + std::collections::hash_map::Entry::Vacant(vacant) => { + let mut entry = ActiveTextPartBuffer::new(delta, now); + entry.last_display_emit = now + .checked_sub(Duration::from_millis(ACTIVE_STREAM_DISPLAY_WINDOW_MS)) + .unwrap_or(now); + let emitted = Self::collect_due_for_part(&mut entry, now); + vacant.insert(entry); + emitted + } + } + } + + pub(super) fn take_due(&mut self, now: Instant) -> Vec { + let mut emitted = Vec::new(); + let mut empty_keys = Vec::new(); + + for (key, entry) in self.parts.iter_mut() { + emitted.extend(Self::collect_due_for_part(entry, now)); + if entry.display_pending.is_empty() && entry.store_pending.is_empty() { + empty_keys.push(key.clone()); + } + } + + for key in empty_keys { + self.parts.remove(&key); + } + + emitted + } + + pub(super) fn flush_for_event(&mut self, event: &Value, now: Instant) -> Vec { + let instance_id = coalesced_instance_id(event); + let payload = coalesced_payload_event(event); + let event_type = payload.get("type").and_then(Value::as_str); + + match event_type { + Some("message.updated") | Some("message.removed") => { + let props = payload.get("properties"); + let session_id = event_session_id(event); + let message_id = props + .and_then(|value| { + value + .get("info") + .and_then(|info| info.get("id")) + .or_else(|| value.get("messageID")) + .or_else(|| value.get("messageId")) + }) + .and_then(Value::as_str); + if let (Some(session_id), Some(message_id)) = (session_id, message_id) { + return self.flush_message(instance_id, session_id, message_id, now); + } + } + Some("message.part.updated") | Some("message.part.removed") => { + let props = payload.get("properties"); + let session_id = event_session_id(event); + let message_id = props + .and_then(|value| { + value + .get("part") + .and_then(|part| { + part.get("messageID").or_else(|| part.get("messageId")) + }) + .or_else(|| value.get("messageID")) + .or_else(|| value.get("messageId")) + }) + .and_then(Value::as_str); + let part_id = props + .and_then(|value| { + value + .get("part") + .and_then(|part| part.get("id")) + .or_else(|| value.get("partID")) + .or_else(|| value.get("partId")) + }) + .and_then(Value::as_str); + if let (Some(session_id), Some(message_id), Some(part_id)) = + (session_id, message_id, part_id) + { + return self.flush_part(instance_id, session_id, message_id, part_id, now); + } + } + _ => {} + } + + Vec::new() + } + + pub(super) fn flush_message( + &mut self, + instance_id: &str, + session_id: &str, + message_id: &str, + now: Instant, + ) -> Vec { + let keys: Vec = self + .parts + .iter() + .filter(|(_, entry)| { + entry.instance_id == instance_id + && entry.session_id == session_id + && entry.message_id == message_id + }) + .map(|(key, _)| key.clone()) + .collect(); + + let mut emitted = Vec::new(); + for key in keys { + if let Some(mut entry) = self.parts.remove(&key) { + emitted.extend(Self::flush_all_for_part(&mut entry, now)); + } + } + emitted + } + + pub(super) fn flush_part( + &mut self, + instance_id: &str, + session_id: &str, + message_id: &str, + part_id: &str, + now: Instant, + ) -> Vec { + let key = format!("{}:{}:{}:{}", instance_id, session_id, message_id, part_id); + if let Some(mut entry) = self.parts.remove(&key) { + return Self::flush_all_for_part(&mut entry, now); + } + Vec::new() + } + + pub(super) fn flush_store_only_all(&mut self, now: Instant) -> Vec { + let mut emitted = Vec::new(); + for entry in self.parts.values_mut() { + if !entry.store_pending.is_empty() { + emitted.push(make_message_part_delta_event(entry, &entry.store_pending)); + entry.store_pending.clear(); + entry.last_store_emit = now; + } + entry.display_pending.clear(); + entry.last_display_emit = now; + } + self.parts.clear(); + emitted + } + + fn collect_due_for_part(entry: &mut ActiveTextPartBuffer, now: Instant) -> Vec { + let mut emitted = Vec::new(); + + // Display lane — emit preview chunks frequently (~16ms / 96 chars). + if !entry.display_pending.is_empty() + && (now.duration_since(entry.last_display_emit) + >= Duration::from_millis(ACTIVE_STREAM_DISPLAY_WINDOW_MS) + || entry.display_pending.len() >= ACTIVE_STREAM_DISPLAY_CHUNK_MAX) + { + emitted.push(make_assistant_stream_chunk_event( + entry, + &entry.display_pending, + )); + entry.display_pending.clear(); + entry.last_display_emit = now; + } + + // Store lane — emit canonical deltas infrequently (~250ms) to avoid + // flooding the JS reactive graph with store mutations that + // trigger expensive re-render cascades during active streaming. + // Explicit flush triggers (message.updated, message.part.updated, + // session change, disconnect) still flush immediately via + // flush_for_event / flush_all_for_part / flush_store_only_all. + if !entry.store_pending.is_empty() + && now.duration_since(entry.last_store_emit) + >= Duration::from_millis(ACTIVE_STREAM_STORE_WINDOW_MS) + { + emitted.push(make_message_part_delta_event(entry, &entry.store_pending)); + entry.store_pending.clear(); + entry.last_store_emit = now; + } + + emitted + } + + fn flush_all_for_part(entry: &mut ActiveTextPartBuffer, now: Instant) -> Vec { + let mut emitted = Vec::new(); + if !entry.display_pending.is_empty() { + emitted.push(make_assistant_stream_chunk_event( + entry, + &entry.display_pending, + )); + entry.display_pending.clear(); + entry.last_display_emit = now; + } + if !entry.store_pending.is_empty() { + emitted.push(make_message_part_delta_event(entry, &entry.store_pending)); + entry.store_pending.clear(); + entry.last_store_emit = now; + } + emitted + } +} + +impl ActiveTextSnapshotBuffer { + pub(super) fn buffer(&mut self, snapshot: ActiveTextSnapshot, now: Instant) { + match self.parts.entry(snapshot.key) { + std::collections::hash_map::Entry::Occupied(mut occupied) => { + let entry = occupied.get_mut(); + entry.instance_id = snapshot.instance_id; + entry.session_id = snapshot.session_id; + entry.message_id = snapshot.message_id; + entry.part_id = snapshot.part_id; + entry.event = snapshot.event; + } + std::collections::hash_map::Entry::Vacant(vacant) => { + vacant.insert(BufferedTextSnapshot { + instance_id: snapshot.instance_id, + session_id: snapshot.session_id, + message_id: snapshot.message_id, + part_id: snapshot.part_id, + event: snapshot.event, + buffered_at: now, + }); + } + } + } + + pub(super) fn take_due(&mut self, now: Instant) -> Vec { + let keys: Vec = self + .parts + .iter() + .filter(|(_, entry)| { + now.duration_since(entry.buffered_at) + >= Duration::from_millis(ACTIVE_STREAM_SNAPSHOT_WINDOW_MS) + }) + .map(|(key, _)| key.clone()) + .collect(); + + self.take_entries(keys) + } + + pub(super) fn flush_for_event(&mut self, event: &Value) -> Vec { + let instance_id = coalesced_instance_id(event); + let payload = coalesced_payload_event(event); + let event_type = payload.get("type").and_then(Value::as_str); + + match event_type { + Some("message.updated") | Some("message.removed") => { + let props = payload.get("properties"); + let session_id = event_session_id(event); + let message_id = props + .and_then(|value| { + value + .get("info") + .and_then(|info| info.get("id")) + .or_else(|| value.get("messageID")) + .or_else(|| value.get("messageId")) + }) + .and_then(Value::as_str); + if let (Some(session_id), Some(message_id)) = (session_id, message_id) { + return self.flush_message(instance_id, session_id, message_id); + } + } + Some("message.part.removed") => { + let props = payload.get("properties"); + let session_id = event_session_id(event); + let message_id = props + .and_then(|value| { + value + .get("part") + .and_then(|part| { + part.get("messageID").or_else(|| part.get("messageId")) + }) + .or_else(|| value.get("messageID")) + .or_else(|| value.get("messageId")) + }) + .and_then(Value::as_str); + let part_id = props + .and_then(|value| { + value + .get("part") + .and_then(|part| part.get("id")) + .or_else(|| value.get("partID")) + .or_else(|| value.get("partId")) + }) + .and_then(Value::as_str); + if let (Some(session_id), Some(message_id), Some(part_id)) = + (session_id, message_id, part_id) + { + return self.flush_part(instance_id, session_id, message_id, part_id); + } + } + _ => {} + } + + Vec::new() + } + + pub(super) fn flush_message( + &mut self, + instance_id: &str, + session_id: &str, + message_id: &str, + ) -> Vec { + let keys: Vec = self + .parts + .iter() + .filter(|(_, entry)| { + entry.instance_id == instance_id + && entry.session_id == session_id + && entry.message_id == message_id + }) + .map(|(key, _)| key.clone()) + .collect(); + + self.take_entries(keys) + } + + pub(super) fn flush_part( + &mut self, + instance_id: &str, + session_id: &str, + message_id: &str, + part_id: &str, + ) -> Vec { + let keys: Vec = self + .parts + .iter() + .filter(|(_, entry)| { + entry.instance_id == instance_id + && entry.session_id == session_id + && entry.message_id == message_id + && entry.part_id == part_id + }) + .map(|(key, _)| key.clone()) + .collect(); + + self.take_entries(keys) + } + + pub(super) fn flush_all(&mut self) -> Vec { + let keys: Vec = self.parts.keys().cloned().collect(); + self.take_entries(keys) + } + + fn take_entries(&mut self, keys: Vec) -> Vec { + let mut emitted = Vec::new(); + for key in keys { + if let Some(entry) = self.parts.remove(&key) { + emitted.push(entry.event); + } + } + emitted + } +} diff --git a/packages/tauri-app/src-tauri/src/desktop_event_transport/stream.rs b/packages/tauri-app/src-tauri/src/desktop_event_transport/stream.rs new file mode 100644 index 00000000..e861a1b3 --- /dev/null +++ b/packages/tauri-app/src-tauri/src/desktop_event_transport/stream.rs @@ -0,0 +1,185 @@ +use super::*; + +pub(super) fn build_stream_client() -> Result { + Client::builder() + .connect_timeout(Duration::from_millis(STREAM_CONNECT_TIMEOUT_MS)) + .tcp_keepalive(Duration::from_millis(STREAM_TCP_KEEPALIVE_MS)) + // Note: reqwest's blocking client doesn't expose a per-read timeout. + // The global `.timeout()` would kill the entire SSE stream, so we + // rely on: + // 1. tcp_keepalive to detect dead connections (OS will RST after + // several unacked probes, typically ~2 min). + // 2. Consumer-side stall detection (STREAM_STALL_TIMEOUT_MS). + // 3. Reader thread breaking on channel send error (consumer dropped). + .build() + .map_err(|error: reqwest::Error| OpenStreamError { + kind: OpenStreamErrorKind::Transport, + message: error.to_string(), + status_code: None, + }) +} + +pub(super) fn open_stream( + app: &AppHandle, + client: &Client, + config: &DesktopEventStreamConfig, +) -> Result { + let connection_id = generate_connection_id(); + let url = format!( + "{}?clientId={}&connectionId={}", + config.events_url, config.client_id, connection_id + ); + + let mut request = client.get(&url).header("Accept", "text/event-stream"); + + if let Some(session_cookie) = resolve_session_cookie(app, config) { + request = request.header( + "Cookie", + format!("{}={}", config.cookie_name, session_cookie), + ); + } + + let response = request.send().map_err(|error| OpenStreamError { + kind: OpenStreamErrorKind::Transport, + message: error.to_string(), + status_code: None, + })?; + + if response.status().is_success() { + return Ok(response); + } + + let status = response.status(); + let kind = if matches!(status, StatusCode::UNAUTHORIZED | StatusCode::FORBIDDEN) { + OpenStreamErrorKind::Unauthorized + } else { + OpenStreamErrorKind::Http + }; + + Err(OpenStreamError { + kind, + message: format!("desktop event stream unavailable ({status})"), + status_code: Some(status.as_u16()), + }) +} + +fn resolve_session_cookie(app: &AppHandle, config: &DesktopEventStreamConfig) -> Option { + read_session_cookie_from_webview(app, &config.base_url, &config.cookie_name) + .or_else(|| config.session_cookie.clone()) + .filter(|value| !value.is_empty()) +} + +fn generate_connection_id() -> String { + let ts = SystemTime::now() + .duration_since(UNIX_EPOCH) + .unwrap_or_default() + .as_millis(); + let tid = std::thread::current().id(); + format!("tauri-{}-{:?}", ts, tid) +} + +fn read_session_cookie_from_webview( + app: &AppHandle, + base_url: &str, + cookie_name: &str, +) -> Option { + let url = Url::parse(base_url).ok()?; + let host = url.host_str()?.to_ascii_lowercase(); + let path = url.path(); + let windows = app.webview_windows(); + let window = windows.get("main")?; + let cookies = window.cookies().ok()?; + cookies + .into_iter() + .filter(|cookie: &tauri::webview::cookie::Cookie<'static>| cookie.name() == cookie_name) + .filter(|cookie: &tauri::webview::cookie::Cookie<'static>| { + let Some(domain) = cookie.domain() else { + return true; + }; + + let normalized_domain = domain.trim_start_matches('.').to_ascii_lowercase(); + host == normalized_domain || host.ends_with(&format!(".{}", normalized_domain)) + }) + .filter(|cookie: &tauri::webview::cookie::Cookie<'static>| { + let Some(cookie_path) = cookie.path() else { + return true; + }; + + path.starts_with(cookie_path) + }) + .map(|cookie: tauri::webview::cookie::Cookie<'static>| cookie.value().to_string()) + .next() +} + +pub(super) fn read_sse( + response: Response, + tx: SyncSender, + stop: Arc, + generation_atomic: Arc, + generation: u64, +) { + let mut reader = BufReader::new(response); + let mut line = String::new(); + let mut data_lines: Vec = Vec::new(); + + loop { + if stop.load(Ordering::SeqCst) || !generation_matches(&generation_atomic, generation) { + let _ = tx.send(ReaderMessage::End(Some("stopped".to_string()))); + return; + } + + line.clear(); + match reader.read_line(&mut line) { + Ok(0) => { + if let Some(event) = parse_sse_payload(&data_lines) { + let _ = tx.send(ReaderMessage::Event(event)); + } + let _ = tx.send(ReaderMessage::End(Some("stream closed".to_string()))); + return; + } + Ok(_) => { + if tx.send(ReaderMessage::Activity).is_err() { + return; // consumer dropped — stop reading + } + let trimmed = line.trim_end_matches(['\r', '\n']); + if trimmed.is_empty() { + if let Some(event) = parse_sse_payload(&data_lines) { + if tx.send(ReaderMessage::Event(event)).is_err() { + return; // consumer dropped + } + } + data_lines.clear(); + continue; + } + + if trimmed.starts_with(':') { + continue; + } + + if let Some(data) = trimmed.strip_prefix("data:") { + data_lines.push(data.strip_prefix(' ').unwrap_or(data).to_string()); + } + } + Err(error) => { + if let Some(event) = parse_sse_payload(&data_lines) { + let _ = tx.send(ReaderMessage::Event(event)); + } + let _ = tx.send(ReaderMessage::End(Some(error.to_string()))); + return; + } + } + } +} + +fn parse_sse_payload(lines: &[String]) -> Option { + if lines.is_empty() { + return None; + } + + let payload = lines.join("\n").trim().to_string(); + if payload.is_empty() { + return None; + } + + serde_json::from_str::(&payload).ok() +} diff --git a/packages/tauri-app/src-tauri/src/desktop_event_transport/tests.rs b/packages/tauri-app/src-tauri/src/desktop_event_transport/tests.rs new file mode 100644 index 00000000..e464ef8c --- /dev/null +++ b/packages/tauri-app/src-tauri/src/desktop_event_transport/tests.rs @@ -0,0 +1,575 @@ +use super::*; +use serde_json::json; + +fn fresh_stats() -> DesktopEventTransportStats { + DesktopEventTransportStats::default() +} + +fn delta_event(delta: &str) -> Value { + json!({ + "type": "instance.event", + "instanceId": "inst-1", + "event": { + "type": "message.part.delta", + "properties": { + "sessionID": "sess-1", + "messageID": "msg-1", + "partID": "part-1", + "field": "text", + "delta": delta, + } + } + }) +} + +fn delta_event_for(part_id: &str, delta: &str) -> Value { + json!({ + "type": "instance.event", + "instanceId": "inst-1", + "event": { + "type": "message.part.delta", + "properties": { + "sessionID": "sess-1", + "messageID": "msg-1", + "partID": part_id, + "field": "text", + "delta": delta, + } + } + }) +} + +fn direct_delta_event(delta: &str) -> Value { + json!({ + "type": "message.part.delta", + "properties": { + "sessionID": "sess-1", + "messageID": "msg-1", + "partID": "part-1", + "field": "text", + "delta": delta, + } + }) +} + +fn direct_message_part_updated_event(text: &str) -> Value { + json!({ + "type": "message.part.updated", + "properties": { + "part": { + "id": "part-1", + "type": "text", + "text": text, + "sessionID": "sess-1", + "messageID": "msg-1" + } + } + }) +} + +fn message_part_updated_event(text: &str) -> Value { + json!({ + "type": "instance.event", + "instanceId": "inst-1", + "event": { + "type": "message.part.updated", + "properties": { + "part": { + "id": "part-1", + "type": "text", + "text": text, + "sessionID": "sess-1", + "messageID": "msg-1" + } + } + } + }) +} + +fn active_target() -> ActiveSessionTarget { + ActiveSessionTarget { + instance_id: "inst-1".to_string(), + session_id: "sess-1".to_string(), + } +} + +#[test] +fn coalesces_message_part_delta_events() { + let mut pending = PendingBatch::default(); + let mut stats = fresh_stats(); + pending.push(delta_event("Hello"), &mut stats); + pending.push(delta_event(" world"), &mut stats); + + let events = pending.take_events(); + assert_eq!(events.len(), 1); + assert_eq!( + events[0]["event"]["properties"]["delta"].as_str(), + Some("Hello world") + ); +} + +#[test] +fn last_write_wins_for_status_events() { + let mut pending = PendingBatch::default(); + let mut stats = fresh_stats(); + pending.push( + json!({ + "type": "instance.eventStatus", + "instanceId": "inst-1", + "status": "connecting" + }), + &mut stats, + ); + pending.push( + json!({ + "type": "instance.eventStatus", + "instanceId": "inst-1", + "status": "connected" + }), + &mut stats, + ); + + let events = pending.take_events(); + assert_eq!(events.len(), 1); + assert_eq!(events[0]["status"].as_str(), Some("connected")); +} + +#[test] +fn last_write_wins_for_consecutive_snapshot_events() { + let mut pending = PendingBatch::default(); + let mut stats = fresh_stats(); + pending.push(message_part_updated_event("Hello"), &mut stats); + pending.push(message_part_updated_event("Hello world"), &mut stats); + + let events = pending.take_events(); + assert_eq!(events.len(), 1); + assert_eq!( + events[0]["event"]["properties"]["part"]["text"].as_str(), + Some("Hello world") + ); +} + +#[test] +fn interleaved_snapshot_keys_keep_order() { + let mut pending = PendingBatch::default(); + let mut stats = fresh_stats(); + pending.push(message_part_updated_event("A1"), &mut stats); + pending.push( + json!({ + "type": "instance.event", + "instanceId": "inst-1", + "event": { + "type": "message.part.updated", + "properties": { + "part": { + "id": "part-2", + "type": "text", + "text": "B1", + "sessionID": "sess-1", + "messageID": "msg-1" + } + } + } + }), + &mut stats, + ); + pending.push(message_part_updated_event("A2"), &mut stats); + + let events = pending.take_events(); + assert_eq!(events.len(), 3); + assert_eq!( + events[0]["event"]["properties"]["part"]["id"].as_str(), + Some("part-1") + ); + assert_eq!( + events[1]["event"]["properties"]["part"]["id"].as_str(), + Some("part-2") + ); + assert_eq!( + events[2]["event"]["properties"]["part"]["text"].as_str(), + Some("A2") + ); +} + +#[test] +fn snapshot_replaces_trailing_deltas_for_same_part() { + let mut pending = PendingBatch::default(); + let mut stats = fresh_stats(); + pending.push(delta_event("Hello"), &mut stats); + pending.push(message_part_updated_event("Hello world"), &mut stats); + + let events = pending.take_events(); + assert_eq!(events.len(), 1); + assert_eq!( + events[0]["event"]["type"].as_str(), + Some("message.part.updated") + ); + assert_eq!( + events[0]["event"]["properties"]["part"]["text"].as_str(), + Some("Hello world") + ); +} + +#[test] +fn structural_events_force_coalesced_flush_before_append() { + let mut pending = PendingBatch::default(); + let mut stats = fresh_stats(); + pending.push(delta_event("Hello"), &mut stats); + pending.push( + json!({ + "type": "instance.event", + "instanceId": "inst-1", + "event": { + "type": "message.updated", + "properties": { + "id": "msg-1" + } + } + }), + &mut stats, + ); + + let events = pending.take_events(); + assert_eq!(events.len(), 2); + assert_eq!( + events[0]["event"]["type"].as_str(), + Some("message.part.delta") + ); + assert_eq!(events[1]["event"]["type"].as_str(), Some("message.updated")); +} + +#[test] +fn interleaved_delta_keys_keep_order() { + let mut pending = PendingBatch::default(); + let mut stats = fresh_stats(); + pending.push(delta_event_for("part-1", "A1"), &mut stats); + pending.push(delta_event_for("part-2", "B1"), &mut stats); + pending.push(delta_event_for("part-1", "A2"), &mut stats); + + let events = pending.take_events(); + assert_eq!(events.len(), 3); + assert_eq!( + events[0]["event"]["properties"]["partID"].as_str(), + Some("part-1") + ); + assert_eq!( + events[0]["event"]["properties"]["delta"].as_str(), + Some("A1") + ); + assert_eq!( + events[1]["event"]["properties"]["partID"].as_str(), + Some("part-2") + ); + assert_eq!( + events[1]["event"]["properties"]["delta"].as_str(), + Some("B1") + ); + assert_eq!( + events[2]["event"]["properties"]["partID"].as_str(), + Some("part-1") + ); + assert_eq!( + events[2]["event"]["properties"]["delta"].as_str(), + Some("A2") + ); +} + +#[test] +fn reconnect_delay_grows_and_caps() { + let policy = ResolvedDesktopEventReconnectPolicy { + initial_delay_ms: 100, + max_delay_ms: 500, + multiplier: 2.0, + max_attempts: None, + }; + + assert_eq!(compute_reconnect_delay_ms(1, &policy), 100); + assert_eq!(compute_reconnect_delay_ms(2, &policy), 200); + assert_eq!(compute_reconnect_delay_ms(3, &policy), 400); + assert_eq!(compute_reconnect_delay_ms(4, &policy), 500); +} + +#[test] +fn holds_single_delta_within_stream_window() { + let pending = PendingBatch { + events: vec![PendingEntry::Delta { + key: "delta-key".to_string(), + scope: "delta-scope".to_string(), + instance_id: "inst-1".to_string(), + session_id: Some("sess-1".to_string()), + event: delta_event("Hello"), + started_at: Instant::now(), + }], + }; + + assert!(pending.should_hold_single_delta(Instant::now(), None)); +} + +#[test] +fn flushes_single_delta_after_stream_window() { + let started_at = Instant::now() - Duration::from_millis(DELTA_STREAM_WINDOW_MS + 1); + let pending = PendingBatch { + events: vec![PendingEntry::Delta { + key: "delta-key".to_string(), + scope: "delta-scope".to_string(), + instance_id: "inst-1".to_string(), + session_id: Some("sess-1".to_string()), + event: delta_event("Hello"), + started_at, + }], + }; + + assert!(!pending.should_hold_single_delta(Instant::now(), None)); +} + +#[test] +fn active_session_uses_shorter_hold_window() { + // Delta aged past the active-stream window but within the base window. + // Active session should flush faster, so this should NOT be held. + let started_at = Instant::now() - Duration::from_millis(ACTIVE_STREAM_HOLD_WINDOW_MS + 10); + let pending = PendingBatch { + events: vec![PendingEntry::Delta { + key: "delta-key".to_string(), + scope: "delta-scope".to_string(), + instance_id: "inst-1".to_string(), + session_id: Some("sess-1".to_string()), + event: delta_event("Hello"), + started_at, + }], + }; + + let active_target = active_target(); + let other_target = ActiveSessionTarget { + instance_id: "inst-1".to_string(), + session_id: "sess-2".to_string(), + }; + + // Active session: uses ACTIVE_STREAM_HOLD_WINDOW_MS, so this stale delta is not held. + assert!(!pending.should_hold_single_delta(Instant::now(), Some(&active_target))); + // Non-matching session: still uses the wider base window. + assert!(pending.should_hold_single_delta(Instant::now(), Some(&other_target))); +} + +#[test] +fn active_session_holds_fresh_delta() { + // A very fresh delta should be held even for the active session's shorter window. + let started_at = Instant::now() - Duration::from_millis(5); + let pending = PendingBatch { + events: vec![PendingEntry::Delta { + key: "delta-key".to_string(), + scope: "delta-scope".to_string(), + instance_id: "inst-1".to_string(), + session_id: Some("sess-1".to_string()), + event: delta_event("Hello"), + started_at, + }], + }; + + let active_target = active_target(); + + assert!(pending.should_hold_single_delta(Instant::now(), Some(&active_target))); +} + +#[test] +fn assembler_emits_first_preview_chunk_immediately() { + let mut assembler = ActiveTextAssembler::default(); + let now = Instant::now(); + + let emitted = assembler.absorb( + ActiveTextDelta { + instance_id: "inst-1".to_string(), + session_id: "sess-1".to_string(), + message_id: "msg-1".to_string(), + part_id: "part-1".to_string(), + delta: "Hello".to_string(), + }, + now, + ); + + assert_eq!(emitted.len(), 1); + assert_eq!( + coalesced_payload_event(&emitted[0]) + .get("type") + .and_then(Value::as_str), + Some("assistant.stream.chunk") + ); + assert_eq!( + coalesced_payload_event(&emitted[0]) + .get("properties") + .and_then(|props| props.get("delta")) + .and_then(Value::as_str), + Some("Hello") + ); +} + +#[test] +fn snapshot_buffer_coalesces_updates_within_window() { + let mut buffer = ActiveTextSnapshotBuffer::default(); + let now = Instant::now(); + + buffer.buffer( + parse_active_text_snapshot(&message_part_updated_event("A"), Some(&active_target())) + .unwrap(), + now, + ); + buffer.buffer( + parse_active_text_snapshot(&message_part_updated_event("AB"), Some(&active_target())) + .unwrap(), + now + Duration::from_millis(40), + ); + buffer.buffer( + parse_active_text_snapshot(&message_part_updated_event("ABC"), Some(&active_target())) + .unwrap(), + now + Duration::from_millis(80), + ); + + let early = buffer.take_due(now + Duration::from_millis(ACTIVE_STREAM_SNAPSHOT_WINDOW_MS - 1)); + assert!(early.is_empty()); + + let emitted = + buffer.take_due(now + Duration::from_millis(ACTIVE_STREAM_SNAPSHOT_WINDOW_MS + 1)); + assert_eq!(emitted.len(), 1); + assert_eq!( + emitted[0]["event"]["properties"]["part"]["text"].as_str(), + Some("ABC") + ); +} + +#[test] +fn snapshot_buffer_flushes_latest_snapshot_before_message_update() { + let mut buffer = ActiveTextSnapshotBuffer::default(); + let now = Instant::now(); + + buffer.buffer( + parse_active_text_snapshot(&message_part_updated_event("Hello"), Some(&active_target())) + .unwrap(), + now, + ); + buffer.buffer( + parse_active_text_snapshot( + &message_part_updated_event("Hello world"), + Some(&active_target()), + ) + .unwrap(), + now + Duration::from_millis(25), + ); + + let flushed = buffer.flush_for_event(&json!({ + "type": "instance.event", + "instanceId": "inst-1", + "event": { + "type": "message.updated", + "properties": { + "info": { + "id": "msg-1", + "sessionID": "sess-1" + } + } + } + })); + + assert_eq!(flushed.len(), 1); + assert_eq!( + flushed[0]["event"]["properties"]["part"]["text"].as_str(), + Some("Hello world") + ); +} + +#[test] +fn assembler_keeps_first_delta_after_full_flush() { + let mut assembler = ActiveTextAssembler::default(); + let now = Instant::now(); + let delta = ActiveTextDelta { + instance_id: "inst-1".to_string(), + session_id: "sess-1".to_string(), + message_id: "msg-1".to_string(), + part_id: "part-1".to_string(), + delta: "Hello".to_string(), + }; + + let _ = assembler.absorb(delta.clone(), now); + let _ = assembler.flush_message("inst-1", "sess-1", "msg-1", now); + let _ = assembler.absorb( + ActiveTextDelta { + delta: " world".to_string(), + ..delta + }, + now, + ); + let emitted = assembler.flush_store_only_all(now + Duration::from_millis(1)); + + assert!(emitted.iter().any(|event| { + coalesced_payload_event(event) + .get("type") + .and_then(Value::as_str) + == Some("message.part.delta") + && coalesced_payload_event(event) + .get("properties") + .and_then(|props| props.get("delta")) + .and_then(Value::as_str) + == Some(" world") + })); +} + +#[test] +fn flush_store_only_all_preserves_canonical_text_without_preview() { + let mut assembler = ActiveTextAssembler::default(); + let now = Instant::now(); + let _ = assembler.absorb( + ActiveTextDelta { + instance_id: "inst-1".to_string(), + session_id: "sess-1".to_string(), + message_id: "msg-1".to_string(), + part_id: "part-1".to_string(), + delta: "Hello".to_string(), + }, + now, + ); + + let emitted = assembler.flush_store_only_all(now + Duration::from_millis(1)); + assert_eq!(emitted.len(), 1); + assert_eq!( + coalesced_payload_event(&emitted[0]) + .get("type") + .and_then(Value::as_str), + Some("message.part.delta") + ); + assert_eq!( + coalesced_payload_event(&emitted[0]) + .get("properties") + .and_then(|props| props.get("delta")) + .and_then(Value::as_str), + Some("Hello") + ); +} + +#[test] +fn coalesces_direct_message_part_delta_events() { + let mut pending = PendingBatch::default(); + let mut stats = fresh_stats(); + pending.push(direct_delta_event("Hello"), &mut stats); + pending.push(direct_delta_event(" world"), &mut stats); + + let events = pending.take_events(); + assert_eq!(events.len(), 1); + assert_eq!( + events[0]["properties"]["delta"].as_str(), + Some("Hello world") + ); +} + +#[test] +fn direct_snapshot_replaces_trailing_direct_deltas_for_same_part() { + let mut pending = PendingBatch::default(); + let mut stats = fresh_stats(); + pending.push(direct_delta_event("Hello"), &mut stats); + pending.push(direct_message_part_updated_event("Hello world"), &mut stats); + + let events = pending.take_events(); + assert_eq!(events.len(), 1); + assert_eq!(events[0]["type"].as_str(), Some("message.part.updated")); + assert_eq!( + events[0]["properties"]["part"]["text"].as_str(), + Some("Hello world") + ); +} diff --git a/packages/tauri-app/src-tauri/src/desktop_event_transport/transport.rs b/packages/tauri-app/src-tauri/src/desktop_event_transport/transport.rs new file mode 100644 index 00000000..0c7da3a2 --- /dev/null +++ b/packages/tauri-app/src-tauri/src/desktop_event_transport/transport.rs @@ -0,0 +1,569 @@ +use super::*; + +pub(super) fn run_transport_loop( + app: AppHandle, + state: Arc>, + generation_atomic: Arc, + generation: u64, + stop: Arc, + config: DesktopEventTransportConfig, +) { + let mut reconnect_attempt = 0_u32; + let mut stats = DesktopEventTransportStats::default(); + + let client = match build_stream_client() { + Ok(client) => client, + Err(error) => { + emit_status( + &app, + generation, + "error", + 0, + true, + Some(error.message), + None, + None, + &stats, + ); + return; + } + }; + + loop { + if stop.load(Ordering::SeqCst) || !generation_matches(&generation_atomic, generation) { + break; + } + + emit_status( + &app, + generation, + "connecting", + reconnect_attempt, + false, + None, + None, + None, + &stats, + ); + + match open_stream(&app, &client, &config.stream) { + Ok(response) => { + reconnect_attempt = 0; + emit_status( + &app, + generation, + "connected", + reconnect_attempt, + false, + None, + None, + None, + &stats, + ); + + let disconnect_reason = consume_stream( + &app, + response, + &state, + &generation_atomic, + generation, + stop.clone(), + &mut stats, + ); + if stop.load(Ordering::SeqCst) + || !generation_matches(&generation_atomic, generation) + { + break; + } + + if !schedule_retry( + &app, + &generation_atomic, + generation, + stop.clone(), + &config.reconnect, + &mut reconnect_attempt, + "disconnected", + disconnect_reason, + None, + &stats, + ) { + break; + } + } + Err(error) => { + let state_name = match error.kind { + OpenStreamErrorKind::Unauthorized => "unauthorized", + OpenStreamErrorKind::Http | OpenStreamErrorKind::Transport => "error", + }; + + if !schedule_retry( + &app, + &generation_atomic, + generation, + stop.clone(), + &config.reconnect, + &mut reconnect_attempt, + state_name, + Some(error.message), + error.status_code, + &stats, + ) { + break; + } + } + } + } + + emit_status( + &app, + generation, + "stopped", + reconnect_attempt, + true, + None, + None, + None, + &stats, + ); +} + +fn schedule_retry( + app: &AppHandle, + generation_atomic: &Arc, + generation: u64, + stop: Arc, + policy: &ResolvedDesktopEventReconnectPolicy, + reconnect_attempt: &mut u32, + state_name: &'static str, + reason: Option, + status_code: Option, + stats: &DesktopEventTransportStats, +) -> bool { + *reconnect_attempt = reconnect_attempt.saturating_add(1); + let terminal = policy + .max_attempts + .map(|max_attempts| *reconnect_attempt >= max_attempts) + .unwrap_or(false); + let next_delay_ms = if terminal { + None + } else { + Some(compute_reconnect_delay_ms(*reconnect_attempt, policy)) + }; + + emit_status( + app, + generation, + state_name, + *reconnect_attempt, + terminal, + reason, + next_delay_ms, + status_code, + stats, + ); + + if terminal { + return false; + } + + if let Some(delay_ms) = next_delay_ms { + wait_with_cancellation(generation_atomic, generation, stop, delay_ms); + } + + true +} + +fn wait_with_cancellation( + generation_atomic: &Arc, + generation: u64, + stop: Arc, + delay_ms: u64, +) { + let mut remaining_ms = delay_ms; + while remaining_ms > 0 { + if stop.load(Ordering::SeqCst) || !generation_matches(generation_atomic, generation) { + return; + } + + let chunk_ms = remaining_ms.min(100); + thread::sleep(Duration::from_millis(chunk_ms)); + remaining_ms -= chunk_ms; + } +} + +fn consume_stream( + app: &AppHandle, + response: Response, + state: &Arc>, + generation_atomic: &Arc, + generation: u64, + stop: Arc, + stats: &mut DesktopEventTransportStats, +) -> Option { + let (tx, rx) = mpsc::sync_channel::(4096); + let reader_stop = stop.clone(); + let reader_generation_atomic = generation_atomic.clone(); + thread::spawn(move || { + read_sse( + response, + tx, + reader_stop, + reader_generation_atomic, + generation, + ) + }); + + let mut pending = PendingBatch::default(); + let mut active_text_assembler = ActiveTextAssembler::default(); + let mut active_text_snapshots = ActiveTextSnapshotBuffer::default(); + let mut sequence = 0_u64; + let mut last_active_target: Option = None; + let mut last_reader_activity = Instant::now(); + + loop { + if stop.load(Ordering::SeqCst) || !generation_matches(generation_atomic, generation) { + return Some("stopped".to_string()); + } + + match rx.recv_timeout(Duration::from_millis(FLUSH_INTERVAL_MS)) { + Ok(ReaderMessage::Activity) => { + last_reader_activity = Instant::now(); + } + Ok(ReaderMessage::Event(event)) => { + last_reader_activity = Instant::now(); + stats.raw_events = stats.raw_events.saturating_add(1); + + let now = Instant::now(); + let active_target = state.lock().active_target.clone(); + let max_batch_events = if active_target.is_some() { + ACTIVE_SESSION_MAX_BATCH_EVENTS + } else { + MAX_BATCH_EVENTS + }; + let mut should_flush_active = false; + if active_target != last_active_target { + for flushed in active_text_assembler.flush_store_only_all(now) { + pending.push(flushed, stats); + } + for flushed in active_text_snapshots.flush_all() { + pending.push(flushed, stats); + } + last_active_target = active_target.clone(); + } + + let due = active_text_assembler.take_due(now); + if !due.is_empty() { + should_flush_active = true; + } + for flushed in due { + pending.push(flushed, stats); + } + + let snapshot_due = active_text_snapshots.take_due(now); + if !snapshot_due.is_empty() { + should_flush_active = true; + } + for flushed in snapshot_due { + pending.push(flushed, stats); + } + + let flushes = active_text_assembler.flush_for_event(&event, now); + if !flushes.is_empty() { + should_flush_active = true; + } + for flushed in flushes { + pending.push(flushed, stats); + } + + let snapshot_flushes = active_text_snapshots.flush_for_event(&event); + if !snapshot_flushes.is_empty() { + should_flush_active = true; + } + for flushed in snapshot_flushes { + pending.push(flushed, stats); + } + + if let Some(snapshot) = parse_active_text_snapshot(&event, active_target.as_ref()) { + active_text_snapshots.buffer(snapshot, now); + + if should_flush_active { + emit_pending_batch( + app, + generation, + &mut pending, + &mut sequence, + generation_atomic, + stats, + ); + } + + if pending.pending_len() >= max_batch_events { + emit_pending_batch( + app, + generation, + &mut pending, + &mut sequence, + generation_atomic, + stats, + ); + } + continue; + } + + if let Some(delta) = parse_active_text_delta(&event, active_target.as_ref()) { + let assembled_events = active_text_assembler.absorb(delta, now); + if !assembled_events.is_empty() { + should_flush_active = true; + } + for assembled in assembled_events { + pending.push(assembled, stats); + } + + if should_flush_active { + emit_pending_batch( + app, + generation, + &mut pending, + &mut sequence, + generation_atomic, + stats, + ); + } + + if pending.pending_len() >= max_batch_events { + emit_pending_batch( + app, + generation, + &mut pending, + &mut sequence, + generation_atomic, + stats, + ); + } + continue; + } + + pending.push(event, stats); + if should_flush_active { + emit_pending_batch( + app, + generation, + &mut pending, + &mut sequence, + generation_atomic, + stats, + ); + } + if pending.pending_len() >= max_batch_events { + emit_pending_batch( + app, + generation, + &mut pending, + &mut sequence, + generation_atomic, + stats, + ); + } + } + Ok(ReaderMessage::End(reason)) => { + for flushed in active_text_assembler.take_due(Instant::now()) { + pending.push(flushed, stats); + } + for flushed in active_text_assembler.flush_store_only_all(Instant::now()) { + pending.push(flushed, stats); + } + for flushed in active_text_snapshots.take_due(Instant::now()) { + pending.push(flushed, stats); + } + for flushed in active_text_snapshots.flush_all() { + pending.push(flushed, stats); + } + if !pending.is_empty() { + emit_pending_batch( + app, + generation, + &mut pending, + &mut sequence, + generation_atomic, + stats, + ); + } + return reason; + } + Err(RecvTimeoutError::Timeout) => { + if last_reader_activity.elapsed() >= Duration::from_millis(STREAM_STALL_TIMEOUT_MS) + { + for flushed in active_text_assembler.take_due(Instant::now()) { + pending.push(flushed, stats); + } + for flushed in active_text_assembler.flush_store_only_all(Instant::now()) { + pending.push(flushed, stats); + } + for flushed in active_text_snapshots.take_due(Instant::now()) { + pending.push(flushed, stats); + } + for flushed in active_text_snapshots.flush_all() { + pending.push(flushed, stats); + } + if !pending.is_empty() { + sequence += 1; + emit_batch( + app, + generation, + &mut pending, + sequence, + generation_atomic, + stats, + ); + } + return Some("stream stalled".to_string()); + } + + for flushed in active_text_assembler.take_due(Instant::now()) { + pending.push(flushed, stats); + } + for flushed in active_text_snapshots.take_due(Instant::now()) { + pending.push(flushed, stats); + } + if !pending.is_empty() { + if pending.should_hold_single_delta( + Instant::now(), + state.lock().active_target.as_ref(), + ) { + continue; + } + emit_pending_batch( + app, + generation, + &mut pending, + &mut sequence, + generation_atomic, + stats, + ); + } + } + Err(RecvTimeoutError::Disconnected) => { + for flushed in active_text_assembler.take_due(Instant::now()) { + pending.push(flushed, stats); + } + for flushed in active_text_assembler.flush_store_only_all(Instant::now()) { + pending.push(flushed, stats); + } + for flushed in active_text_snapshots.take_due(Instant::now()) { + pending.push(flushed, stats); + } + for flushed in active_text_snapshots.flush_all() { + pending.push(flushed, stats); + } + if !pending.is_empty() { + emit_pending_batch( + app, + generation, + &mut pending, + &mut sequence, + generation_atomic, + stats, + ); + } + return Some("reader disconnected".to_string()); + } + } + } +} + +fn emit_pending_batch( + app: &AppHandle, + generation: u64, + pending: &mut PendingBatch, + sequence: &mut u64, + generation_atomic: &Arc, + stats: &mut DesktopEventTransportStats, +) { + if pending.is_empty() { + return; + } + + *sequence += 1; + emit_batch( + app, + generation, + pending, + *sequence, + generation_atomic, + stats, + ); +} + +fn emit_batch( + app: &AppHandle, + generation: u64, + pending: &mut PendingBatch, + sequence: u64, + generation_atomic: &Arc, + stats: &mut DesktopEventTransportStats, +) { + if !generation_matches(generation_atomic, generation) { + return; + } + + let events = pending.take_events(); + if events.is_empty() { + return; + } + + stats.emitted_batches = stats.emitted_batches.saturating_add(1); + stats.emitted_events = stats.emitted_events.saturating_add(events.len() as u64); + + let _ = app.emit( + EVENT_BATCH_NAME, + WorkspaceEventBatchPayload { + generation, + sequence, + emitted_at: SystemTime::now() + .duration_since(UNIX_EPOCH) + .unwrap_or_default() + .as_millis(), + events, + }, + ); +} + +fn emit_status( + app: &AppHandle, + generation: u64, + state_name: &'static str, + reconnect_attempt: u32, + terminal: bool, + reason: Option, + next_delay_ms: Option, + status_code: Option, + stats: &DesktopEventTransportStats, +) { + let _ = app.emit( + EVENT_STATUS_NAME, + DesktopEventStreamStatusPayload { + generation, + state: state_name, + reconnect_attempt, + terminal, + reason, + next_delay_ms, + status_code, + stats: stats.clone(), + }, + ); +} + +pub(super) fn generation_matches(generation_atomic: &Arc, generation: u64) -> bool { + generation_atomic.load(Ordering::SeqCst) == generation +} + +pub(super) fn compute_reconnect_delay_ms( + attempt: u32, + policy: &ResolvedDesktopEventReconnectPolicy, +) -> u64 { + let exponent = attempt.saturating_sub(1) as i32; + let scaled = (policy.initial_delay_ms as f64) * policy.multiplier.powi(exponent); + (scaled.round().max(policy.initial_delay_ms as f64) as u64).min(policy.max_delay_ms) +} diff --git a/packages/tauri-app/src-tauri/src/main.rs b/packages/tauri-app/src-tauri/src/main.rs index 1bbc83ef..b3f95c1a 100644 --- a/packages/tauri-app/src-tauri/src/main.rs +++ b/packages/tauri-app/src-tauri/src/main.rs @@ -1,8 +1,13 @@ #![cfg_attr(not(debug_assertions), windows_subsystem = "windows")] mod cli_manager; +mod desktop_event_transport; use cli_manager::{CliProcessManager, CliStatus}; +use desktop_event_transport::{ + ActiveSessionTarget, DesktopEventTransportManager, DesktopEventsStartRequest, + DesktopEventsStartResult, +}; use keepawake::KeepAwake; use serde::Deserialize; use serde_json::json; @@ -40,6 +45,7 @@ const WINDOWS_APP_USER_MODEL_ID: &str = "ai.neuralnomads.codenomad.client"; pub struct AppState { pub manager: CliProcessManager, + pub desktop_events: DesktopEventTransportManager, pub wake_lock: Mutex>, pub zoom_level: Mutex, pub remote_origins: Mutex>, @@ -70,6 +76,7 @@ fn cli_get_status(state: tauri::State) -> CliStatus { #[tauri::command] fn cli_restart(app: AppHandle, state: tauri::State) -> Result { let dev_mode = is_dev_mode(); + state.desktop_events.stop(); state.manager.stop().map_err(|e| e.to_string())?; state .manager @@ -78,6 +85,42 @@ fn cli_restart(app: AppHandle, state: tauri::State) -> Result, + request: Option, +) -> DesktopEventsStartResult { + let config = state.manager.desktop_event_stream_config(); + state.desktop_events.start(app, config, request) +} + +#[tauri::command] +fn desktop_events_stop(state: tauri::State) { + state.desktop_events.stop(); +} + +#[tauri::command] +fn desktop_events_set_active_session( + state: tauri::State, + instance_id: Option, + session_id: Option, +) { + let target = match (instance_id, session_id) { + (Some(instance_id), Some(session_id)) + if !instance_id.trim().is_empty() && !session_id.trim().is_empty() => + { + Some(ActiveSessionTarget { + instance_id, + session_id, + }) + } + _ => None, + }; + + state.desktop_events.set_active_session_target(target); +} + #[tauri::command] fn wake_lock_start( state: tauri::State, @@ -359,6 +402,7 @@ fn main() { .plugin(navigation_guard) .manage(AppState { manager: CliProcessManager::new(), + desktop_events: DesktopEventTransportManager::new(), wake_lock: Mutex::new(None), zoom_level: Mutex::new(DEFAULT_ZOOM_LEVEL), remote_origins: Mutex::new(HashMap::new()), @@ -398,6 +442,9 @@ fn main() { .invoke_handler(tauri::generate_handler![ cli_get_status, cli_restart, + desktop_events_start, + desktop_events_stop, + desktop_events_set_active_session, wake_lock_start, wake_lock_stop, open_remote_window @@ -502,6 +549,7 @@ fn main() { let app = app_handle.clone(); std::thread::spawn(move || { if let Some(state) = app.try_state::() { + state.desktop_events.stop(); let _ = state.manager.stop(); } app.exit(0); @@ -553,6 +601,7 @@ fn main() { let app = app_handle.clone(); std::thread::spawn(move || { if let Some(state) = app.try_state::() { + state.desktop_events.stop(); let _ = state.manager.stop(); } app.exit(0); diff --git a/packages/ui/src/App.tsx b/packages/ui/src/App.tsx index 462c0c16..b0270896 100644 --- a/packages/ui/src/App.tsx +++ b/packages/ui/src/App.tsx @@ -25,6 +25,8 @@ import { initReleaseNotifications } from "./stores/releases" import { runtimeEnv } from "./lib/runtime-env" import { useI18n } from "./lib/i18n" import { setWakeLockDesired } from "./lib/native/wake-lock" +import { setTauriDesktopActiveSession } from "./lib/native/desktop-events" +import { clearAssistantStreamSession } from "./stores/assistant-stream" import { isSelectingFolder, setIsSelectingFolder, @@ -249,6 +251,35 @@ const App: Component = () => { return activeSessionId().get(instance.id) || null }) + const activeStreamTarget = createMemo(() => { + const instance = activeInstance() + const sessionId = activeSessionIdForInstance() + if (!instance || !sessionId) return null + return { + instanceId: instance.id, + sessionId, + } + }) + + let previousActiveStreamTarget: ReturnType = null + + createEffect(() => { + if (runtimeEnv.host !== "tauri") { + return + } + + const currentTarget = activeStreamTarget() + if (previousActiveStreamTarget) { + clearAssistantStreamSession( + previousActiveStreamTarget.instanceId, + previousActiveStreamTarget.sessionId, + ) + } + previousActiveStreamTarget = currentTarget + + void setTauriDesktopActiveSession(currentTarget) + }) + const launchErrorPath = () => { const value = launchError()?.binaryPath if (!value) return "opencode" diff --git a/packages/ui/src/components/markdown.tsx b/packages/ui/src/components/markdown.tsx index 56889ca3..c2ee9393 100644 --- a/packages/ui/src/components/markdown.tsx +++ b/packages/ui/src/components/markdown.tsx @@ -1,4 +1,4 @@ -import { createEffect, createMemo, createSignal, onCleanup, onMount } from "solid-js" +import { createEffect, createMemo, createSignal, onCleanup, onMount, untrack } from "solid-js" import { useGlobalCache } from "../lib/hooks/use-global-cache" import type { TextPart, RenderCache } from "../types/message" import { getLogger } from "../lib/logger" @@ -87,12 +87,17 @@ interface MarkdownProps { onRendered?: () => void } +/** Default throttle delay for expensive Shiki re-renders (ms). */ +const MARKDOWN_RENDER_THROTTLE_MS = 120 + export function Markdown(props: MarkdownProps) { const { t } = useI18n() const [html, setHtml] = createSignal("") let containerRef: HTMLDivElement | undefined let latestRequestKey = "" let cleanupLanguageListener: (() => void) | undefined + let renderTimer: ReturnType | undefined + let hasRenderedOnce = false const notifyRendered = () => { Promise.resolve().then(() => props.onRendered?.()) @@ -155,6 +160,45 @@ export function Markdown(props: MarkdownProps) { } } + /** Schedule a Shiki render, throttled after the first paint. */ + let pendingRenderSnapshot: ReturnType | undefined + + const scheduleRender = (snapshot: ReturnType) => { + const doRender = (snap: ReturnType) => { + latestRequestKey = snap.requestKey + void renderSnapshot(snap).catch((error) => { + log.error("Failed to render markdown:", error) + if (latestRequestKey === snap.requestKey) { + commitCacheEntry(snap, renderFallbackHtml(snap.text)) + } + }) + } + + // First render is always immediate to avoid a prolonged fallback flash. + if (!hasRenderedOnce) { + hasRenderedOnce = true + doRender(snapshot) + return + } + + // Subsequent renders are throttled: the timer fires at a fixed cadence + // and always uses the latest pending snapshot. Unlike a debounce, the + // timer is NOT reset when new snapshots arrive, so Shiki re-renders + // periodically (~every MARKDOWN_RENDER_THROTTLE_MS) even during + // continuous streaming — preventing the raw↔markdown flash. + pendingRenderSnapshot = snapshot + if (!renderTimer) { + renderTimer = setTimeout(() => { + renderTimer = undefined + const snap = pendingRenderSnapshot + if (snap) { + pendingRenderSnapshot = undefined + doRender(snap) + } + }, MARKDOWN_RENDER_THROTTLE_MS) + } + } + createEffect(() => { const snapshot = resolved() latestRequestKey = snapshot.requestKey @@ -179,15 +223,15 @@ export function Markdown(props: MarkdownProps) { return } - setHtml(renderFallbackHtml(snapshot.text)) + // Keep the previous rendered markdown visible while Shiki re-renders. + // Only fall back to escaped plain text on the initial render (no prior + // content). This eliminates the raw↔markdown flash during streaming. + if (!untrack(html)) { + setHtml(renderFallbackHtml(snapshot.text)) + } notifyRendered() - void renderSnapshot(snapshot).catch((error) => { - log.error("Failed to render markdown:", error) - if (latestRequestKey === snapshot.requestKey) { - commitCacheEntry(snapshot, renderFallbackHtml(snapshot.text)) - } - }) + scheduleRender(snapshot) }) onMount(() => { @@ -234,9 +278,7 @@ export function Markdown(props: MarkdownProps) { } latestRequestKey = snapshot.requestKey - void renderSnapshot(snapshot).catch((error) => { - log.error("Failed to re-render markdown after language load:", error) - }) + scheduleRender(snapshot) }) }) .catch((error) => { @@ -245,6 +287,9 @@ export function Markdown(props: MarkdownProps) { onCleanup(() => { disposed = true + if (renderTimer) clearTimeout(renderTimer) + renderTimer = undefined + pendingRenderSnapshot = undefined containerRef?.removeEventListener("click", handleClick) cleanupLanguageListener?.() cleanupLanguageListener = undefined diff --git a/packages/ui/src/components/message-part.tsx b/packages/ui/src/components/message-part.tsx index c0bd07cf..7a06a3be 100644 --- a/packages/ui/src/components/message-part.tsx +++ b/packages/ui/src/components/message-part.tsx @@ -3,7 +3,6 @@ import { isItemExpanded, toggleItemExpanded } from "../stores/tool-call-state" import { Markdown } from "./markdown" import { useTheme } from "../lib/theme" import { partHasRenderableText, SDKPart, TextPart, ClientPart } from "../types/message" - type ToolCallPart = Extract const LazyToolCall = lazy(() => import("./tool-call")) @@ -98,6 +97,7 @@ export default function MessagePart(props: MessagePartProps) { const createTextPartForMarkdown = (): TextPart => { const part = props.part + if (part.type === "text" && typeof part.text === "string") { // Pass through the original part so `renderCache` updates persist. return part as unknown as TextPart diff --git a/packages/ui/src/components/prompt-input/usePromptState.ts b/packages/ui/src/components/prompt-input/usePromptState.ts index 3b326f2c..c7c52715 100644 --- a/packages/ui/src/components/prompt-input/usePromptState.ts +++ b/packages/ui/src/components/prompt-input/usePromptState.ts @@ -40,7 +40,6 @@ type PromptState = { } const HISTORY_LIMIT = 100 - export function usePromptState(options: PromptStateOptions): PromptState { const [prompt, setPromptInternal] = createSignal("") const [history, setHistory] = createSignal([]) diff --git a/packages/ui/src/components/virtual-follow-list.tsx b/packages/ui/src/components/virtual-follow-list.tsx index a0dce708..de007486 100644 --- a/packages/ui/src/components/virtual-follow-list.tsx +++ b/packages/ui/src/components/virtual-follow-list.tsx @@ -145,6 +145,31 @@ export default function VirtualFollowList(props: VirtualFollowListProps) { let suppressAutoScrollOnce = false let pendingInitialScroll = true + // PERF: Coalescing rAF wrapper for scrollToBottom. + // Multiple reactive effects (followToken, notifyContentRendered) can trigger + // scrollToBottom synchronously within the same microtask. Each call forces a + // synchronous layout reflow (setting scrollTop reads back computed layout). + // In WebView2 this reflow is extremely expensive (~4.9s). + // By deferring to a single requestAnimationFrame, we batch all scroll requests + // and pay the reflow cost only once, after the DOM has settled. + let pendingAutoScrollRaf: number | undefined + function scheduleAutoScroll() { + if (pendingAutoScrollRaf !== undefined) return + pendingAutoScrollRaf = requestAnimationFrame(() => { + pendingAutoScrollRaf = undefined + if (autoScroll()) { + scrollToBottom(true) + } + }) + } + + onCleanup(() => { + if (pendingAutoScrollRaf !== undefined) { + cancelAnimationFrame(pendingAutoScrollRaf) + pendingAutoScrollRaf = undefined + } + }) + const state: VirtualFollowListState = { autoScroll, showScrollTopButton, @@ -282,7 +307,7 @@ export default function VirtualFollowList(props: VirtualFollowListProps) { }, notifyContentRendered: () => { if (autoScroll()) { - scrollToBottom(true) + scheduleAutoScroll() } }, setAutoScroll: (enabled) => setAutoScroll(Boolean(enabled)), @@ -305,7 +330,7 @@ export default function VirtualFollowList(props: VirtualFollowListProps) { // Handle followToken change createEffect(on(() => props.followToken?.(), () => { if (autoScroll()) { - scrollToBottom(true) + scheduleAutoScroll() } }, { defer: true })) diff --git a/packages/ui/src/lib/event-transport-contract.ts b/packages/ui/src/lib/event-transport-contract.ts new file mode 100644 index 00000000..f7230128 --- /dev/null +++ b/packages/ui/src/lib/event-transport-contract.ts @@ -0,0 +1,78 @@ +export interface DesktopEventTransportReconnectPolicy { + initialDelayMs: number + maxDelayMs: number + multiplier: number + maxAttempts?: number +} + +export interface DesktopEventTransportStartOptions { + reconnect?: Partial +} + +export type DesktopEventTransportState = + | "connecting" + | "connected" + | "disconnected" + | "unauthorized" + | "error" + | "stopped" + +export interface DesktopEventTransportStats { + rawEvents: number + emittedEvents: number + emittedBatches: number + deltaCoalesces: number + snapshotCoalesces: number + statusCoalesces: number + supersededDeltasDropped: number +} + +export interface DesktopEventTransportStatusPayload { + generation: number + state: DesktopEventTransportState + reconnectAttempt: number + terminal: boolean + reason?: string + nextDelayMs?: number + statusCode?: number + stats?: DesktopEventTransportStats +} + +export interface DesktopEventsStartResult { + started: boolean + generation?: number + reason?: string +} + +export interface DesktopEventActiveSessionTarget { + instanceId: string + sessionId: string +} + +export interface AssistantStreamChunkEvent { + type: "assistant.stream.chunk" + properties: { + sessionID: string + messageID: string + partID: string + field: "text" + delta: string + } +} + +export const DEFAULT_DESKTOP_EVENT_RECONNECT_POLICY: DesktopEventTransportReconnectPolicy = { + initialDelayMs: 1000, + maxDelayMs: 10000, + multiplier: 2, +} + +export function resolveDesktopEventTransportStartOptions( + options?: DesktopEventTransportStartOptions, +): Required { + return { + reconnect: { + ...DEFAULT_DESKTOP_EVENT_RECONNECT_POLICY, + ...options?.reconnect, + }, + } +} diff --git a/packages/ui/src/lib/event-transport.ts b/packages/ui/src/lib/event-transport.ts new file mode 100644 index 00000000..fe8f89b6 --- /dev/null +++ b/packages/ui/src/lib/event-transport.ts @@ -0,0 +1,67 @@ +import type { WorkspaceEventPayload } from "../../../server/src/api-types" +import { serverApi } from "./api-client" +import { + resolveDesktopEventTransportStartOptions, + type DesktopEventTransportStartOptions, +} from "./event-transport-contract" +import { getLogger } from "./logger" +import { runtimeEnv } from "./runtime-env" +import { connectTauriWorkspaceEvents } from "./native/desktop-events" + +const log = getLogger("sse") + +export interface WorkspaceEventTransportCallbacks { + onBatch: (events: WorkspaceEventPayload[]) => void + onError?: () => void + onOpen?: () => void + onPing?: (payload: { ts?: number }) => void +} + +export interface WorkspaceEventConnection { + disconnect: () => void +} + +async function connectBrowserWorkspaceEvents( + callbacks: WorkspaceEventTransportCallbacks, +): Promise { + const source = serverApi.connectEvents((event) => { + callbacks.onBatch([event]) + }, callbacks.onError, callbacks.onPing) + source.onopen = () => callbacks.onOpen?.() + return { + disconnect() { + source.close() + }, + } +} + +export async function connectWorkspaceEvents( + callbacks: WorkspaceEventTransportCallbacks, + options?: DesktopEventTransportStartOptions, +): Promise { + if (runtimeEnv.host === "tauri") { + try { + const conn = await connectTauriWorkspaceEvents( + callbacks, + resolveDesktopEventTransportStartOptions(options), + ) + ;(globalThis as any).__TRANSPORT_TYPE = "rust-native" + log.info("Event transport: rust-native (desktop_event_transport)") + return conn + } catch (error) { + log.warn("Failed to start native desktop event transport, falling back to browser EventSource", error) + } + } + + ;(globalThis as any).__TRANSPORT_TYPE = "browser-eventsource" + log.info(`Event transport: browser-eventsource (host=${runtimeEnv.host})`) + return connectBrowserWorkspaceEvents(callbacks) +} + +export type { + DesktopEventsStartResult, + DesktopEventTransportReconnectPolicy, + DesktopEventTransportStartOptions, + DesktopEventTransportState, + DesktopEventTransportStatusPayload, +} from "./event-transport-contract" diff --git a/packages/ui/src/lib/markdown.ts b/packages/ui/src/lib/markdown.ts index c83c6894..c3adf62a 100644 --- a/packages/ui/src/lib/markdown.ts +++ b/packages/ui/src/lib/markdown.ts @@ -22,6 +22,9 @@ const queuedLanguages = new Set() const languageLoadQueue: Array<() => Promise> = [] let isQueueRunning = false +// Cache for ensureLanguages to avoid redundant marked.lexer() calls +let lastEnsuredContent = "" + // Pub/sub mechanism for language loading notifications const languageListeners: Array<() => void> = [] @@ -165,6 +168,24 @@ async function ensureLanguages(content: string) { return } + if (content === lastEnsuredContent) { + return + } + + if ( + lastEnsuredContent.length > 0 && + content.length > lastEnsuredContent.length && + content.startsWith(lastEnsuredContent) + ) { + const newPortion = content.slice(lastEnsuredContent.length) + if (!newPortion.includes("```")) { + lastEnsuredContent = content + return + } + } + + lastEnsuredContent = content + // Extract code-fence language tokens via `marked` so we correctly handle code blocks // that contain backticks (e.g. JS template literals). Regex-based fence scans tend // to miss these and prevent languages from loading. @@ -227,10 +248,11 @@ async function runLanguageLoadQueue() { isQueueRunning = true while (languageLoadQueue.length > 0) { - const task = languageLoadQueue.shift() - if (task) { - await task() - } + // Drain the current queue and run all pending loads concurrently. + // New languages queued while this batch runs will be picked up by + // the outer while-loop on the next iteration. + const batch = languageLoadQueue.splice(0) + await Promise.allSettled(batch.map((task) => task())) } isQueueRunning = false diff --git a/packages/ui/src/lib/native/desktop-events.ts b/packages/ui/src/lib/native/desktop-events.ts new file mode 100644 index 00000000..b033e19a --- /dev/null +++ b/packages/ui/src/lib/native/desktop-events.ts @@ -0,0 +1,163 @@ +import { invoke } from "@tauri-apps/api/core" +import { listen } from "@tauri-apps/api/event" +import type { WorkspaceEventPayload } from "../../../../server/src/api-types" +import type { + DesktopEventActiveSessionTarget, + DesktopEventsStartResult, + DesktopEventTransportStartOptions, + DesktopEventTransportStatusPayload, +} from "../event-transport-contract" +import type { WorkspaceEventConnection, WorkspaceEventTransportCallbacks } from "../event-transport" +import { getLogger } from "../logger" + +const log = getLogger("sse") + +interface WorkspaceEventBatchPayload { + generation: number + sequence: number + emittedAt: number + events: WorkspaceEventPayload[] +} + +export async function connectTauriWorkspaceEvents( + callbacks: WorkspaceEventTransportCallbacks, + options: DesktopEventTransportStartOptions, +): Promise { + let closed = false + let opened = false + let expectedGeneration: number | null = null + let terminalErrorRaised = false + const pendingBatches: WorkspaceEventBatchPayload[] = [] + const pendingStatuses: DesktopEventTransportStatusPayload[] = [] + + const matchesGeneration = (generation: number) => expectedGeneration === generation + + const handleBatchPayload = (payload: WorkspaceEventBatchPayload) => { + if (!payload || !matchesGeneration(payload.generation)) return + + if (!opened) { + opened = true + callbacks.onOpen?.() + } + + const events = payload.events ?? [] + if (events.length === 0) { + return + } + + callbacks.onBatch(events) + } + + const handleStatusPayload = (payload: DesktopEventTransportStatusPayload) => { + if (!payload || !matchesGeneration(payload.generation)) return + + if (payload.state === "connected" && !opened) { + opened = true + callbacks.onOpen?.() + } + + if (payload.state === "unauthorized") { + log.warn("Native desktop event transport is waiting for authentication", { + reason: payload.reason, + reconnectAttempt: payload.reconnectAttempt, + nextDelayMs: payload.nextDelayMs, + stats: payload.stats, + }) + } else if (payload.state === "error") { + log.warn("Native desktop event transport reported an error", { + reason: payload.reason, + reconnectAttempt: payload.reconnectAttempt, + nextDelayMs: payload.nextDelayMs, + statusCode: payload.statusCode, + stats: payload.stats, + }) + } else if ((payload.state === "disconnected" || payload.state === "stopped") && payload.stats) { + log.info("Native desktop event transport stats", { + state: payload.state, + reconnectAttempt: payload.reconnectAttempt, + stats: payload.stats, + }) + } + + if (payload.state === "stopped") { + callbacks.onError?.() + return + } + + if (payload.terminal && !terminalErrorRaised) { + terminalErrorRaised = true + callbacks.onError?.() + } + } + + const flushPending = () => { + if (expectedGeneration === null) return + for (const payload of pendingStatuses.splice(0, pendingStatuses.length)) { + handleStatusPayload(payload) + } + for (const payload of pendingBatches.splice(0, pendingBatches.length)) { + handleBatchPayload(payload) + } + } + + const unlistenBatch = await listen("desktop:event-batch", (event) => { + if (closed) return + const payload = event.payload + if (!payload) return + if (expectedGeneration === null) { + pendingBatches.push(payload) + return + } + handleBatchPayload(payload) + }) + + const unlistenStatus = await listen("desktop:event-stream-status", (event) => { + if (closed) return + const payload = event.payload + if (!payload) return + if (expectedGeneration === null) { + pendingStatuses.push(payload) + return + } + handleStatusPayload(payload) + }) + + try { + const result = await invoke("desktop_events_start", { request: options }) + if (!result?.started) { + throw new Error(result?.reason ?? "desktop event transport unavailable") + } + expectedGeneration = result.generation ?? null + flushPending() + } catch (error) { + unlistenBatch() + unlistenStatus() + throw error + } + + return { + disconnect() { + if (closed) { + return + } + + closed = true + unlistenBatch() + unlistenStatus() + void invoke("desktop_events_stop").catch((error) => { + log.warn("Failed to stop native desktop event transport", error) + }) + }, + } +} + +export async function setTauriDesktopActiveSession(target: DesktopEventActiveSessionTarget | null): Promise { + try { + await invoke("desktop_events_set_active_session", { + instanceId: target?.instanceId ?? null, + sessionId: target?.sessionId ?? null, + }) + } catch (error) { + log.warn("Failed to update native desktop active session", error) + } +} diff --git a/packages/ui/src/lib/sdk-manager.ts b/packages/ui/src/lib/sdk-manager.ts index 2e4ee3c1..b4a6d530 100644 --- a/packages/ui/src/lib/sdk-manager.ts +++ b/packages/ui/src/lib/sdk-manager.ts @@ -1,5 +1,40 @@ import { createOpencodeClient, type OpencodeClient } from "@opencode-ai/sdk/v2/client" import { CODENOMAD_API_BASE } from "./api-client" +import { getLogger } from "./logger" + +const log = getLogger("api") + +/** + * Instrumented fetch wrapper for the SDK client. + * Logs method, URL, status, and elapsed time (ms) using performance.now() + * so we can compare WebView2 (Tauri) vs Chromium (Electron) fetch latency. + */ +function createInstrumentedFetch(): typeof globalThis.fetch { + return (async (input: RequestInfo | URL, init?: RequestInit): Promise => { + // The SDK always passes a Request object, but we handle other forms too + const request = input instanceof Request ? input : new Request(input, init) + ;(request as any).timeout = false + const method = request.method + const url = request.url + const t0 = performance.now() + log.info(`[sdk-fetch] ${method} ${url}`) + try { + const response = await fetch(request) + const elapsed = performance.now() - t0 + log.info(`[sdk-fetch] ${method} ${url} -> ${response.status}`, { + durationMs: Math.round(elapsed * 100) / 100, + }) + return response + } catch (error) { + const elapsed = performance.now() - t0 + log.info(`[sdk-fetch] ${method} ${url} FAILED`, { + durationMs: Math.round(elapsed * 100) / 100, + error, + }) + throw error + } + }) as typeof globalThis.fetch +} class SDKManager { private clients = new Map() @@ -16,7 +51,7 @@ class SDKManager { } const baseUrl = buildInstanceBaseUrl(proxyPath) - const client = createOpencodeClient({ baseUrl }) + const client = createOpencodeClient({ baseUrl, fetch: createInstrumentedFetch() }) this.clients.set(key, client) diff --git a/packages/ui/src/lib/server-events.ts b/packages/ui/src/lib/server-events.ts index 68db4476..6095c697 100644 --- a/packages/ui/src/lib/server-events.ts +++ b/packages/ui/src/lib/server-events.ts @@ -1,6 +1,6 @@ +import { batch as solidBatch } from "solid-js" import type { WorkspaceEventPayload, WorkspaceEventType } from "../../../server/src/api-types" -import { serverApi } from "./api-client" -import { getClientIdentity } from "./client-identity" +import { connectWorkspaceEvents, type WorkspaceEventConnection } from "./event-transport" import { getLogger } from "./logger" const RETRY_BASE_DELAY = 1000 @@ -18,57 +18,105 @@ function logSse(message: string, context?: Record) { class ServerEvents { private handlers = new Map void>>() private openHandlers = new Set<() => void>() - private source: EventSource | null = null + private connection: WorkspaceEventConnection | null = null + private connectGeneration = 0 private retryDelay = RETRY_BASE_DELAY + private retryTimer: ReturnType | null = null constructor() { - this.connect() + void this.connect() } - private connect() { - if (this.source) { - this.source.close() + private async connect() { + const generation = ++this.connectGeneration + this.clearReconnectTimer() + + if (this.connection) { + this.connection.disconnect() + this.connection = null } + logSse("Connecting to backend events stream") - this.source = serverApi.connectEvents( - (event) => this.dispatch(event), - () => this.scheduleReconnect(), - (payload) => { - void serverApi - .sendClientConnectionPong({ - ...getClientIdentity(), - pingTs: payload.ts, - }) - .catch((error) => { - log.error("Failed to send client connection pong", error) - }) - }, - ) - this.source.onopen = () => { - logSse("Events stream connected") - this.retryDelay = RETRY_BASE_DELAY - this.openHandlers.forEach((handler) => handler()) + + try { + const connection = await connectWorkspaceEvents({ + onBatch: (events) => this.dispatchBatch(events), + onError: () => { + if (generation !== this.connectGeneration) { + return + } + this.scheduleReconnect() + }, + onOpen: () => { + if (generation !== this.connectGeneration) { + return + } + logSse("Events stream connected") + this.retryDelay = RETRY_BASE_DELAY + this.openHandlers.forEach((handler) => handler()) + }, + }) + + if (generation !== this.connectGeneration) { + connection.disconnect() + return + } + + this.connection = connection + } catch (error) { + logSse("Events stream failed to connect, scheduling reconnect", { + error: error instanceof Error ? error.message : String(error), + }) + this.scheduleReconnect() } } private scheduleReconnect() { - if (this.source) { - this.source.close() - this.source = null + if (this.retryTimer) { + return } + + if (this.connection) { + this.connection.disconnect() + this.connection = null + } + logSse("Events stream disconnected, scheduling reconnect", { delayMs: this.retryDelay }) - setTimeout(() => { + this.retryTimer = setTimeout(() => { + this.retryTimer = null this.retryDelay = Math.min(this.retryDelay * 2, RETRY_MAX_DELAY) - this.connect() + void this.connect() }, this.retryDelay) } + private clearReconnectTimer() { + if (!this.retryTimer) { + return + } + + clearTimeout(this.retryTimer) + this.retryTimer = null + } + private dispatch(event: WorkspaceEventPayload) { - logSse(`event ${event.type}`) this.handlers.get("*")?.forEach((handler) => handler(event)) this.handlers.get(event.type)?.forEach((handler) => handler(event)) } + private dispatchBatch(events: WorkspaceEventPayload[]) { + if (events.length === 0) { + return + } + + logSse("event batch", { size: events.length }) + + solidBatch(() => { + for (const event of events) { + this.dispatch(event) + } + }) + } + on(type: WorkspaceEventType | "*", handler: (event: WorkspaceEventPayload) => void): () => void { if (!this.handlers.has(type)) { this.handlers.set(type, new Set()) diff --git a/packages/ui/src/lib/sse-manager.ts b/packages/ui/src/lib/sse-manager.ts index e6354e2c..04240d2a 100644 --- a/packages/ui/src/lib/sse-manager.ts +++ b/packages/ui/src/lib/sse-manager.ts @@ -1,4 +1,5 @@ import { createSignal } from "solid-js" +import type { AssistantStreamChunkEvent } from "./event-transport-contract" import { MessageUpdateEvent, MessageRemovedEvent, @@ -67,6 +68,7 @@ type SSEEvent = | MessagePartUpdatedEvent | MessagePartRemovedEvent | MessagePartDeltaEvent + | AssistantStreamChunkEvent | EventSessionUpdated | EventSessionCompacted | EventSessionDiff @@ -119,8 +121,6 @@ class SSEManager { return } - log.info("Received event", { type: event.type, event }) - switch (event.type) { case "message.updated": this.onMessageUpdate?.(instanceId, event as MessageUpdateEvent) @@ -131,6 +131,9 @@ class SSEManager { case "message.part.delta": this.onMessagePartDelta?.(instanceId, event as MessagePartDeltaEvent) break + case "assistant.stream.chunk": + this.onAssistantStreamChunk?.(instanceId, event as AssistantStreamChunkEvent) + break case "message.removed": this.onMessageRemoved?.(instanceId, event as MessageRemovedEvent) break @@ -201,6 +204,7 @@ class SSEManager { onMessageRemoved?: (instanceId: string, event: MessageRemovedEvent) => void onMessagePartUpdated?: (instanceId: string, event: MessagePartUpdatedEvent) => void onMessagePartDelta?: (instanceId: string, event: MessagePartDeltaEvent) => void + onAssistantStreamChunk?: (instanceId: string, event: AssistantStreamChunkEvent) => void onMessagePartRemoved?: (instanceId: string, event: MessagePartRemovedEvent) => void onSessionUpdate?: (instanceId: string, event: EventSessionUpdated) => void onSessionCompacted?: (instanceId: string, event: EventSessionCompacted) => void diff --git a/packages/ui/src/stores/assistant-stream.ts b/packages/ui/src/stores/assistant-stream.ts new file mode 100644 index 00000000..e457b6c3 --- /dev/null +++ b/packages/ui/src/stores/assistant-stream.ts @@ -0,0 +1,108 @@ +import { createSignal } from "solid-js" +import type { AssistantStreamChunkEvent } from "../lib/event-transport-contract" + +interface StreamEntry { + text: string + get: () => string + set: (value: string) => void +} + +const streamEntries = new Map() + +function makeKey(instanceId: string, sessionId: string, messageId: string, partId: string) { + return `${instanceId}:${sessionId}:${messageId}:${partId}` +} + +function getOrCreateEntry(key: string): StreamEntry { + let entry = streamEntries.get(key) + if (!entry) { + const [get, set] = createSignal("") + entry = { text: "", get, set: (v: string) => set(v) } + streamEntries.set(key, entry) + } + return entry +} + +export function appendAssistantStreamChunk(instanceId: string, event: AssistantStreamChunkEvent) { + const props = event.properties + if (!props?.sessionID || !props?.messageID || !props.partID || typeof props.delta !== "string") { + return + } + if (props.delta.length === 0) return + + const key = makeKey(instanceId, props.sessionID, props.messageID, props.partID) + const entry = getOrCreateEntry(key) + entry.text += props.delta + entry.set(entry.text) +} + +export function getAssistantStreamPreviewText( + instanceId: string, + sessionId: string | undefined, + messageId: string | undefined, + partId: string | undefined, +) { + if (!sessionId || !messageId || !partId) return undefined + const key = makeKey(instanceId, sessionId, messageId, partId) + const entry = streamEntries.get(key) + if (!entry) return undefined + // Subscribe to this specific key's signal + return entry.get() || undefined +} + +export function clearAssistantStreamMessage( + instanceId: string, + sessionId: string | undefined, + messageId: string | undefined, +) { + if (!sessionId || !messageId) return + const prefix = `${instanceId}:${sessionId}:${messageId}:` + for (const [key, entry] of streamEntries) { + if (key.startsWith(prefix)) { + entry.set("") + streamEntries.delete(key) + } + } +} + +export function clearAssistantStreamPart( + instanceId: string, + sessionId: string | undefined, + messageId: string | undefined, + partId: string | undefined, +) { + if (!sessionId || !messageId || !partId) return + const key = makeKey(instanceId, sessionId, messageId, partId) + const entry = streamEntries.get(key) + if (!entry) return + entry.set("") + streamEntries.delete(key) +} + +export function clearAssistantStreamAll() { + for (const entry of streamEntries.values()) { + entry.set("") + } + streamEntries.clear() +} + +export function clearAssistantStreamInstance(instanceId: string) { + const prefix = `${instanceId}:` + for (const [key, entry] of streamEntries) { + if (key.startsWith(prefix)) { + entry.set("") + streamEntries.delete(key) + } + } +} + +export function clearAssistantStreamSession(instanceId: string, sessionId: string | undefined) { + if (!sessionId) return + const prefix = `${instanceId}:${sessionId}:` + for (const [key, entry] of streamEntries) { + if (key.startsWith(prefix)) { + entry.set("") + streamEntries.delete(key) + } + } +} diff --git a/packages/ui/src/stores/instances.ts b/packages/ui/src/stores/instances.ts index c6b49482..19b9943c 100644 --- a/packages/ui/src/stores/instances.ts +++ b/packages/ui/src/stores/instances.ts @@ -32,6 +32,7 @@ import { setSessionPendingPermission, setSessionPendingQuestion } from "./sessio import { setHasInstances } from "./ui" import { messageStoreBus } from "./message-v2/bus" import { upsertPermissionV2, removePermissionV2, upsertQuestionV2, removeQuestionV2 } from "./message-v2/bridge" +import { clearAssistantStreamInstance } from "./assistant-stream" import { clearCacheForInstance } from "../lib/global-cache" import { getLogger } from "../lib/logger" import { mergeInstanceMetadata, clearInstanceMetadata } from "./instance-metadata" @@ -1048,6 +1049,8 @@ sseManager.onConnectionLost = (instanceId, reason) => { return } + clearAssistantStreamInstance(instanceId) + setDisconnectedInstance({ id: instanceId, folder: instance.folder, @@ -1083,11 +1086,13 @@ sseManager.onInstanceDisposed = (sourceInstanceId, event) => { } if (matchingInstanceIds.length === 0) { + clearAssistantStreamInstance(sourceInstanceId) void rehydrateInstance(sourceInstanceId, { reason: "disposed" }) return } for (const instanceId of matchingInstanceIds) { + clearAssistantStreamInstance(instanceId) void rehydrateInstance(instanceId, { reason: "disposed" }) } } diff --git a/packages/ui/src/stores/message-v2/bridge.ts b/packages/ui/src/stores/message-v2/bridge.ts index 7b1b35be..b80a572a 100644 --- a/packages/ui/src/stores/message-v2/bridge.ts +++ b/packages/ui/src/stores/message-v2/bridge.ts @@ -1,7 +1,7 @@ import type { PermissionRequestLike } from "../../types/permission" -import { getPermissionCallId, getPermissionMessageId } from "../../types/permission" +import { getPermissionCallId, getPermissionMessageId, getPermissionSessionId } from "../../types/permission" import type { QuestionRequest } from "../../types/question" -import { getQuestionCallId, getQuestionMessageId } from "../../types/question" +import { getQuestionCallId, getQuestionMessageId, getQuestionSessionId } from "../../types/question" import type { Message, MessageInfo, ClientPart } from "../../types/message" import type { Session } from "../../types/session" import { messageStoreBus } from "./bus" @@ -70,6 +70,7 @@ export function seedSessionMessagesV2( interface MessageInfoOptions { status?: MessageStatus bumpRevision?: boolean + isEphemeral?: boolean } export function upsertMessageInfoV2(instanceId: string, info: MessageInfo | null | undefined, options?: MessageInfoOptions): void { @@ -89,6 +90,7 @@ export function upsertMessageInfoV2(instanceId: string, info: MessageInfo | null createdAt, updatedAt: endAt ?? createdAt, bumpRevision: Boolean(options?.bumpRevision), + ...(typeof options?.isEphemeral === "boolean" ? { isEphemeral: options.isEphemeral } : {}), }) store.setMessageInfo(info.id, info) } @@ -148,6 +150,8 @@ function extractPermissionCallId(permission: PermissionRequestLike): string | un function resolvePartIdFromCallId(store: ReturnType, messageId?: string, callId?: string): string | undefined { if (!messageId || !callId) return undefined + const indexed = store.resolveToolCallPartId(messageId, callId) + if (indexed) return indexed const record = store.getMessage(messageId) if (!record) return undefined for (const partId of record.partIds) { @@ -170,36 +174,39 @@ export function upsertPermissionV2(instanceId: string, permission: PermissionReq if (!permission) return const store = messageStoreBus.getOrCreate(instanceId) const messageId = extractPermissionMessageId(permission) + const sessionId = getPermissionSessionId(permission) + const callId = extractPermissionCallId(permission) let partId = extractPermissionPartId(permission) if (!partId) { - const callId = extractPermissionCallId(permission) partId = resolvePartIdFromCallId(store, messageId, callId) } store.upsertPermission({ permission, messageId, partId, + sessionId, + callId, enqueuedAt: (permission as any).time?.created ?? Date.now(), }) } export function reconcilePendingPermissionsV2(instanceId: string, sessionId?: string): void { const store = messageStoreBus.getOrCreate(instanceId) - const pending = store.state.permissions.queue + const pending = store.getPendingPermissionEntries(sessionId) if (!pending || pending.length === 0) return for (const entry of pending) { - if (!entry || entry.partId) continue + if (!entry) continue const permission = entry.permission if (!permission) continue - const permissionSessionId = (permission as any)?.sessionID ?? (permission as any)?.sessionId ?? undefined + const permissionSessionId = entry.sessionId ?? (permission as any)?.sessionID ?? (permission as any)?.sessionId ?? undefined if (sessionId && permissionSessionId && permissionSessionId !== sessionId) { continue } const messageId = entry.messageId ?? extractPermissionMessageId(permission) - const callId = extractPermissionCallId(permission) + const callId = entry.callId ?? extractPermissionCallId(permission) const resolvedPartId = resolvePartIdFromCallId(store, messageId, callId) if (!resolvedPartId) continue @@ -223,6 +230,7 @@ export function upsertQuestionV2(instanceId: string, request: QuestionRequest): if (!request) return const store = messageStoreBus.getOrCreate(instanceId) const messageId = extractQuestionMessageId(request) + const sessionId = getQuestionSessionId(request) let partId: string | undefined = undefined const callId = extractQuestionCallId(request) if (callId) { @@ -232,27 +240,29 @@ export function upsertQuestionV2(instanceId: string, request: QuestionRequest): request, messageId, partId, + sessionId, + callId, enqueuedAt: (request as any).time?.created ?? Date.now(), }) } export function reconcilePendingQuestionsV2(instanceId: string, sessionId?: string): void { const store = messageStoreBus.getOrCreate(instanceId) - const pending = store.state.questions.queue + const pending = store.getPendingQuestionEntries(sessionId) if (!pending || pending.length === 0) return for (const entry of pending) { - if (!entry || entry.partId) continue + if (!entry) continue const request = entry.request if (!request) continue - const questionSessionId = request.sessionID + const questionSessionId = entry.sessionId ?? request.sessionID if (sessionId && questionSessionId && questionSessionId !== sessionId) { continue } const messageId = entry.messageId ?? extractQuestionMessageId(request) - const callId = extractQuestionCallId(request) + const callId = entry.callId ?? extractQuestionCallId(request) const resolvedPartId = resolvePartIdFromCallId(store, messageId, callId) if (!resolvedPartId) continue diff --git a/packages/ui/src/stores/message-v2/instance-store.ts b/packages/ui/src/stores/message-v2/instance-store.ts index ae2a3b3c..e09cd964 100644 --- a/packages/ui/src/stores/message-v2/instance-store.ts +++ b/packages/ui/src/stores/message-v2/instance-store.ts @@ -206,9 +206,11 @@ export interface InstanceMessageStore { setMessageInfo: (messageId: string, info: MessageInfo) => void getMessageInfo: (messageId: string) => MessageInfo | undefined upsertPermission: (entry: PermissionEntry) => void + getPendingPermissionEntries: (sessionId?: string) => PermissionEntry[] removePermission: (permissionId: string) => void getPermissionState: (messageId?: string, partId?: string) => { entry: PermissionEntry; active: boolean } | null upsertQuestion: (entry: QuestionEntry) => void + getPendingQuestionEntries: (sessionId?: string) => QuestionEntry[] removeQuestion: (requestId: string) => void getQuestionState: (messageId?: string, partId?: string) => { entry: QuestionEntry; active: boolean } | null setSessionRevert: (sessionId: string, revert?: SessionRecord["revert"] | null) => void @@ -220,6 +222,8 @@ export interface InstanceMessageStore { getSessionRevision: (sessionId: string) => number getSessionMessageIds: (sessionId: string) => string[] getLastAssistantMessageId: (sessionId: string) => string | undefined + getPendingSyntheticMessageId: (sessionId: string, role: MessageRecord["role"]) => string | undefined + resolveToolCallPartId: (messageId: string, callId?: string) => string | undefined // Index of the most recent message in the session that contains a compaction part. // Returns -1 if there has been no compaction. getLastCompactionMessageIndex: (sessionId: string) => number @@ -235,6 +239,157 @@ export function createInstanceMessageStore(instanceId: string, hooks?: MessageSt const TODO_TOOL_NAME = "todowrite" const messageInfoCache = new Map() + const toolCallPartIndex = new Map() + const pendingPermissionEntriesById = new Map() + const pendingPermissionIdsBySession = new Map>() + const pendingQuestionEntriesById = new Map() + const pendingQuestionIdsBySession = new Map>() + const pendingSyntheticMessageIdsBySessionRole = new Map() + const pendingSyntheticKeyByMessageId = new Map() + const pendingSessionRevisionBumps = new Set() + let sessionRevisionFlushQueued = false + + function makePendingSyntheticKey(sessionId: string, role: MessageRecord["role"]) { + return `${sessionId}:${role}` + } + + function isPendingSyntheticRecord(record: MessageRecord | undefined) { + return Boolean(record?.isEphemeral && (record?.status === "sending" || record?.status === "sent")) + } + + function removePendingSyntheticMessageId(messageId: string | undefined) { + if (!messageId) return + const key = pendingSyntheticKeyByMessageId.get(messageId) + if (!key) return + pendingSyntheticKeyByMessageId.delete(messageId) + const existing = pendingSyntheticMessageIdsBySessionRole.get(key) + if (!existing) return + const next = existing.filter((id) => id !== messageId) + if (next.length === 0) { + pendingSyntheticMessageIdsBySessionRole.delete(key) + return + } + pendingSyntheticMessageIdsBySessionRole.set(key, next) + } + + function addPendingSyntheticRecord(record: MessageRecord | undefined) { + if (!record || !isPendingSyntheticRecord(record)) return + const key = makePendingSyntheticKey(record.sessionId, record.role) + const existing = pendingSyntheticMessageIdsBySessionRole.get(key) ?? [] + if (!existing.includes(record.id)) { + pendingSyntheticMessageIdsBySessionRole.set(key, [...existing, record.id]) + } + pendingSyntheticKeyByMessageId.set(record.id, key) + } + + function syncPendingSyntheticRecord(nextRecord: MessageRecord | undefined, previousMessageId?: string) { + removePendingSyntheticMessageId(previousMessageId ?? nextRecord?.id) + addPendingSyntheticRecord(nextRecord) + } + + function getPendingSyntheticMessageId(sessionId: string, role: MessageRecord["role"]) { + if (!sessionId) return undefined + const ids = pendingSyntheticMessageIdsBySessionRole.get(makePendingSyntheticKey(sessionId, role)) + return ids?.[0] + } + + function addIdToSessionIndex(index: Map>, sessionId: string | undefined, entryId: string) { + if (!sessionId || !entryId) return + const set = index.get(sessionId) ?? new Set() + set.add(entryId) + index.set(sessionId, set) + } + + function removeIdFromSessionIndex(index: Map>, sessionId: string | undefined, entryId: string) { + if (!sessionId || !entryId) return + const set = index.get(sessionId) + if (!set) return + set.delete(entryId) + if (set.size === 0) { + index.delete(sessionId) + } + } + + function updatePermissionPendingIndex(previous: PermissionEntry | undefined, next: PermissionEntry) { + const permissionId = next.permission?.id + if (!permissionId) return + if (previous) { + removeIdFromSessionIndex(pendingPermissionIdsBySession, previous.sessionId, permissionId) + } + pendingPermissionEntriesById.set(permissionId, next) + if (!next.partId) { + addIdToSessionIndex(pendingPermissionIdsBySession, next.sessionId, permissionId) + } + } + + function removePermissionPendingIndex(entry: PermissionEntry | undefined) { + const permissionId = entry?.permission?.id + if (!permissionId) return + removeIdFromSessionIndex(pendingPermissionIdsBySession, entry?.sessionId, permissionId) + pendingPermissionEntriesById.delete(permissionId) + } + + function updateQuestionPendingIndex(previous: QuestionEntry | undefined, next: QuestionEntry) { + const requestId = next.request?.id + if (!requestId) return + if (previous) { + removeIdFromSessionIndex(pendingQuestionIdsBySession, previous.sessionId, requestId) + } + pendingQuestionEntriesById.set(requestId, next) + if (!next.partId) { + addIdToSessionIndex(pendingQuestionIdsBySession, next.sessionId, requestId) + } + } + + function removeQuestionPendingIndex(entry: QuestionEntry | undefined) { + const requestId = entry?.request?.id + if (!requestId) return + removeIdFromSessionIndex(pendingQuestionIdsBySession, entry?.sessionId, requestId) + pendingQuestionEntriesById.delete(requestId) + } + + function getToolCallId(part: ClientPart | undefined): string | undefined { + if (!part || part.type !== "tool") { + return undefined + } + + return ( + (part as any).callID ?? + (part as any).callId ?? + (part as any).toolCallID ?? + (part as any).toolCallId ?? + undefined + ) + } + + function makeToolCallPartKey(messageId: string, callId: string) { + return `${messageId}:${callId}` + } + + function clearToolCallPartIndexForMessage(record: MessageRecord | undefined) { + if (!record) return + for (const partId of record.partIds) { + const part = record.parts[partId]?.data + const callId = getToolCallId(part) + if (!callId) continue + toolCallPartIndex.delete(makeToolCallPartKey(record.id, callId)) + } + } + + function indexToolCallPartsForMessage(record: MessageRecord | undefined) { + if (!record) return + for (const partId of record.partIds) { + const part = record.parts[partId]?.data + const callId = getToolCallId(part) + if (!callId) continue + toolCallPartIndex.set(makeToolCallPartKey(record.id, callId), partId) + } + } + + function resolveToolCallPartId(messageId: string, callId?: string) { + if (!messageId || !callId) return undefined + return toolCallPartIndex.get(makeToolCallPartKey(messageId, callId)) + } function findLastAssistantMessageId(messageIds: readonly string[]): string | undefined { for (let index = messageIds.length - 1; index >= 0; index -= 1) { @@ -319,6 +474,33 @@ export function createInstanceMessageStore(instanceId: string, hooks?: MessageSt setState("sessionRevisions", sessionId, (value = 0) => value + 1) } + function flushPendingSessionRevisionBumps() { + sessionRevisionFlushQueued = false + if (pendingSessionRevisionBumps.size === 0) { + return + } + + const pending = Array.from(pendingSessionRevisionBumps) + pendingSessionRevisionBumps.clear() + + batch(() => { + for (const sessionId of pending) { + bumpSessionRevision(sessionId) + } + }) + } + + function scheduleSessionRevisionBump(sessionId: string) { + if (!sessionId) return + pendingSessionRevisionBumps.add(sessionId) + if (sessionRevisionFlushQueued) { + return + } + + sessionRevisionFlushQueued = true + queueMicrotask(flushPendingSessionRevisionBumps) + } + function getSessionRevisionValue(sessionId: string) { return state.sessionRevisions[sessionId] ?? 0 } @@ -440,7 +622,10 @@ export function createInstanceMessageStore(instanceId: string, hooks?: MessageSt } Object.entries(normalizedRecords).forEach(([id, record]) => { + clearToolCallPartIndexForMessage(state.messages[id]) + syncPendingSyntheticRecord(record, id) nextMessages[id] = record + indexToolCallPartsForMessage(record) }) if (infoList) { @@ -473,7 +658,7 @@ export function createInstanceMessageStore(instanceId: string, hooks?: MessageSt maybeUpdateLatestTodoFromRecord(record) }) - bumpSessionRevision(sessionId) + scheduleSessionRevisionBump(sessionId) }) } @@ -512,35 +697,44 @@ export function createInstanceMessageStore(instanceId: string, hooks?: MessageSt const normalizedParts = normalizeParts(input.id, input.parts) const shouldBump = Boolean(input.bumpRevision || normalizedParts) const now = Date.now() + const previousRecord = state.messages[input.id] let nextRecord: MessageRecord | undefined - setState("messages", input.id, (previous) => { - const revision = previous ? previous.revision + (shouldBump ? 1 : 0) : 0 - const record: MessageRecord = { - id: input.id, - sessionId: input.sessionId, - role: input.role, - status: input.status, - createdAt: input.createdAt ?? previous?.createdAt ?? now, - updatedAt: input.updatedAt ?? now, - isEphemeral: input.isEphemeral ?? previous?.isEphemeral ?? false, - revision, - partIds: normalizedParts ? normalizedParts.ids : previous?.partIds ?? [], - parts: normalizedParts ? normalizedParts.map : previous?.parts ?? {}, + batch(() => { + setState("messages", input.id, (previous) => { + const revision = previous ? previous.revision + (shouldBump ? 1 : 0) : 0 + const record: MessageRecord = { + id: input.id, + sessionId: input.sessionId, + role: input.role, + status: input.status, + createdAt: input.createdAt ?? previous?.createdAt ?? now, + updatedAt: input.updatedAt ?? now, + isEphemeral: input.isEphemeral ?? previous?.isEphemeral ?? false, + revision, + partIds: normalizedParts ? normalizedParts.ids : previous?.partIds ?? [], + parts: normalizedParts ? normalizedParts.map : previous?.parts ?? {}, + } + nextRecord = record + return record + }) + + if (nextRecord) { + syncPendingSyntheticRecord(nextRecord, previousRecord?.id) + if (normalizedParts) { + clearToolCallPartIndexForMessage(previousRecord) + indexToolCallPartsForMessage(nextRecord) + } + maybeUpdateLatestTodoFromRecord(nextRecord) } - nextRecord = record - return record - }) - if (nextRecord) { - maybeUpdateLatestTodoFromRecord(nextRecord) - } + insertMessageIntoSession(input.sessionId, input.id) + flushPendingParts(input.id) + }) - insertMessageIntoSession(input.sessionId, input.id) - flushPendingParts(input.id) recomputeLastAssistantMessageId(input.sessionId) - bumpSessionRevision(input.sessionId) + scheduleSessionRevisionBump(input.sessionId) } function bufferPendingPart(entry: PendingPartEntry) { @@ -613,13 +807,72 @@ export function createInstanceMessageStore(instanceId: string, hooks?: MessageSt const partId = ensurePartId(input.messageId, input.part, message.partIds.length) const cloned = clonePart(input.part) + const previousPart = message.parts[partId]?.data as ClientPart | undefined setState( "messages", input.messageId, produce((draft: MessageRecord) => { + // When replacing a synthetic (optimistic) message ID with the server ID + // (clearParts: false), synthetic parts are kept to avoid a visible flash. + // When the first real server part arrives for the same type, evict the + // synthetic duplicate so the user's text doesn't appear twice. if (!draft.partIds.includes(partId)) { - draft.partIds = [...draft.partIds, partId] + const syntheticIds = draft.partIds.filter((pid) => { + const entry = draft.parts[pid] + if (!entry) return false + const data = entry.data as any + return data?.synthetic === true && data?.type === cloned.type + }) + for (const sid of syntheticIds) { + draft.partIds = draft.partIds.filter((pid) => pid !== sid) + delete draft.parts[sid] + } + + // When a reasoning/thinking part arrives, insert it before any + // existing text-placeholder parts. The Rust active-text assembler + // can emit `assistant.stream.chunk` events (which create a text + // part placeholder) before the reasoning snapshot arrives via + // `message.part.updated`. Without this, text appears before + // reasoning in the partIds array, breaking display order. + // + // When multiple reasoning parts arrive, insert after the last + // existing reasoning part (rather than before the first text part) + // to preserve chronological order among reasoning parts. + const isReasoningLike = cloned.type === "reasoning" + if (isReasoningLike && draft.partIds.length > 0) { + // Find the last existing reasoning part + let lastReasoningIdx = -1 + for (let i = draft.partIds.length - 1; i >= 0; i--) { + const data = draft.parts[draft.partIds[i]]?.data as any + if (data?.type === "reasoning") { + lastReasoningIdx = i + break + } + } + + if (lastReasoningIdx !== -1) { + // Insert after the last reasoning part + const next = [...draft.partIds] + next.splice(lastReasoningIdx + 1, 0, partId) + draft.partIds = next + } else { + // No reasoning yet — insert before first text part + const insertIdx = draft.partIds.findIndex((pid) => { + const data = draft.parts[pid]?.data as any + return data?.type === "text" + }) + if (insertIdx !== -1) { + const next = [...draft.partIds] + next.splice(insertIdx, 0, partId) + draft.partIds = next + } else { + draft.partIds = [...draft.partIds, partId] + } + } + } else { + draft.partIds = [...draft.partIds, partId] + } } const existing = draft.parts[partId] const nextRevision = existing ? existing.revision + 1 : (cloned as any).version ?? 0 @@ -636,6 +889,14 @@ export function createInstanceMessageStore(instanceId: string, hooks?: MessageSt ) rebindPermissionForPart(input.messageId, partId, cloned) + const previousCallId = getToolCallId(previousPart) + if (previousCallId) { + toolCallPartIndex.delete(makeToolCallPartKey(input.messageId, previousCallId)) + } + const nextCallId = getToolCallId(cloned) + if (nextCallId) { + toolCallPartIndex.set(makeToolCallPartKey(input.messageId, nextCallId), partId) + } if (isCompletedTodoPart(cloned)) { recordLatestTodoSnapshot(message.sessionId, { @@ -647,7 +908,7 @@ export function createInstanceMessageStore(instanceId: string, hooks?: MessageSt // Any part update can change the rendered height of the message // list, so we treat it as a session revision for scroll purposes. - bumpSessionRevision(message.sessionId) + scheduleSessionRevisionBump(message.sessionId) } function applyPartDelta(input: { @@ -692,7 +953,7 @@ export function createInstanceMessageStore(instanceId: string, hooks?: MessageSt ) if (applied && (input.bumpSessionRevision ?? true)) { - bumpSessionRevision(message.sessionId) + scheduleSessionRevisionBump(message.sessionId) } } @@ -713,6 +974,8 @@ export function createInstanceMessageStore(instanceId: string, hooks?: MessageSt } clearRecordDisplayCacheForMessages(instanceId, [messageId]) + clearToolCallPartIndexForMessage(record) + removePendingSyntheticMessageId(messageId) batch(() => { sessionIds.forEach((sessionId) => { @@ -766,6 +1029,11 @@ export function createInstanceMessageStore(instanceId: string, hooks?: MessageSt if (!message) return clearRecordDisplayCacheForMessages(instanceId, [messageId]) + const part = message.parts[partId]?.data as ClientPart | undefined + const callId = getToolCallId(part) + if (callId) { + toolCallPartIndex.delete(makeToolCallPartKey(messageId, callId)) + } batch(() => { setState( @@ -821,6 +1089,11 @@ export function createInstanceMessageStore(instanceId: string, hooks?: MessageSt parts: options.clearParts ? {} : existing.parts, } + clearToolCallPartIndexForMessage(existing) + removePendingSyntheticMessageId(options.oldId) + indexToolCallPartsForMessage(cloned) + addPendingSyntheticRecord(cloned) + setState("messages", options.newId, cloned) setState("messages", (prev) => { const next = { ...prev } @@ -843,7 +1116,7 @@ export function createInstanceMessageStore(instanceId: string, hooks?: MessageSt affectedSessions.forEach((sessionId) => { recomputeLastAssistantMessageId(sessionId) - bumpSessionRevision(sessionId) + scheduleSessionRevisionBump(sessionId) }) const infoEntry = messageInfoCache.get(options.oldId) @@ -905,6 +1178,7 @@ export function createInstanceMessageStore(instanceId: string, hooks?: MessageSt } function upsertPermission(entry: PermissionEntry) { + const previous = pendingPermissionEntriesById.get(entry.permission.id) const messageKey = entry.messageId ?? "__global__" const partKey = entry.partId ?? entry.permission?.id ?? "__global__" @@ -924,9 +1198,31 @@ export function createInstanceMessageStore(instanceId: string, hooks?: MessageSt } }), ) + updatePermissionPendingIndex(previous, entry) + } + + function getPendingPermissionEntries(sessionId?: string) { + if (!sessionId) { + return state.permissions.queue.filter((entry) => !entry?.partId) + } + + const ids = pendingPermissionIdsBySession.get(sessionId) + if (!ids || ids.size === 0) { + return [] + } + + const result: PermissionEntry[] = [] + ids.forEach((id) => { + const entry = pendingPermissionEntriesById.get(id) + if (entry && !entry.partId) { + result.push(entry) + } + }) + return result } function removePermission(permissionId: string) { + const previous = pendingPermissionEntriesById.get(permissionId) setState( "permissions", produce((draft) => { @@ -947,6 +1243,7 @@ export function createInstanceMessageStore(instanceId: string, hooks?: MessageSt }) }), ) + removePermissionPendingIndex(previous) } function getPermissionState(messageId?: string, partId?: string) { @@ -959,6 +1256,7 @@ export function createInstanceMessageStore(instanceId: string, hooks?: MessageSt } function upsertQuestion(entry: QuestionEntry) { + const previous = pendingQuestionEntriesById.get(entry.request.id) const messageKey = entry.messageId ?? "__global__" const partKey = entry.partId ?? entry.request?.id ?? "__global__" @@ -978,9 +1276,31 @@ export function createInstanceMessageStore(instanceId: string, hooks?: MessageSt } }), ) + updateQuestionPendingIndex(previous, entry) + } + + function getPendingQuestionEntries(sessionId?: string) { + if (!sessionId) { + return state.questions.queue.filter((entry) => !entry?.partId) + } + + const ids = pendingQuestionIdsBySession.get(sessionId) + if (!ids || ids.size === 0) { + return [] + } + + const result: QuestionEntry[] = [] + ids.forEach((id) => { + const entry = pendingQuestionEntriesById.get(id) + if (entry && !entry.partId) { + result.push(entry) + } + }) + return result } function removeQuestion(requestId: string) { + const previous = pendingQuestionEntriesById.get(requestId) setState( "questions", produce((draft) => { @@ -1001,6 +1321,7 @@ export function createInstanceMessageStore(instanceId: string, hooks?: MessageSt }) }), ) + removeQuestionPendingIndex(previous) } function getQuestionState(messageId?: string, partId?: string) { @@ -1021,6 +1342,9 @@ export function createInstanceMessageStore(instanceId: string, hooks?: MessageSt const keptIds = session.messageIds.slice(0, stopIndex) if (removedIds.length === 0) return + removedIds.forEach((id) => clearToolCallPartIndexForMessage(state.messages[id])) + removedIds.forEach((id) => removePendingSyntheticMessageId(id)) + setState("sessions", sessionId, "messageIds", keptIds) setState("messages", (prev) => { @@ -1099,12 +1423,17 @@ export function createInstanceMessageStore(instanceId: string, hooks?: MessageSt function clearSession(sessionId: string) { if (!sessionId) return - const messageIds = Object.values(state.messages) - .filter((record) => record.sessionId === sessionId) - .map((record) => record.id) + const session = state.sessions[sessionId] + const messageIds = session?.messageIds?.length + ? [...session.messageIds] + : Object.values(state.messages) + .filter((record) => record.sessionId === sessionId) + .map((record) => record.id) storeLog.info("Clearing session data", { instanceId, sessionId, messageCount: messageIds.length }) clearRecordDisplayCacheForMessages(instanceId, messageIds) + messageIds.forEach((id) => clearToolCallPartIndexForMessage(state.messages[id])) + messageIds.forEach((id) => removePendingSyntheticMessageId(id)) batch(() => { setState("messages", (prev) => { @@ -1194,10 +1523,17 @@ export function createInstanceMessageStore(instanceId: string, hooks?: MessageSt } - function clearInstance() { - messageInfoCache.clear() - setState(reconcile(createInitialState(instanceId))) - } + function clearInstance() { + toolCallPartIndex.clear() + pendingPermissionEntriesById.clear() + pendingPermissionIdsBySession.clear() + pendingQuestionEntriesById.clear() + pendingQuestionIdsBySession.clear() + pendingSyntheticMessageIdsBySessionRole.clear() + pendingSyntheticKeyByMessageId.clear() + messageInfoCache.clear() + setState(reconcile(createInitialState(instanceId))) + } return { @@ -1216,11 +1552,13 @@ export function createInstanceMessageStore(instanceId: string, hooks?: MessageSt replaceMessageId, setMessageInfo, getMessageInfo, - upsertPermission, - removePermission, - getPermissionState, - upsertQuestion, - removeQuestion, + upsertPermission, + getPendingPermissionEntries, + removePermission, + getPermissionState, + upsertQuestion, + getPendingQuestionEntries, + removeQuestion, getQuestionState, setSessionRevert, @@ -1232,10 +1570,12 @@ export function createInstanceMessageStore(instanceId: string, hooks?: MessageSt getSessionRevision: getSessionRevisionValue, getSessionMessageIds: (sessionId: string) => state.sessions[sessionId]?.messageIds ?? [], getLastAssistantMessageId: getLastAssistantMessageIdValue, + getPendingSyntheticMessageId, + resolveToolCallPartId, getLastCompactionMessageIndex, getMessage: (messageId: string) => state.messages[messageId], getLatestTodoSnapshot: (sessionId: string) => state.latestTodos[sessionId], clearSession, clearInstance, - } - } + } + } diff --git a/packages/ui/src/stores/message-v2/record-display-cache.ts b/packages/ui/src/stores/message-v2/record-display-cache.ts index 9e89ee71..7e7a865c 100644 --- a/packages/ui/src/stores/message-v2/record-display-cache.ts +++ b/packages/ui/src/stores/message-v2/record-display-cache.ts @@ -1,7 +1,7 @@ import type { ClientPart } from "../../types/message" import type { MessageRecord } from "./types" -type ClientPartWithRevision = ClientPart & { revision?: number } +type ClientPartWithRevision = ClientPart export interface RecordDisplayData { orderedParts: ClientPartWithRevision[] @@ -30,7 +30,7 @@ export function buildRecordDisplayData(instanceId: string, record: MessageRecord for (const partId of record.partIds) { const entry = record.parts[partId] if (!entry?.data) continue - orderedParts.push({ ...(entry.data as ClientPart), revision: entry.revision }) + orderedParts.push(entry.data as ClientPartWithRevision) } const data: RecordDisplayData = { orderedParts } diff --git a/packages/ui/src/stores/message-v2/types.ts b/packages/ui/src/stores/message-v2/types.ts index 581a896e..44b2bb41 100644 --- a/packages/ui/src/stores/message-v2/types.ts +++ b/packages/ui/src/stores/message-v2/types.ts @@ -51,6 +51,8 @@ export interface PermissionEntry { permission: PermissionRequestLike messageId?: string partId?: string + sessionId?: string + callId?: string enqueuedAt: number } @@ -64,6 +66,8 @@ export interface QuestionEntry { request: QuestionRequest messageId?: string partId?: string + sessionId?: string + callId?: string enqueuedAt: number } diff --git a/packages/ui/src/stores/session-actions.ts b/packages/ui/src/stores/session-actions.ts index 3cdf0321..f7eb470a 100644 --- a/packages/ui/src/stores/session-actions.ts +++ b/packages/ui/src/stores/session-actions.ts @@ -11,6 +11,7 @@ import { removeMessagePartV2, removeMessageV2 } from "./message-v2/bridge" import { getLogger } from "../lib/logger" import { requestData } from "../lib/opencode-api" import { clearConversationPlaybackForSession } from "./conversation-speech" +import { clearAssistantStreamSession } from "./assistant-stream" const log = getLogger("actions") @@ -168,21 +169,7 @@ async function sendMessage( clearConversationPlaybackForSession(instanceId, sessionId) - store.upsertMessage({ - id: messageId, - sessionId, - role: "user", - status: "sending", - parts: optimisticParts, - createdAt, - updatedAt: createdAt, - isEphemeral: true, - }) - - withSession(instanceId, sessionId, () => { - /* trigger reactivity for legacy session data */ - }) - + // Build request body BEFORE any store mutations to avoid blocking the thread. const requestBody = { parts: requestParts, ...(session.agent && { agent: session.agent }), @@ -201,14 +188,36 @@ async function sendMessage( })()), } - log.info("sendMessage", { - instanceId, - sessionId, - requestBody, - }) - try { - log.info("session.promptAsync", { instanceId, sessionId, requestBody }) + // PERF: Yield to the event loop so the keydown handler returns immediately. + // The upsertMessage batch flush triggers a synchronous reactive cascade + // (SolidJS effects + DOM updates) that takes ~4.7s in WebView2. Moving it + // to a timer task prevents the browser from reporting a multi-second + // "keydown handler" violation and frees the input thread sooner. + // The HTTP fetch starts AFTER upsertMessage, so SSE events cannot race + // ahead of the optimistic insert. + await new Promise((resolve) => setTimeout(resolve, 0)) + + // Insert optimistic message BEFORE the fetch so SSE events can find + // and replace the synthetic record when the server responds. + store.upsertMessage({ + id: messageId, + sessionId, + role: "user", + status: "sending", + parts: optimisticParts, + createdAt, + updatedAt: createdAt, + isEphemeral: true, + }) + + withSession(instanceId, sessionId, () => { + /* trigger reactivity for legacy session data */ + }) + + // PERF: Fire the HTTP fetch AFTER the optimistic insert. + // The requestBody was already prepared above the store mutation + // to minimize synchronous work before the insert. await requestData( client.session.promptAsync({ sessionID: sessionId, @@ -216,6 +225,24 @@ async function sendMessage( }), "session.promptAsync", ) + + // Transition the optimistic message from "sending" to "sent" so it stays + // visible until the real SSE events arrive and replace it. + const pendingId = store.getPendingSyntheticMessageId(sessionId, "user") + if (pendingId) { + const record = store.getMessage(pendingId) + if (record?.isEphemeral && record.status === "sending") { + store.upsertMessage({ + id: pendingId, + sessionId, + role: "user", + status: "sent", + createdAt: record.createdAt, + updatedAt: Date.now(), + isEphemeral: true, + }) + } + } } catch (error) { log.error("Failed to send prompt", error) throw error @@ -319,6 +346,17 @@ async function abortSession(instanceId: string, sessionId: string): Promise { + session.status = "idle" + }) + + // Also clear any ephemeral stream preview so it doesn't linger until + // the SSE session.idle event arrives. + clearAssistantStreamSession(instanceId, sessionId) } catch (error) { log.error("Failed to abort session", error) throw error diff --git a/packages/ui/src/stores/session-events.ts b/packages/ui/src/stores/session-events.ts index 713c95d1..1a68c5b0 100644 --- a/packages/ui/src/stores/session-events.ts +++ b/packages/ui/src/stores/session-events.ts @@ -1,3 +1,4 @@ +import { batch } from "solid-js" import type { MessageInfo, MessagePartRemovedEvent, @@ -51,6 +52,15 @@ import { ensureSessionParentExpanded, sessions, setSessions, syncInstanceSession import { normalizeMessagePart } from "./message-v2/normalizers" import { updateSessionInfo } from "./message-v2/session-info" import { tGlobal } from "../lib/i18n" +import { handleConversationAssistantPartUpdated } from "./conversation-speech" +import { + appendAssistantStreamChunk, + clearAssistantStreamMessage, + clearAssistantStreamPart, + clearAssistantStreamSession, + getAssistantStreamPreviewText, +} from "./assistant-stream" +import type { AssistantStreamChunkEvent } from "../lib/event-transport-contract" import { loadMessages } from "./session-api" import { getOrCreateWorktreeClient, getRootClient, getWorktreeSlugForDirectory, getWorktreeSlugForSession } from "./worktrees" @@ -70,7 +80,6 @@ import { } from "./message-v2/bridge" import { messageStoreBus } from "./message-v2/bus" import type { InstanceMessageStore } from "./message-v2/instance-store" -import { handleConversationAssistantPartUpdated } from "./conversation-speech" const log = getLogger("sse") const pendingSessionFetches = new Map>() @@ -84,6 +93,147 @@ function isSameRetryState(left: SessionRetryState | null | undefined, right: Ses return a.attempt === b.attempt && a.message === b.message && a.next === b.next } +const pendingSessionInfoUpdates = new Set() +let sessionInfoFlushQueued = false +const pendingPartDeltas = new Map() +let partDeltaFlushTask: number | null = null + +function flushPendingSessionInfoUpdates() { + sessionInfoFlushQueued = false + if (pendingSessionInfoUpdates.size === 0) { + return + } + + const pending = Array.from(pendingSessionInfoUpdates) + pendingSessionInfoUpdates.clear() + + batch(() => { + for (const entry of pending) { + const [instanceId, sessionId] = entry.split(":") + if (!instanceId || !sessionId) continue + updateSessionInfo(instanceId, sessionId) + } + }) +} + +function scheduleSessionInfoUpdate(instanceId: string, sessionId: string) { + if (!instanceId || !sessionId) { + return + } + + pendingSessionInfoUpdates.add(`${instanceId}:${sessionId}`) + if (sessionInfoFlushQueued) { + return + } + + sessionInfoFlushQueued = true + queueMicrotask(flushPendingSessionInfoUpdates) +} + +function makeQueuedPartDeltaKey(instanceId: string, messageId: string, partId: string, field: string) { + return `${instanceId}:${messageId}:${partId}:${field}` +} + +function clearScheduledPartDeltaFlush() { + if (partDeltaFlushTask === null) { + return + } + + window.clearTimeout(partDeltaFlushTask) + partDeltaFlushTask = null +} + +function flushQueuedPartDeltas( + predicate?: (entry: { instanceId: string; messageId: string; partId: string; field: string; delta: string }) => boolean, +) { + if (pendingPartDeltas.size === 0) { + return + } + + const entries: Array<{ instanceId: string; messageId: string; partId: string; field: string; delta: string }> = [] + + for (const [key, entry] of pendingPartDeltas) { + if (predicate && !predicate(entry)) { + continue + } + entries.push(entry) + pendingPartDeltas.delete(key) + } + + if (entries.length === 0) { + return + } + + if (pendingPartDeltas.size === 0) { + clearScheduledPartDeltaFlush() + } + + batch(() => { + for (const entry of entries) { + applyPartDeltaV2(entry.instanceId, { + messageId: entry.messageId, + partId: entry.partId, + field: entry.field, + delta: entry.delta, + }) + } + }) +} + +function flushQueuedPartDeltasForMessage(instanceId: string, messageId: string) { + flushQueuedPartDeltas((entry) => entry.instanceId === instanceId && entry.messageId === messageId) +} + +function flushQueuedPartDeltasForPart(instanceId: string, messageId: string, partId: string) { + flushQueuedPartDeltas((entry) => entry.instanceId === instanceId && entry.messageId === messageId && entry.partId === partId) +} + +/** Cadence for flushing coalesced deltas to the store (ms). + * 48ms (~3 frames) keeps heavy store-driven re-renders (Markdown, block + * layout) throttled while still delivering timely updates. */ +const DELTA_FLUSH_INTERVAL_MS = 48 + +function scheduleQueuedPartDeltaFlush() { + if (partDeltaFlushTask !== null) { + return + } + + partDeltaFlushTask = window.setTimeout(() => { + partDeltaFlushTask = null + flushQueuedPartDeltas() + }, DELTA_FLUSH_INTERVAL_MS) +} + +function queueMessagePartDelta(instanceId: string, event: MessagePartDeltaEvent) { + const props = event.properties + if (!props) return + + const { messageID, partID, field, delta } = props + if (!messageID || !partID || !field || typeof delta !== "string" || delta.length === 0) return + + const key = makeQueuedPartDeltaKey(instanceId, messageID, partID, field) + const existing = pendingPartDeltas.get(key) + if (existing) { + existing.delta += delta + } else { + pendingPartDeltas.set(key, { + instanceId, + messageId: messageID, + partId: partID, + field, + delta, + }) + } + + scheduleQueuedPartDeltaFlush() +} + function shouldSendOsNotification(kind: "needsInput" | "idle"): boolean { if (typeof document === "undefined") return false const pref = preferences() @@ -155,7 +305,10 @@ function applySessionStatus(instanceId: string, sessionId: string, status: Sessi const nextRetry = retry ?? null if (current === status && isSameRetryState(session.retry, nextRetry)) return false - if (current === "compacting" && status !== "compacting") { + // While compacting, only allow transitions to "idle" (abort / session end) + // or staying in "compacting". Block "compacting -> working" so the + // compaction UI isn't prematurely replaced by the normal busy indicator. + if (current === "compacting" && status !== "compacting" && status !== "idle") { return false } @@ -225,8 +378,8 @@ async function fetchSessionInfo(instanceId: string, sessionId: string, directory ...fetched, agent: existing?.agent ?? fetched.agent, model: existing?.model ?? fetched.model, - status: existing?.status === "compacting" ? "compacting" : fetched.status, - retry: existing?.status === "compacting" ? null : fetched.retry, + status: existing?.status === "compacting" && fetched.status !== "idle" ? "compacting" : fetched.status, + retry: existing?.status === "compacting" && fetched.status !== "idle" ? null : fetched.retry, pendingPermission: existing?.pendingPermission ?? fetched.pendingPermission, pendingQuestion: existing?.pendingQuestion ?? false, } @@ -297,17 +450,7 @@ function findPendingSyntheticMessageId( sessionId: string, role: MessageRole, ): string | undefined { - const messageIds = store.getSessionMessageIds(sessionId) - for (const messageId of messageIds) { - const record = store.getMessage(messageId) - if (!record) continue - if (record.sessionId !== sessionId) continue - if (record.role !== role) continue - if (record.status !== "sending") continue - if (!record.isEphemeral) continue - return record.id - } - return undefined + return store.getPendingSyntheticMessageId(sessionId, role) } function handleMessageUpdate(instanceId: string, event: MessageUpdateEvent | MessagePartUpdatedEvent): void { @@ -326,6 +469,11 @@ function handleMessageUpdate(instanceId: string, event: MessageUpdateEvent | Mes const sessionId = typeof part.sessionID === "string" ? part.sessionID : fallbackSessionId const messageId = typeof part.messageID === "string" ? part.messageID : fallbackMessageId if (!sessionId || !messageId) return + if (typeof part.id === "string" && part.id.length > 0) { + flushQueuedPartDeltasForPart(instanceId, messageId, part.id) + } else { + flushQueuedPartDeltasForMessage(instanceId, messageId) + } if (part.type === "compaction") { ensureSessionStatus(instanceId, sessionId, "compacting", (event as any)?.directory) } @@ -333,42 +481,98 @@ function handleMessageUpdate(instanceId: string, event: MessageUpdateEvent | Mes const store = messageStoreBus.getOrCreate(instanceId) const role: MessageRole = resolveMessageRole(messageInfo) const createdAt = typeof messageInfo?.time?.created === "number" ? messageInfo.time.created : Date.now() + const previousText = + role === "assistant" && part.type === "text" + ? (() => { + const existingRecord = store.getMessage(messageId) + const existingPart = existingRecord?.parts[part.id]?.data as { text?: unknown } | undefined + return typeof existingPart?.text === "string" ? existingPart.text : "" + })() + : "" + const previousPreviewText = + role === "assistant" && part.type === "text" + ? getAssistantStreamPreviewText(instanceId, sessionId, messageId, part.id) ?? "" + : "" + const previousVisibleText = + previousPreviewText.length > previousText.length ? previousPreviewText : previousText + + + batch(() => { + let record = store.getMessage(messageId) + if (!record) { + const pendingId = findPendingSyntheticMessageId(store, sessionId, role) + if (pendingId && pendingId !== messageId) { + // PERF: Never clear parts during ID swap. The synthetic (optimistic) parts + // stay visible until the server parts arrive via applyPartUpdateV2. + // Clearing them causes a visible flash in WebView2 because the ID change + // forces a component re-mount, and empty parts render as blank. + replaceMessageIdV2(instanceId, pendingId, messageId, { clearParts: false }) + record = store.getMessage(messageId) + } + } - - let record = store.getMessage(messageId) - if (!record) { - const pendingId = findPendingSyntheticMessageId(store, sessionId, role) - if (pendingId && pendingId !== messageId) { - replaceMessageIdV2(instanceId, pendingId, messageId, { clearParts: role === "user" }) - record = store.getMessage(messageId) + if (!record) { + store.upsertMessage({ + id: messageId, + sessionId, + role, + status: "streaming", + createdAt, + updatedAt: createdAt, + isEphemeral: true, + }) } - } - if (!record) { - store.upsertMessage({ - id: messageId, - sessionId, - role, - status: "streaming", - createdAt, - updatedAt: createdAt, - isEphemeral: true, - }) - } + if (messageInfo) { + upsertMessageInfoV2(instanceId, messageInfo, { status: "streaming" }) + } - if (messageInfo) { - upsertMessageInfoV2(instanceId, messageInfo, { status: "streaming" }) - } - - applyPartUpdateV2(instanceId, { ...part, sessionID: sessionId, messageID: messageId }) - handleConversationAssistantPartUpdated(instanceId, { ...part, sessionID: sessionId, messageID: messageId }, messageInfo) + applyPartUpdateV2(instanceId, { ...part, sessionID: sessionId, messageID: messageId }) + handleConversationAssistantPartUpdated(instanceId, { ...part, sessionID: sessionId, messageID: messageId }, messageInfo) + + if ( + role === "assistant" && + part.type === "text" && + typeof part.text === "string" && + part.text.length > previousVisibleText.length && + part.text.startsWith(previousVisibleText) + ) { + appendAssistantStreamChunk(instanceId, { + type: "assistant.stream.chunk", + properties: { + sessionID: sessionId, + messageID: messageId, + partID: part.id, + field: "text", + delta: part.text.slice(previousVisibleText.length), + }, + }) + } else if ( + role === "assistant" && + part.type === "text" && + typeof part.text === "string" && + previousPreviewText.length > 0 && + part.text.length >= previousPreviewText.length && + part.text.startsWith(previousPreviewText) + ) { + clearAssistantStreamPart(instanceId, sessionId, messageId, part.id) + } else if ( + role === "assistant" && + part.type === "text" && + previousPreviewText.length > 0 + ) { + // Server sent a correction/replacement that diverges from the preview. + // Clear stale preview so the canonical text takes over. + clearAssistantStreamPart(instanceId, sessionId, messageId, part.id) + } - if (part.type === "tool" && part.tool === "question") { - // Questions can arrive before their tool part exists; re-link now. - reconcilePendingQuestionsV2(instanceId, sessionId) - } + if (part.type === "tool" && part.tool === "question") { + // Questions can arrive before their tool part exists; re-link now. + reconcilePendingQuestionsV2(instanceId, sessionId) + } - updateSessionInfo(instanceId, sessionId) + scheduleSessionInfoUpdate(instanceId, sessionId) + }) } else if (event.type === "message.updated") { const info = event.properties?.info if (!info) return @@ -376,6 +580,12 @@ function handleMessageUpdate(instanceId: string, event: MessageUpdateEvent | Mes const sessionId = typeof info.sessionID === "string" ? info.sessionID : undefined const messageId = typeof info.id === "string" ? info.id : undefined if (!sessionId || !messageId) return + flushQueuedPartDeltasForMessage(instanceId, messageId) + // Defer preview clear to after the current synchronous batch completes. + // Using queueMicrotask (not requestAnimationFrame) so the clear runs within + // the same JS task — avoiding a one-frame flicker where preview text could + // reappear between the rAF callback and the next paint. + queueMicrotask(() => clearAssistantStreamMessage(instanceId, sessionId, messageId)) const timeInfo = (info.time ?? {}) as { created?: number; updated?: number; end?: number } const nextUpdated = @@ -387,52 +597,101 @@ function handleMessageUpdate(instanceId: string, event: MessageUpdateEvent | Mes ? timeInfo.created : Date.now() - withSession(instanceId, sessionId, (session) => { - const currentUpdated = session.time?.updated ?? 0 - if (nextUpdated <= currentUpdated) return false - session.time = { ...(session.time ?? {}), updated: nextUpdated } - }) + batch(() => { + withSession(instanceId, sessionId, (session) => { + const currentUpdated = session.time?.updated ?? 0 + if (nextUpdated <= currentUpdated) return false + session.time = { ...(session.time ?? {}), updated: nextUpdated } + }) - const store = messageStoreBus.getOrCreate(instanceId) + const store = messageStoreBus.getOrCreate(instanceId) - const role: MessageRole = info.role === "user" ? "user" : "assistant" - const hasError = Boolean((info as any).error) - const status: MessageStatus = hasError ? "error" : "complete" + const role: MessageRole = info.role === "user" ? "user" : "assistant" + const hasError = Boolean((info as any).error) + const status: MessageStatus = hasError ? "error" : "complete" - let record = store.getMessage(messageId) - if (!record) { - const pendingId = findPendingSyntheticMessageId(store, sessionId, role) - if (pendingId && pendingId !== messageId) { - replaceMessageIdV2(instanceId, pendingId, messageId, { clearParts: role === "user" }) - record = store.getMessage(messageId) + let record = store.getMessage(messageId) + if (!record) { + const pendingId = findPendingSyntheticMessageId(store, sessionId, role) + if (pendingId && pendingId !== messageId) { + replaceMessageIdV2(instanceId, pendingId, messageId, { clearParts: false }) + record = store.getMessage(messageId) + } } - } - if (!record) { - const createdAt = info.time?.created ?? Date.now() - const endAt = (info.time as { end?: number } | undefined)?.end - store.upsertMessage({ - id: messageId, - sessionId, - role, - status, - createdAt, - updatedAt: endAt ?? createdAt, - }) - } + if (!record) { + const createdAt = info.time?.created ?? Date.now() + const endAt = (info.time as { end?: number } | undefined)?.end + store.upsertMessage({ + id: messageId, + sessionId, + role, + status, + createdAt, + updatedAt: endAt ?? createdAt, + }) + } - upsertMessageInfoV2(instanceId, info, { status, bumpRevision: true }) + upsertMessageInfoV2(instanceId, info, { status, bumpRevision: true, isEphemeral: false }) - updateSessionInfo(instanceId, sessionId) + scheduleSessionInfoUpdate(instanceId, sessionId) + }) } } function handleMessagePartDelta(instanceId: string, event: MessagePartDeltaEvent): void { + queueMessagePartDelta(instanceId, event) +} + +function handleAssistantStreamChunk(instanceId: string, event: AssistantStreamChunkEvent): void { const props = event.properties - if (!props) return - const { messageID, partID, field, delta } = props - if (!messageID || !partID || !field || typeof delta !== "string") return - applyPartDeltaV2(instanceId, { messageId: messageID, partId: partID, field, delta }) + if (!props?.sessionID || !props.messageID || !props.partID) { + return + } + + // Guard: ignore chunks for sessions that don't exist or are already idle + // (a stale/delayed chunk could otherwise create ghost ephemeral messages). + const instanceSessions = sessions().get(instanceId) + const session = instanceSessions?.get(props.sessionID) + if (!session || session.status === "idle") { + return + } + + const store = messageStoreBus.getOrCreate(instanceId) + let record = store.getMessage(props.messageID) + + // When the message/part doesn't exist yet, create lightweight ephemeral + // placeholders. Wrap both mutations in a batch so SolidJS flushes them + // atomically — avoiding separate reactive cascades for each mutation. + if (!record || !record.parts[props.partID]) { + batch(() => { + if (!record) { + store.upsertMessage({ + id: props.messageID, + sessionId: props.sessionID, + role: "assistant", + status: "streaming", + createdAt: Date.now(), + updatedAt: Date.now(), + isEphemeral: true, + bumpRevision: false, + }) + record = store.getMessage(props.messageID) + } + + if (record && !record.parts[props.partID]) { + applyPartUpdateV2(instanceId, { + id: props.partID, + type: "text", + text: "", + sessionID: props.sessionID, + messageID: props.messageID, + } as any) + } + }) + } + + appendAssistantStreamChunk(instanceId, event) } function handleSessionUpdate(instanceId: string, event: EventSessionUpdated): void { @@ -555,6 +814,28 @@ function handleSessionIdle(instanceId: string, event: EventSessionIdle): void { } ensureSessionStatus(instanceId, sessionId, "idle", (event as any)?.directory) + + // Clean up any lingering assistant-stream preview entries for this session. + // On Tauri this also happens on session switch (App.tsx), but on browser + // hosts this is the only cleanup path. + clearAssistantStreamSession(instanceId, sessionId) + + // Safety net: if any ephemeral "sending"/"sent" messages still exist for + // this session they are orphans (the real server messages never arrived or + // replaced them). Reload the full message list so the UI reflects reality. + const store = messageStoreBus.getOrCreate(instanceId) + const messageIds = store.getSessionMessageIds(sessionId) ?? [] + const hasOrphanedEphemeral = messageIds.some((id) => { + const record = store.getMessage(id) + return record?.isEphemeral && (record.status === "sending" || record.status === "sent") + }) + if (hasOrphanedEphemeral) { + log.info(`[SSE] Session idle with orphaned ephemeral messages, reloading: ${sessionId}`) + loadMessages(instanceId, sessionId, true).catch((error) => + log.error("Failed to reload messages after idle with orphaned ephemerals", error), + ) + } + log.info(`[SSE] Session idle: ${sessionId}`) } @@ -646,8 +927,10 @@ function handleMessageRemoved(instanceId: string, event: MessageRemovedEvent): v if (!sessionID || !messageID) return log.info(`[SSE] Message removed from session ${sessionID}`, { messageID }) + flushQueuedPartDeltasForMessage(instanceId, messageID) + clearAssistantStreamMessage(instanceId, sessionID, messageID) removeMessageV2(instanceId, messageID) - updateSessionInfo(instanceId, sessionID) + scheduleSessionInfoUpdate(instanceId, sessionID) } function handleMessagePartRemoved(instanceId: string, event: MessagePartRemovedEvent): void { @@ -655,8 +938,10 @@ function handleMessagePartRemoved(instanceId: string, event: MessagePartRemovedE if (!sessionID || !messageID || !partID) return log.info(`[SSE] Message part removed from session ${sessionID}`, { messageID, partID }) + flushQueuedPartDeltasForPart(instanceId, messageID, partID) + clearAssistantStreamPart(instanceId, sessionID, messageID, partID) removeMessagePartV2(instanceId, messageID, partID) - updateSessionInfo(instanceId, sessionID) + scheduleSessionInfoUpdate(instanceId, sessionID) } function handleTuiToast(_instanceId: string, event: TuiToastEvent): void { @@ -736,6 +1021,7 @@ function handleQuestionAnswered( } export { + handleAssistantStreamChunk, handleMessagePartRemoved, handleMessageRemoved, handleMessagePartDelta, diff --git a/packages/ui/src/stores/session-state.ts b/packages/ui/src/stores/session-state.ts index c4703e0b..6c5d1cb2 100644 --- a/packages/ui/src/stores/session-state.ts +++ b/packages/ui/src/stores/session-state.ts @@ -589,8 +589,9 @@ function setActiveSessionFromList(instanceId: string, sessionId: string): void { function isSessionBusy(instanceId: string, sessionId: string): boolean { const instanceSessions = sessions().get(instanceId) if (!instanceSessions) return false - if (!instanceSessions.has(sessionId)) return false - return true + const session = instanceSessions.get(sessionId) + if (!session) return false + return session.status === "working" || session.status === "compacting" } function isSessionMessagesLoading(instanceId: string, sessionId: string): boolean { diff --git a/packages/ui/src/stores/sessions.ts b/packages/ui/src/stores/sessions.ts index 8b5e9a3e..5d3de73b 100644 --- a/packages/ui/src/stores/sessions.ts +++ b/packages/ui/src/stores/sessions.ts @@ -56,6 +56,7 @@ import { updateSessionModel, } from "./session-actions" import { + handleAssistantStreamChunk, handleMessagePartRemoved, handleMessageRemoved, handleMessagePartDelta, @@ -76,6 +77,7 @@ import { sseManager.onMessageUpdate = handleMessageUpdate sseManager.onMessagePartUpdated = handleMessageUpdate sseManager.onMessagePartDelta = handleMessagePartDelta +sseManager.onAssistantStreamChunk = handleAssistantStreamChunk sseManager.onMessageRemoved = handleMessageRemoved sseManager.onMessagePartRemoved = handleMessagePartRemoved sseManager.onSessionUpdate = handleSessionUpdate diff --git a/packages/ui/src/styles/messaging/message-base.css b/packages/ui/src/styles/messaging/message-base.css index 5b290f89..0953db4c 100644 --- a/packages/ui/src/styles/messaging/message-base.css +++ b/packages/ui/src/styles/messaging/message-base.css @@ -137,6 +137,10 @@ flex-direction: column; gap: 2px; margin-bottom: 2px; + /* Layout/style containment per message block — isolates reflows. + Avoid `paint` (clips overflow like delete-hover overlay) and + `size` (collapses height when children size the block). */ + contain: layout style; } .message-step-start { diff --git a/packages/ui/src/styles/messaging/message-timeline.css b/packages/ui/src/styles/messaging/message-timeline.css index 78aeb927..3adde835 100644 --- a/packages/ui/src/styles/messaging/message-timeline.css +++ b/packages/ui/src/styles/messaging/message-timeline.css @@ -9,6 +9,7 @@ /* Isolate stacking context so sidebar z-indices don't compete with Portals (Command Palette, modals) that live at the body level. */ isolation: isolate; + contain: layout style; } .message-layout--with-timeline { diff --git a/packages/ui/src/styles/messaging/virtual-follow-list.css b/packages/ui/src/styles/messaging/virtual-follow-list.css index 9392eb6a..c5fd9965 100644 --- a/packages/ui/src/styles/messaging/virtual-follow-list.css +++ b/packages/ui/src/styles/messaging/virtual-follow-list.css @@ -5,6 +5,7 @@ min-height: 0; position: relative; width: 100%; + contain: layout style; } .message-stream { @@ -24,7 +25,7 @@ /* Prevent browser scroll anchoring fighting our virtualization compensation. */ overflow-anchor: none; - + /* Scrollbar styling */ scrollbar-gutter: stable; }