diff --git a/CHANGELOG.md b/CHANGELOG.md index b0041448..fffe104a 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -22,6 +22,19 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - **Fair comparison benchmark** (`tests/graph_bench_compare.rs`): Moon 2.4x FalkorDB on Cypher MATCH, 19x on native 1-hop, 23x on population. - **New dependencies**: `slotmap` 1.x (generational indices), `boomphf` 0.6 (MPH), `logos` 0.14 (Cypher lexer, optional). +### Added — Client Connection Security Hardening (2026-04-10) + +- **`--maxclients` (P0):** Connection limit with atomic CAS rejection (default 10000, 0=unlimited). Returns `-ERR max number of clients reached` when exceeded. +- **`--timeout` (P0):** Client idle timeout in seconds (default 0=disabled). Disconnects idle clients via `tokio::time::timeout` / `monoio::select!`. +- **`--tcp-keepalive` (P0):** TCP keepalive interval (default 300s, 0=disabled). Sets `SO_KEEPALIVE` + `TCP_KEEPIDLE` on accepted sockets via `socket2`. +- **AUTH rate limiting (P0):** Per-IP exponential backoff on AUTH failures (100ms base, 10s cap, 60s auto-reset). New module `src/auth_ratelimit.rs`. +- **CLIENT LIST / INFO / KILL (P1):** Global client registry with Drop-guard deregister. Redis-compatible output format. Kill by ID/ADDR/USER. New module `src/client_registry.rs`. +- **CLIENT PAUSE / UNPAUSE (P1):** Server-wide pause with ALL/WRITE modes and auto-expiry. New module `src/client_pause.rs`. +- **CLIENT NO-EVICT / NO-TOUCH (P1):** Accepted stubs for Redis compatibility. +- **ACL GENPASS (P1):** Cryptographically secure random password generation (1-4096 bits, hex output). +- **CONFIG GET/SET** support for `maxclients`, `timeout`, `tcp-keepalive` (runtime-mutable). +- **Monoio connection tracking:** Added missing `record_connection_opened` / `record_connection_closed` for accurate `connected_clients` metric. + ### Fixed — Deep Review Findings (2026-04-11) - **DoS protection**: `execute_profile` and `execute_mut` Cypher paths now enforce MAX_HOPS_LIMIT=20 and MAX_RESULT_ROWS=100K (were unbounded). diff --git a/deny.toml b/deny.toml index c099d593..693dd27a 100644 --- a/deny.toml +++ b/deny.toml @@ -1,17 +1,20 @@ -# cargo-deny configuration for Moon. +# cargo-deny configuration for Moon (v2 format). # Run: cargo deny check # CI: .github/workflows/ci.yml safety-audit job +[graph] +targets = [] +all-features = false +no-default-features = false + [advisories] +version = 2 db-path = "~/.cargo/advisory-db" db-urls = ["https://github.com/rustsec/advisory-db"] -vulnerability = "deny" -unmaintained = "warn" -yanked = "warn" -notice = "warn" +ignore = [] [licenses] -unlicensed = "deny" +version = 2 allow = [ "MIT", "Apache-2.0", @@ -26,7 +29,6 @@ allow = [ "CC0-1.0", "0BSD", ] -copyleft = "deny" [bans] multiple-versions = "warn" diff --git a/src/admin/metrics_setup.rs b/src/admin/metrics_setup.rs index 503e6464..97e06996 100644 --- a/src/admin/metrics_setup.rs +++ b/src/admin/metrics_setup.rs @@ -365,6 +365,42 @@ pub fn connected_clients() -> u64 { CONNECTED_CLIENTS.load(Ordering::Relaxed) } +/// Try to open a connection if under the maxclients limit. +/// Returns true if the connection was accepted, false if at limit. +/// When maxclients is 0, the limit is disabled (unlimited). +#[inline] +pub fn try_accept_connection(maxclients: usize) -> bool { + if maxclients == 0 { + record_connection_opened(); + return true; + } + // CAS loop: only increment if under limit. + // AcqRel on success ensures the counter increment is visible to other cores + // before the connection handler runs (important on ARM/weak-memory archs). + let mut current = CONNECTED_CLIENTS.load(Ordering::Acquire); + loop { + if current >= maxclients as u64 { + return false; + } + match CONNECTED_CLIENTS.compare_exchange_weak( + current, + current + 1, + Ordering::AcqRel, + Ordering::Acquire, + ) { + Ok(_) => { + TOTAL_CONNECTIONS.fetch_add(1, Ordering::Relaxed); + if METRICS_INITIALIZED.load(Ordering::Relaxed) { + counter!("moon_connections_total").increment(1); + gauge!("moon_connected_clients").increment(1.0); + } + return true; + } + Err(actual) => current = actual, + } + } +} + // ── Keyspace metrics ──────────────────────────────────────────────────── /// Record keyspace hit/miss. diff --git a/src/auth_ratelimit.rs b/src/auth_ratelimit.rs new file mode 100644 index 00000000..9eec0ac8 --- /dev/null +++ b/src/auth_ratelimit.rs @@ -0,0 +1,132 @@ +//! Per-IP AUTH failure rate limiting. +//! +//! Tracks failed AUTH attempts per client IP and enforces exponential backoff +//! delays to prevent brute-force attacks. Successful AUTH has zero overhead. +//! +//! Design: +//! - Global `parking_lot::Mutex` (not on hot path — only touched on AUTH) +//! - Exponential backoff: 100ms * 2^(failures-1), capped at 10s +//! - Auto-reset after 60s of inactivity per IP +//! - Periodic cleanup of stale entries via `cleanup_stale()` + +use parking_lot::Mutex; +use std::collections::HashMap; +use std::net::IpAddr; +use std::sync::LazyLock; +use std::time::Instant; + +/// Base delay for first failure (100ms). +const BASE_DELAY_MS: u64 = 100; + +/// Maximum delay cap (10 seconds). +const MAX_DELAY_MS: u64 = 10_000; + +/// Entries older than this are pruned (60 seconds). +const STALE_THRESHOLD_SECS: u64 = 60; + +/// Maximum entries before forced cleanup. +const MAX_ENTRIES: usize = 10_000; + +struct FailureRecord { + count: u32, + last_failure: Instant, +} + +static RATE_LIMITER: LazyLock>> = + LazyLock::new(|| Mutex::new(HashMap::new())); + +/// Record a failed AUTH attempt for the given IP. +/// Returns the delay in milliseconds that should be applied before +/// sending the error response. +pub fn record_failure(ip: IpAddr) -> u64 { + let mut map = RATE_LIMITER.lock(); + + // Periodic cleanup when map grows too large + if map.len() >= MAX_ENTRIES { + let cutoff = Instant::now() - std::time::Duration::from_secs(STALE_THRESHOLD_SECS); + map.retain(|_, r| r.last_failure > cutoff); + } + + let now = Instant::now(); + let record = map.entry(ip).or_insert(FailureRecord { + count: 0, + last_failure: now, + }); + + // Reset if stale (no failures for STALE_THRESHOLD_SECS) + if now.duration_since(record.last_failure).as_secs() >= STALE_THRESHOLD_SECS { + record.count = 0; + } + + record.count = record.count.saturating_add(1); + record.last_failure = now; + + // Exponential backoff: 100ms * 2^(count-1), capped at 10s + let delay = BASE_DELAY_MS.saturating_mul( + 1u64.checked_shl(record.count.saturating_sub(1)) + .unwrap_or(u64::MAX), + ); + delay.min(MAX_DELAY_MS) +} + +/// Clear failure record on successful AUTH (reset for this IP). +pub fn record_success(ip: IpAddr) { + let mut map = RATE_LIMITER.lock(); + map.remove(&ip); +} + +#[cfg(test)] +mod tests { + use super::*; + use std::net::Ipv4Addr; + + #[test] + fn test_first_failure_returns_base_delay() { + let ip = IpAddr::V4(Ipv4Addr::new(192, 168, 1, 100)); + // Clean up from any prior test state + record_success(ip); + let delay = record_failure(ip); + assert_eq!(delay, 100); + // Cleanup + record_success(ip); + } + + #[test] + fn test_exponential_backoff() { + let ip = IpAddr::V4(Ipv4Addr::new(192, 168, 1, 101)); + record_success(ip); + assert_eq!(record_failure(ip), 100); // 100 * 2^0 + assert_eq!(record_failure(ip), 200); // 100 * 2^1 + assert_eq!(record_failure(ip), 400); // 100 * 2^2 + assert_eq!(record_failure(ip), 800); // 100 * 2^3 + assert_eq!(record_failure(ip), 1600); // 100 * 2^4 + record_success(ip); + } + + #[test] + fn test_max_delay_cap() { + let ip = IpAddr::V4(Ipv4Addr::new(192, 168, 1, 102)); + record_success(ip); + // 7 failures: 100, 200, 400, 800, 1600, 3200, 6400 + for _ in 0..7 { + record_failure(ip); + } + // 8th failure should be capped at 10000 + let delay = record_failure(ip); + assert!(delay <= MAX_DELAY_MS); + record_success(ip); + } + + #[test] + fn test_success_clears_record() { + let ip = IpAddr::V4(Ipv4Addr::new(192, 168, 1, 103)); + record_success(ip); + record_failure(ip); + record_failure(ip); + record_success(ip); + // After success, next failure should be back to base + let delay = record_failure(ip); + assert_eq!(delay, 100); + record_success(ip); + } +} diff --git a/src/client_pause.rs b/src/client_pause.rs new file mode 100644 index 00000000..012808ec --- /dev/null +++ b/src/client_pause.rs @@ -0,0 +1,104 @@ +//! CLIENT PAUSE / UNPAUSE global state. +//! +//! When paused, command processing is delayed for all clients until the pause +//! expires or CLIENT UNPAUSE is called. Supports two modes: +//! - ALL: pause all commands +//! - WRITE: pause only write commands (reads still served) + +use parking_lot::RwLock; +use std::sync::LazyLock; +use std::time::Instant; + +#[derive(Clone, Copy, PartialEq, Eq)] +pub enum PauseMode { + All, + Write, +} + +struct PauseState { + active: bool, + mode: PauseMode, + until: Instant, +} + +static PAUSE: LazyLock> = LazyLock::new(|| { + RwLock::new(PauseState { + active: false, + mode: PauseMode::All, + until: Instant::now(), + }) +}); + +/// Activate pause for the given duration and mode. +pub fn pause(duration_ms: u64, mode: PauseMode) { + let mut state = PAUSE.write(); + state.active = true; + state.mode = mode; + state.until = Instant::now() + std::time::Duration::from_millis(duration_ms); +} + +/// Deactivate pause immediately. +pub fn unpause() { + let mut state = PAUSE.write(); + state.active = false; +} + +/// Check if the server is currently paused. Returns the remaining duration +/// if paused, or None if not paused. Auto-expires when the deadline passes. +pub fn check_pause(is_write: bool) -> Option { + let state = PAUSE.read(); + if !state.active { + return None; + } + let now = Instant::now(); + if now >= state.until { + // Expired — will be cleaned up on next write access + return None; + } + // In WRITE mode, only pause write commands + if state.mode == PauseMode::Write && !is_write { + return None; + } + Some(state.until - now) +} + +/// Clean up expired pause state (called periodically from handlers). +/// Takes a write lock to avoid TOCTOU between expiry check and clear. +pub fn expire_if_needed() { + let mut state = PAUSE.write(); + if state.active && Instant::now() >= state.until { + state.active = false; + } +} + +#[cfg(test)] +mod tests { + use super::*; + + // Tests share global state — run sequentially via a mutex + static TEST_LOCK: std::sync::Mutex<()> = std::sync::Mutex::new(()); + + #[test] + fn test_pause_and_check() { + let _lock = TEST_LOCK.lock(); + unpause(); + assert!(check_pause(true).is_none()); + + pause(5000, PauseMode::All); + assert!(check_pause(true).is_some()); + assert!(check_pause(false).is_some()); + + unpause(); + assert!(check_pause(true).is_none()); + } + + #[test] + fn test_write_mode_allows_reads() { + let _lock = TEST_LOCK.lock(); + unpause(); + pause(5000, PauseMode::Write); + assert!(check_pause(true).is_some()); // write blocked + assert!(check_pause(false).is_none()); // read allowed + unpause(); + } +} diff --git a/src/client_registry.rs b/src/client_registry.rs new file mode 100644 index 00000000..0d71ec57 --- /dev/null +++ b/src/client_registry.rs @@ -0,0 +1,280 @@ +//! Global client connection registry for CLIENT LIST/INFO/KILL. +//! +//! Every connection registers on accept and deregisters on close. +//! The registry is a global `parking_lot::RwLock` — not on +//! the command hot path (only touched on connect/disconnect and CLIENT commands). + +use parking_lot::RwLock; +use std::collections::HashMap; +use std::sync::LazyLock; +use std::sync::atomic::{AtomicBool, Ordering}; +use std::time::Instant; + +/// Global client registry. +static REGISTRY: LazyLock>> = + LazyLock::new(|| RwLock::new(HashMap::new())); + +/// Information about a connected client. +pub struct ClientEntry { + pub id: u64, + pub addr: String, + pub name: Option, + pub user: String, + pub db: usize, + pub shard: usize, + pub flags: ClientFlags, + pub connected_at: Instant, + pub last_cmd_at: Instant, + /// Set by CLIENT KILL — handler checks this and closes the connection. + pub kill_flag: AtomicBool, +} + +/// Client connection flags (matches Redis CLIENT LIST flag characters). +#[derive(Clone, Copy, Default)] +pub struct ClientFlags { + pub subscriber: bool, + pub in_multi: bool, + pub blocked: bool, +} + +impl ClientFlags { + /// Format as Redis-compatible flag string (e.g., "N", "S", "x"). + pub fn to_flag_str(self) -> &'static str { + if self.subscriber { + "S" + } else if self.in_multi { + "x" + } else if self.blocked { + "b" + } else { + "N" + } + } +} + +/// Register a new client connection. +pub fn register(id: u64, addr: String, user: String, shard: usize) { + let now = Instant::now(); + let entry = ClientEntry { + id, + addr, + name: None, + user, + db: 0, + shard, + flags: ClientFlags::default(), + connected_at: now, + last_cmd_at: now, + kill_flag: AtomicBool::new(false), + }; + REGISTRY.write().insert(id, entry); +} + +/// Deregister a client connection. +pub fn deregister(id: u64) { + REGISTRY.write().remove(&id); +} + +/// Update mutable fields for a client (called periodically or on state change). +pub fn update(id: u64, f: F) { + if let Some(entry) = REGISTRY.write().get_mut(&id) { + f(entry); + } +} + +/// Check if a client has been marked for killing. +pub fn is_killed(id: u64) -> bool { + REGISTRY + .read() + .get(&id) + .is_some_and(|e| e.kill_flag.load(Ordering::Relaxed)) +} + +/// Format all clients as a CLIENT LIST string. +/// +/// Each line: `id=N addr=... fd=0 name=... db=N ...` +/// Returns the full response string. +pub fn client_list() -> String { + let registry = REGISTRY.read(); + let now = Instant::now(); + let mut result = String::with_capacity(registry.len() * 128); + for entry in registry.values() { + format_client_line(&mut result, entry, now); + } + // Remove trailing newline if present + if result.ends_with('\n') { + result.pop(); + } + result +} + +/// Format a single client's info (for CLIENT INFO). +pub fn client_info(id: u64) -> Option { + let registry = REGISTRY.read(); + let now = Instant::now(); + registry.get(&id).map(|entry| { + let mut result = String::with_capacity(128); + format_client_line(&mut result, entry, now); + if result.ends_with('\n') { + result.pop(); + } + result + }) +} + +/// Kill clients matching the given filter. Returns count of killed clients. +pub fn kill_clients(filter: &KillFilter) -> u64 { + let registry = REGISTRY.read(); + let mut count = 0u64; + for entry in registry.values() { + let matches = match filter { + KillFilter::Id(target_id) => entry.id == *target_id, + KillFilter::Addr(addr) => entry.addr == *addr, + KillFilter::User(user) => entry.user == *user, + }; + if matches { + entry.kill_flag.store(true, Ordering::Relaxed); + count += 1; + } + } + count +} + +/// Filter for CLIENT KILL. +pub enum KillFilter { + Id(u64), + Addr(String), + User(String), +} + +/// Parse CLIENT KILL arguments into a KillFilter. +/// +/// Supports both the legacy form (`CLIENT KILL addr:port`) and the modern +/// filter form (`CLIENT KILL ID id`, `CLIENT KILL ADDR addr`, `CLIENT KILL USER user`). +pub fn parse_kill_args(args: &[&[u8]]) -> Option { + if args.is_empty() { + return None; + } + // Legacy single-arg form: CLIENT KILL addr:port + if args.len() == 1 { + let addr = std::str::from_utf8(args[0]).ok()?; + return Some(KillFilter::Addr(addr.to_string())); + } + // Modern filter form: CLIENT KILL ID|ADDR|USER value + let mut i = 0; + while i + 1 < args.len() { + let key = args[i]; + let val = args[i + 1]; + if key.eq_ignore_ascii_case(b"ID") { + let id_str = std::str::from_utf8(val).ok()?; + let id = id_str.parse::().ok()?; + return Some(KillFilter::Id(id)); + } else if key.eq_ignore_ascii_case(b"ADDR") { + let addr = std::str::from_utf8(val).ok()?; + return Some(KillFilter::Addr(addr.to_string())); + } else if key.eq_ignore_ascii_case(b"USER") { + let user = std::str::from_utf8(val).ok()?; + return Some(KillFilter::User(user.to_string())); + } + i += 2; + } + None +} + +fn format_client_line(buf: &mut String, entry: &ClientEntry, now: Instant) { + use std::fmt::Write; + let age = now.duration_since(entry.connected_at).as_secs(); + let idle = now.duration_since(entry.last_cmd_at).as_secs(); + let name = entry.name.as_deref().unwrap_or(""); + let flags = entry.flags.to_flag_str(); + let _ = writeln!( + buf, + "id={} addr={} fd=0 name={} db={} sub=0 psub=0 ssub=0 multi=-1 \ + watch=0 qbuf=0 qbuf-free=0 argv-mem=0 tot-mem=0 net-i=0 net-o=0 \ + age={} idle={} flags={} user={}", + entry.id, entry.addr, name, entry.db, age, idle, flags, entry.user, + ); +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_register_and_list() { + let id = 999_000; + register(id, "127.0.0.1:12345".into(), "default".into(), 0); + let list = client_list(); + assert!(list.contains("id=999000")); + assert!(list.contains("addr=127.0.0.1:12345")); + assert!(list.contains("user=default")); + deregister(id); + let list = client_list(); + assert!(!list.contains("id=999000")); + } + + #[test] + fn test_client_info() { + let id = 999_001; + register(id, "10.0.0.1:5000".into(), "alice".into(), 1); + let info = client_info(id); + assert!(info.is_some()); + assert!(info.as_ref().is_some_and(|s| s.contains("user=alice"))); + deregister(id); + assert!(client_info(id).is_none()); + } + + #[test] + fn test_kill_by_id() { + let id = 999_002; + register(id, "10.0.0.2:6000".into(), "bob".into(), 0); + assert!(!is_killed(id)); + let count = kill_clients(&KillFilter::Id(id)); + assert_eq!(count, 1); + assert!(is_killed(id)); + deregister(id); + } + + #[test] + fn test_kill_by_user() { + let id1 = 999_010; + let id2 = 999_011; + register(id1, "10.0.0.3:7000".into(), "eve".into(), 0); + register(id2, "10.0.0.4:7001".into(), "eve".into(), 1); + let count = kill_clients(&KillFilter::User("eve".into())); + assert_eq!(count, 2); + assert!(is_killed(id1)); + assert!(is_killed(id2)); + deregister(id1); + deregister(id2); + } + + #[test] + fn test_update() { + let id = 999_003; + register(id, "10.0.0.5:8000".into(), "default".into(), 0); + update(id, |e| { + e.name = Some("myconn".into()); + e.db = 3; + }); + let info = client_info(id).unwrap(); + assert!(info.contains("name=myconn")); + assert!(info.contains("db=3")); + deregister(id); + } + + #[test] + fn test_parse_kill_args() { + let args: Vec<&[u8]> = vec![b"ID", b"42"]; + let filter = parse_kill_args(&args).unwrap(); + assert!(matches!(filter, KillFilter::Id(42))); + + let args: Vec<&[u8]> = vec![b"ADDR", b"127.0.0.1:6379"]; + let filter = parse_kill_args(&args).unwrap(); + assert!(matches!(filter, KillFilter::Addr(a) if a == "127.0.0.1:6379")); + + let args: Vec<&[u8]> = vec![b"USER", b"alice"]; + let filter = parse_kill_args(&args).unwrap(); + assert!(matches!(filter, KillFilter::User(u) if u == "alice")); + } +} diff --git a/src/command/acl.rs b/src/command/acl.rs index 19ee4eff..7f59a1bf 100644 --- a/src/command/acl.rs +++ b/src/command/acl.rs @@ -329,6 +329,29 @@ pub fn handle_acl( } } + "GENPASS" => { + // ACL GENPASS [bits] — generate cryptographically secure random password + let bits: usize = if let Some(arg) = args.first() { + match extract_str(arg).and_then(|s| s.parse().ok()) { + Some(b) if b > 0 && b <= 4096 => b, + _ => { + return Frame::Error(Bytes::from_static( + b"ERR ACL GENPASS argument must be a positive integer up to 4096", + )); + } + } + } else { + 256 // default: 256 bits = 64 hex chars + }; + let byte_count = (bits + 7) / 8; + let mut buf = vec![0u8; byte_count]; + rand::RngExt::fill(&mut rand::rng(), &mut buf[..]); + let hex = hex::encode(&buf); + // Truncate to exact number of hex chars for the requested bits + let hex_chars = (bits + 3) / 4; // 4 bits per hex char + Frame::BulkString(Bytes::from(hex[..hex_chars].to_string())) + } + _ => Frame::Error(Bytes::from(format!( "ERR unknown subcommand '{}'. Try ACL HELP.", sub @@ -670,4 +693,51 @@ mod tests { let result = handle_acl(&args, &table, &mut log, "default", "127.0.0.1:1234", &rc); assert!(matches!(result, Frame::Error(_))); } + + #[test] + fn test_acl_genpass_default() { + let table = make_acl_table(); + let mut log = AclLog::new(128); + let rc = make_runtime_config(); + let args = vec![Frame::BulkString(Bytes::from_static(b"GENPASS"))]; + let result = handle_acl(&args, &table, &mut log, "default", "127.0.0.1:1234", &rc); + match result { + Frame::BulkString(b) => { + assert_eq!(b.len(), 64); // 256 bits = 64 hex chars + assert!(b.iter().all(|&c| c.is_ascii_hexdigit())); + } + _ => panic!("Expected BulkString, got {:?}", result), + } + } + + #[test] + fn test_acl_genpass_custom_bits() { + let table = make_acl_table(); + let mut log = AclLog::new(128); + let rc = make_runtime_config(); + let args = vec![ + Frame::BulkString(Bytes::from_static(b"GENPASS")), + Frame::BulkString(Bytes::from_static(b"128")), + ]; + let result = handle_acl(&args, &table, &mut log, "default", "127.0.0.1:1234", &rc); + match result { + Frame::BulkString(b) => { + assert_eq!(b.len(), 32); // 128 bits = 32 hex chars + } + _ => panic!("Expected BulkString"), + } + } + + #[test] + fn test_acl_genpass_invalid_bits() { + let table = make_acl_table(); + let mut log = AclLog::new(128); + let rc = make_runtime_config(); + let args = vec![ + Frame::BulkString(Bytes::from_static(b"GENPASS")), + Frame::BulkString(Bytes::from_static(b"0")), + ]; + let result = handle_acl(&args, &table, &mut log, "default", "127.0.0.1:1234", &rc); + assert!(matches!(result, Frame::Error(_))); + } } diff --git a/src/command/config.rs b/src/command/config.rs index a9c575fc..3a679618 100644 --- a/src/command/config.rs +++ b/src/command/config.rs @@ -51,6 +51,9 @@ pub fn config_get( runtime_config.protected_mode.clone(), ), (b"acllog-max-len", runtime_config.acllog_max_len.to_string()), + (b"maxclients", runtime_config.maxclients.to_string()), + (b"timeout", runtime_config.timeout.to_string()), + (b"tcp-keepalive", runtime_config.tcp_keepalive.to_string()), ]; let mut result = Vec::new(); @@ -171,6 +174,33 @@ pub fn config_set(runtime_config: &mut RuntimeConfig, args: &[Frame]) -> Frame { ))); } }, + "maxclients" => match value_str.parse::() { + Ok(v) => runtime_config.maxclients = v, + Err(_) => { + return Frame::Error(Bytes::from(format!( + "ERR Invalid argument '{}' for CONFIG SET 'maxclients'", + value_str + ))); + } + }, + "timeout" => match value_str.parse::() { + Ok(v) => runtime_config.timeout = v, + Err(_) => { + return Frame::Error(Bytes::from(format!( + "ERR Invalid argument '{}' for CONFIG SET 'timeout'", + value_str + ))); + } + }, + "tcp-keepalive" => match value_str.parse::() { + Ok(v) => runtime_config.tcp_keepalive = v, + Err(_) => { + return Frame::Error(Bytes::from(format!( + "ERR Invalid argument '{}' for CONFIG SET 'tcp-keepalive'", + value_str + ))); + } + }, _ => { return Frame::Error(Bytes::from(format!( "ERR Unsupported CONFIG parameter: {}", diff --git a/src/config.rs b/src/config.rs index bba0516a..89514f41 100644 --- a/src/config.rs +++ b/src/config.rs @@ -94,6 +94,18 @@ pub struct ServerConfig { #[arg(long, default_value = "yes")] pub protected_mode: String, + /// Maximum number of simultaneous client connections (0 = unlimited) + #[arg(long, default_value_t = 10000)] + pub maxclients: usize, + + /// Close connections idle for more than N seconds (0 = disabled) + #[arg(long, default_value_t = 0)] + pub timeout: u64, + + /// TCP keepalive interval in seconds (0 = disabled). Sets SO_KEEPALIVE on accepted sockets. + #[arg(long = "tcp-keepalive", default_value_t = 300)] + pub tcp_keepalive: u64, + /// Maximum number of entries in the ACL log #[arg(long, default_value_t = 128)] pub acllog_max_len: usize, @@ -286,6 +298,9 @@ impl ServerConfig { requirepass: self.requirepass.clone(), protected_mode: self.protected_mode.clone(), acllog_max_len: self.acllog_max_len, + maxclients: self.maxclients, + timeout: self.timeout, + tcp_keepalive: self.tcp_keepalive, } } } @@ -321,6 +336,12 @@ pub struct RuntimeConfig { pub protected_mode: String, /// Maximum number of entries in the ACL log (mutable via CONFIG SET). pub acllog_max_len: usize, + /// Maximum number of simultaneous client connections (0 = unlimited). + pub maxclients: usize, + /// Close connections idle for more than N seconds (0 = disabled). + pub timeout: u64, + /// TCP keepalive interval in seconds (0 = disabled). + pub tcp_keepalive: u64, } impl Default for RuntimeConfig { @@ -339,6 +360,9 @@ impl Default for RuntimeConfig { requirepass: None, protected_mode: "yes".to_string(), acllog_max_len: 128, + maxclients: 10000, + timeout: 0, + tcp_keepalive: 300, } } } diff --git a/src/lib.rs b/src/lib.rs index d66cb9fc..c4f0a2c9 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -39,7 +39,10 @@ pub mod acl; pub mod admin; +pub mod auth_ratelimit; pub mod blocking; +pub mod client_pause; +pub mod client_registry; pub mod cluster; pub mod command; pub mod config; diff --git a/src/server/conn/handler_monoio.rs b/src/server/conn/handler_monoio.rs index 9dd9a509..95bfd22a 100644 --- a/src/server/conn/handler_monoio.rs +++ b/src/server/conn/handler_monoio.rs @@ -81,7 +81,8 @@ pub(crate) async fn handle_connection_sharded_monoio< ) -> (MonoioHandlerResult, Option) { use monoio::io::AsyncWriteRentExt; - crate::admin::metrics_setup::record_connection_opened(); + // NOTE: do NOT call record_connection_opened() here — the caller + // (conn_accept.rs) already increments via try_accept_connection(). let mut read_buf = if initial_read_buf.is_empty() { BytesMut::with_capacity(8192) @@ -104,6 +105,21 @@ pub(crate) async fn handle_connection_sharded_monoio< ); let db_count = ctx.shard_databases.db_count(); + // Register in global client registry for CLIENT LIST/INFO/KILL. + crate::client_registry::register( + client_id, + peer_addr.clone(), + conn.current_user.clone(), + ctx.shard_id, + ); + struct RegistryGuard(u64); + impl Drop for RegistryGuard { + fn drop(&mut self) { + crate::client_registry::deregister(self.0); + } + } + let _registry_guard = RegistryGuard(client_id); + // Functions API registry (per-connection, lazy init) — kept as local because Rc> is !Send let func_registry = Rc::new(RefCell::new(crate::scripting::FunctionRegistry::new())); @@ -111,6 +127,14 @@ pub(crate) async fn handle_connection_sharded_monoio< // Monoio's ownership I/O takes ownership and returns the buffer, so we reassign. let mut tmp_buf = vec![0u8; 8192]; + // Client idle timeout: 0 = disabled (read once, avoid lock on hot path) + let idle_timeout_secs = ctx.runtime_config.read().timeout; + let idle_timeout = if idle_timeout_secs > 0 { + Some(std::time::Duration::from_secs(idle_timeout_secs)) + } else { + None + }; + // Pre-allocate batch containers outside the loop to avoid per-batch heap allocation. // These are cleared and reused each iteration instead of being recreated. let mut responses: Vec = Vec::with_capacity(64); @@ -125,6 +149,11 @@ pub(crate) async fn handle_connection_sharded_monoio< let mut frames: Vec = Vec::with_capacity(64); loop { + // Check if CLIENT KILL targeted this connection + if crate::client_registry::is_killed(client_id) { + break; + } + // Subscriber mode: bidirectional select on client commands + published messages if conn.subscription_count > 0 { #[allow(clippy::unwrap_used)] @@ -387,18 +416,36 @@ pub(crate) async fn handle_connection_sharded_monoio< if tmp_buf.len() < 8192 { tmp_buf.resize(8192, 0); } - let (result, returned_buf) = stream.read(tmp_buf).await; - tmp_buf = returned_buf; - match result { - Ok(0) => { - // Client half-closed — break out of loop. - // Stream drop (end of function) triggers monoio's cleanup. - break; + if let Some(dur) = idle_timeout { + // Timeout-aware read: select between read and sleep. + // monoio::select! drops the losing future, so tmp_buf ownership transfers. + // We allocate a fresh buffer when timeout is enabled (safety feature, not hot path). + let timeout_buf = std::mem::take(&mut tmp_buf); + monoio::select! { + read_result = stream.read(timeout_buf) => { + let (result, returned_buf) = read_result; + tmp_buf = returned_buf; + match result { + Ok(0) => break, + Ok(n) => { read_buf.extend_from_slice(&tmp_buf[..n]); } + Err(_) => break, + } + } + _ = monoio::time::sleep(dur) => { + tracing::debug!("Connection {} idle timeout ({}s)", client_id, idle_timeout_secs); + break; + } } - Ok(n) => { - read_buf.extend_from_slice(&tmp_buf[..n]); + } else { + let (result, returned_buf) = stream.read(tmp_buf).await; + tmp_buf = returned_buf; + match result { + Ok(0) => break, + Ok(n) => { + read_buf.extend_from_slice(&tmp_buf[..n]); + } + Err(_) => break, } - Err(_) => break, } // Inline dispatch: handle GET/SET directly from raw bytes without Frame @@ -451,6 +498,12 @@ pub(crate) async fn handle_connection_sharded_monoio< continue; } + // CLIENT PAUSE: delay processing if server is paused + crate::client_pause::expire_if_needed(); + if let Some(remaining) = crate::client_pause::check_pause(true) { + monoio::time::sleep(remaining).await; + } + // Process frames with shard routing, cross-shard dispatch, and AOF logging // Note: do NOT clear write_buf -- it may contain responses from inline dispatch. // The inline path appends directly; the normal path appends via encode_frame below. @@ -470,6 +523,8 @@ pub(crate) async fn handle_connection_sharded_monoio< guard.refresh_now_from_cache(&ctx.cached_clock); } + let mut auth_delay_ms: u64 = 0; + for frame in frames.drain(..) { // --- AUTH gate --- if !conn.authenticated { @@ -479,7 +534,13 @@ pub(crate) async fn handle_connection_sharded_monoio< if let Some(uname) = opt_user { conn.authenticated = true; conn.current_user = uname; + if let Ok(addr) = peer_addr.parse::() { + crate::auth_ratelimit::record_success(addr.ip()); + } } else { + if let Ok(addr) = peer_addr.parse::() { + auth_delay_ms += crate::auth_ratelimit::record_failure(addr.ip()); + } conn.acl_log.push(crate::acl::AclLogEntry { reason: "auth".to_string(), object: "AUTH".to_string(), @@ -509,8 +570,18 @@ pub(crate) async fn handle_connection_sharded_monoio< if let Some(name) = new_name { conn.client_name = Some(name); } - if let Some(uname) = opt_user { - conn.current_user = uname; + if let Some(ref uname) = opt_user { + conn.current_user = uname.clone(); + } + // HELLO AUTH rate limiting + if matches!(&response, Frame::Error(_)) { + if let Ok(addr) = peer_addr.parse::() { + auth_delay_ms += crate::auth_ratelimit::record_failure(addr.ip()); + } + } else if opt_user.is_some() { + if let Ok(addr) = peer_addr.parse::() { + crate::auth_ratelimit::record_success(addr.ip()); + } } responses.push(response); continue; @@ -690,6 +761,11 @@ pub(crate) async fn handle_connection_sharded_monoio< let (response, opt_user) = conn_cmd::auth_acl(cmd_args, &ctx.acl_table); if let Some(uname) = opt_user { conn.current_user = uname; + if let Ok(addr) = peer_addr.parse::() { + crate::auth_ratelimit::record_success(addr.ip()); + } + } else if let Ok(addr) = peer_addr.parse::() { + auth_delay_ms += crate::auth_ratelimit::record_failure(addr.ip()); } responses.push(response); continue; @@ -710,8 +786,17 @@ pub(crate) async fn handle_connection_sharded_monoio< if let Some(name) = new_name { conn.client_name = Some(name); } - if let Some(uname) = opt_user { - conn.current_user = uname; + if let Some(ref uname) = opt_user { + conn.current_user = uname.clone(); + } + if matches!(&response, Frame::Error(_)) { + if let Ok(addr) = peer_addr.parse::() { + auth_delay_ms += crate::auth_ratelimit::record_failure(addr.ip()); + } + } else if opt_user.is_some() { + if let Ok(addr) = peer_addr.parse::() { + crate::auth_ratelimit::record_success(addr.ip()); + } } responses.push(response); continue; @@ -842,6 +927,13 @@ pub(crate) async fn handle_connection_sharded_monoio< ))); } else { conn.client_name = extract_bytes(&cmd_args[1]); + let name_str = conn + .client_name + .as_ref() + .map(|b| String::from_utf8_lossy(b).to_string()); + crate::client_registry::update(client_id, |e| { + e.name = name_str; + }); responses.push(Frame::SimpleString(Bytes::from_static(b"OK"))); } continue; @@ -898,18 +990,11 @@ pub(crate) async fn handle_connection_sharded_monoio< } } } - // Unknown CLIENT subcommand - responses.push(Frame::Error(Bytes::from(format!( - "ERR unknown subcommand '{}'", - String::from_utf8_lossy(&sub_bytes) - )))); - continue; + // Admin CLIENT subcommands (LIST, INFO, KILL, PAUSE, UNPAUSE, + // NO-EVICT, NO-TOUCH) fall through to the ACL gate below. } } - responses.push(Frame::Error(Bytes::from_static( - b"ERR wrong number of arguments for 'client' command", - ))); - continue; + // Fall through — admin subcommands handled after ACL check. } // --- PUBLISH: local delivery + cross-shard fan-out --- @@ -1218,6 +1303,123 @@ pub(crate) async fn handle_connection_sharded_monoio< } } + // --- CLIENT admin subcommands (LIST, INFO, KILL, PAUSE, UNPAUSE) --- + // Placed AFTER ACL check so restricted users cannot access admin ops. + if cmd.eq_ignore_ascii_case(b"CLIENT") { + if let Some(sub) = cmd_args.first() { + if let Some(sub_bytes) = extract_bytes(sub) { + if sub_bytes.eq_ignore_ascii_case(b"LIST") { + crate::client_registry::update(client_id, |e| { + e.db = conn.selected_db; + e.last_cmd_at = std::time::Instant::now(); + e.flags = crate::client_registry::ClientFlags { + subscriber: conn.subscription_count > 0, + in_multi: conn.in_multi, + blocked: false, + }; + }); + let list = crate::client_registry::client_list(); + responses.push(Frame::BulkString(Bytes::from(list))); + continue; + } + if sub_bytes.eq_ignore_ascii_case(b"INFO") { + crate::client_registry::update(client_id, |e| { + e.db = conn.selected_db; + e.last_cmd_at = std::time::Instant::now(); + }); + let info = + crate::client_registry::client_info(client_id).unwrap_or_default(); + responses.push(Frame::BulkString(Bytes::from(info))); + continue; + } + if sub_bytes.eq_ignore_ascii_case(b"KILL") { + let raw_args: Vec<&[u8]> = cmd_args[1..] + .iter() + .filter_map(|f| match f { + Frame::BulkString(b) => Some(b.as_ref()), + Frame::SimpleString(b) => Some(b.as_ref()), + _ => None, + }) + .collect(); + match crate::client_registry::parse_kill_args(&raw_args) { + Some(filter) => { + let count = crate::client_registry::kill_clients(&filter); + responses.push(Frame::Integer(count as i64)); + } + None => { + responses.push(Frame::Error(Bytes::from_static( + b"ERR syntax error. Usage: CLIENT KILL [ID id] [ADDR addr] [USER user]", + ))); + } + } + continue; + } + if sub_bytes.eq_ignore_ascii_case(b"PAUSE") { + if cmd_args.len() < 2 { + responses.push(Frame::Error(Bytes::from_static( + b"ERR wrong number of arguments for 'CLIENT PAUSE' command", + ))); + } else { + let timeout_bytes = match &cmd_args[1] { + Frame::BulkString(b) => Some(b.as_ref()), + Frame::SimpleString(b) => Some(b.as_ref()), + _ => None, + }; + match timeout_bytes + .and_then(|b| std::str::from_utf8(b).ok()) + .and_then(|s| s.parse::().ok()) + { + Some(ms) => { + let mode = if cmd_args.len() > 2 { + match &cmd_args[2] { + Frame::BulkString(b) | Frame::SimpleString(b) + if b.eq_ignore_ascii_case(b"WRITE") => + { + crate::client_pause::PauseMode::Write + } + _ => crate::client_pause::PauseMode::All, + } + } else { + crate::client_pause::PauseMode::All + }; + crate::client_pause::pause(ms, mode); + responses + .push(Frame::SimpleString(Bytes::from_static(b"OK"))); + } + None => { + responses.push(Frame::Error(Bytes::from_static( + b"ERR timeout is not a valid integer or out of range", + ))); + } + } + } + continue; + } + if sub_bytes.eq_ignore_ascii_case(b"UNPAUSE") { + crate::client_pause::unpause(); + responses.push(Frame::SimpleString(Bytes::from_static(b"OK"))); + continue; + } + if sub_bytes.eq_ignore_ascii_case(b"NO-EVICT") + || sub_bytes.eq_ignore_ascii_case(b"NO-TOUCH") + { + responses.push(Frame::SimpleString(Bytes::from_static(b"OK"))); + continue; + } + // Unknown CLIENT subcommand + responses.push(Frame::Error(Bytes::from(format!( + "ERR unknown subcommand '{}'", + String::from_utf8_lossy(&sub_bytes) + )))); + continue; + } + } + responses.push(Frame::Error(Bytes::from_static( + b"ERR wrong number of arguments for 'client' command", + ))); + continue; + } + // --- Functions API: FUNCTION/FCALL/FCALL_RO --- // Placed AFTER ACL check. Respects MULTI queue — if conn.in_multi, // fall through to the MULTI queue gate instead of executing. @@ -2004,6 +2206,11 @@ pub(crate) async fn handle_connection_sharded_monoio< } } + // AUTH rate limiting: delay response to slow down brute-force attacks + if auth_delay_ms > 0 { + monoio::time::sleep(std::time::Duration::from_millis(auth_delay_ms)).await; + } + // Serialize all responses into write_buf, then do ONE write_all syscall. for response in &responses { codec.encode_frame(response, &mut write_buf); @@ -2019,6 +2226,17 @@ pub(crate) async fn handle_connection_sharded_monoio< } } + // Update registry with current state after each batch + crate::client_registry::update(client_id, |e| { + e.db = conn.selected_db; + e.last_cmd_at = std::time::Instant::now(); + e.flags = crate::client_registry::ClientFlags { + subscriber: conn.subscription_count > 0, + in_multi: conn.in_multi, + blocked: false, + }; + }); + // Check if migration was triggered during frame processing. // All responses for the current batch have been written, so the // client sees no interruption -- TCP socket stays open. diff --git a/src/server/conn/handler_sharded.rs b/src/server/conn/handler_sharded.rs index d514e57e..fa29ed38 100644 --- a/src/server/conn/handler_sharded.rs +++ b/src/server/conn/handler_sharded.rs @@ -67,12 +67,19 @@ use super::{ /// connection level same as the non-sharded handler. #[tracing::instrument(skip_all, level = "debug")] pub(crate) async fn handle_connection_sharded( - stream: TcpStream, + mut stream: TcpStream, ctx: &super::core::ConnectionContext, shutdown: CancellationToken, client_id: u64, ) { - crate::admin::metrics_setup::record_connection_opened(); + let maxclients = ctx.runtime_config.read().maxclients; + if !crate::admin::metrics_setup::try_accept_connection(maxclients) { + use tokio::io::AsyncWriteExt; + let _ = stream + .write_all(b"-ERR max number of clients reached\r\n") + .await; + return; + } let peer_addr = stream .peer_addr() .map(|a| a.to_string()) @@ -215,6 +222,22 @@ pub(crate) async fn handle_connection_sharded_inner< migrated_state, ); + // Register in global client registry for CLIENT LIST/INFO/KILL. + // RegistryGuard ensures deregister on all exit paths (including early returns). + crate::client_registry::register( + client_id, + peer_addr.clone(), + conn.current_user.clone(), + ctx.shard_id, + ); + struct RegistryGuard(u64); + impl Drop for RegistryGuard { + fn drop(&mut self) { + crate::client_registry::deregister(self.0); + } + } + let _registry_guard = RegistryGuard(client_id); + use crate::pubsub::{self, subscriber::Subscriber}; // Functions API registry (per-shard, lazy init) — kept as local because Rc> is !Send @@ -229,8 +252,21 @@ pub(crate) async fn handle_connection_sharded_inner< // Pre-allocated response slots for zero-allocation cross-shard dispatch. let response_pool = ResponseSlotPool::new(ctx.num_shards, ctx.shard_id); + // Client idle timeout: 0 = disabled (read once, avoid lock on hot path) + let idle_timeout_secs = ctx.runtime_config.read().timeout; + let idle_timeout = if idle_timeout_secs > 0 { + Some(std::time::Duration::from_secs(idle_timeout_secs)) + } else { + None + }; + let mut break_outer = false; loop { + // Check if CLIENT KILL targeted this connection + if crate::client_registry::is_killed(client_id) { + break; + } + // --- Subscriber mode: bidirectional select on client commands + published messages --- if conn.subscription_count > 0 { #[allow(clippy::unwrap_used)] @@ -437,10 +473,23 @@ pub(crate) async fn handle_connection_sharded_inner< continue; } tokio::select! { - result = stream.read_buf(&mut read_buf) => { + result = async { + if let Some(dur) = idle_timeout { + match tokio::time::timeout(dur, stream.read_buf(&mut read_buf)).await { + Ok(r) => r, + Err(_) => Err(std::io::Error::new(std::io::ErrorKind::TimedOut, "idle timeout")), + } + } else { + stream.read_buf(&mut read_buf).await + } + } => { match result { Ok(0) => break, // connection closed Ok(_) => {} + Err(e) if e.kind() == std::io::ErrorKind::TimedOut => { + tracing::debug!("Connection {} idle timeout ({}s)", client_id, idle_timeout_secs); + break; + } Err(_) => break, } @@ -461,6 +510,13 @@ pub(crate) async fn handle_connection_sharded_inner< if break_outer { break; } if batch.is_empty() { continue; } + // CLIENT PAUSE: delay processing if server is paused + // Check with is_write=true (conservative — pauses all batches in ALL mode) + crate::client_pause::expire_if_needed(); + if let Some(remaining) = crate::client_pause::check_pause(true) { + tokio::time::sleep(remaining).await; + } + let mut responses: Vec = Vec::with_capacity(batch.len()); let mut should_quit = false; let mut remote_groups: HashMap, Option, Bytes, usize)>> = HashMap::with_capacity(ctx.num_shards); @@ -468,6 +524,9 @@ pub(crate) async fn handle_connection_sharded_inner< // Key: target shard ID -> Vec of (response_index, channel, message) let mut publish_batches: HashMap> = HashMap::new(); + // Track if AUTH rate limiting delay is needed (applied after batch response) + let mut auth_delay_ms: u64 = 0; + for frame in batch { // --- AUTH gate --- if !conn.authenticated { @@ -477,7 +536,13 @@ pub(crate) async fn handle_connection_sharded_inner< if let Some(uname) = opt_user { conn.authenticated = true; conn.current_user = uname; + if let Ok(addr) = peer_addr.parse::() { + crate::auth_ratelimit::record_success(addr.ip()); + } } else { + if let Ok(addr) = peer_addr.parse::() { + auth_delay_ms += crate::auth_ratelimit::record_failure(addr.ip()); + } conn.acl_log.push(crate::acl::AclLogEntry { reason: "auth".to_string(), object: "AUTH".to_string(), @@ -506,8 +571,18 @@ pub(crate) async fn handle_connection_sharded_inner< if let Some(name) = new_name { conn.client_name = Some(name); } - if let Some(uname) = opt_user { - conn.current_user = uname; + if let Some(ref uname) = opt_user { + conn.current_user = uname.clone(); + } + // HELLO AUTH rate limiting (same as AUTH gate) + if matches!(&response, Frame::Error(_)) { + if let Ok(addr) = peer_addr.parse::() { + auth_delay_ms += crate::auth_ratelimit::record_failure(addr.ip()); + } + } else if opt_user.is_some() { + if let Ok(addr) = peer_addr.parse::() { + crate::auth_ratelimit::record_success(addr.ip()); + } } responses.push(response); continue; @@ -655,7 +730,14 @@ pub(crate) async fn handle_connection_sharded_inner< // --- AUTH (already conn.authenticated) --- if cmd.eq_ignore_ascii_case(b"AUTH") { let (response, opt_user) = conn_cmd::auth_acl(cmd_args, &ctx.acl_table); - if let Some(uname) = opt_user { conn.current_user = uname; } + if let Some(uname) = opt_user { + conn.current_user = uname; + if let Ok(addr) = peer_addr.parse::() { + crate::auth_ratelimit::record_success(addr.ip()); + } + } else if let Ok(addr) = peer_addr.parse::() { + auth_delay_ms += crate::auth_ratelimit::record_failure(addr.ip()); + } responses.push(response); continue; } @@ -667,7 +749,16 @@ pub(crate) async fn handle_connection_sharded_inner< ); if !matches!(&response, Frame::Error(_)) { conn.protocol_version = new_proto; } if let Some(name) = new_name { conn.client_name = Some(name); } - if let Some(uname) = opt_user { conn.current_user = uname; } + if let Some(ref uname) = opt_user { conn.current_user = uname.clone(); } + if matches!(&response, Frame::Error(_)) { + if let Ok(addr) = peer_addr.parse::() { + auth_delay_ms += crate::auth_ratelimit::record_failure(addr.ip()); + } + } else if opt_user.is_some() { + if let Ok(addr) = peer_addr.parse::() { + crate::auth_ratelimit::record_success(addr.ip()); + } + } responses.push(response); continue; } @@ -860,6 +951,8 @@ pub(crate) async fn handle_connection_sharded_inner< ))); } else { conn.client_name = extract_bytes(&cmd_args[1]); + let name_str = conn.client_name.as_ref().map(|b| String::from_utf8_lossy(b).to_string()); + crate::client_registry::update(client_id, |e| { e.name = name_str; }); responses.push(Frame::SimpleString(Bytes::from_static(b"OK"))); } continue; @@ -905,6 +998,94 @@ pub(crate) async fn handle_connection_sharded_inner< Err(err_frame) => { responses.push(err_frame); continue; } } } + if sub_bytes.eq_ignore_ascii_case(b"LIST") { + // Update our own entry before listing + crate::client_registry::update(client_id, |e| { + e.db = conn.selected_db; + e.last_cmd_at = std::time::Instant::now(); + e.flags = crate::client_registry::ClientFlags { + subscriber: conn.subscription_count > 0, + in_multi: conn.in_multi, + blocked: false, + }; + }); + let list = crate::client_registry::client_list(); + responses.push(Frame::BulkString(Bytes::from(list))); + continue; + } + if sub_bytes.eq_ignore_ascii_case(b"INFO") { + crate::client_registry::update(client_id, |e| { + e.db = conn.selected_db; + e.last_cmd_at = std::time::Instant::now(); + }); + let info = crate::client_registry::client_info(client_id) + .unwrap_or_default(); + responses.push(Frame::BulkString(Bytes::from(info))); + continue; + } + if sub_bytes.eq_ignore_ascii_case(b"KILL") { + let raw_args: Vec<&[u8]> = cmd_args[1..].iter() + .filter_map(|f| match f { + Frame::BulkString(b) => Some(b.as_ref()), + Frame::SimpleString(b) => Some(b.as_ref()), + _ => None, + }) + .collect(); + match crate::client_registry::parse_kill_args(&raw_args) { + Some(filter) => { + let count = crate::client_registry::kill_clients(&filter); + responses.push(Frame::Integer(count as i64)); + } + None => { + responses.push(Frame::Error(Bytes::from_static( + b"ERR syntax error. Usage: CLIENT KILL [ID id] [ADDR addr] [USER user]", + ))); + } + } + continue; + } + if sub_bytes.eq_ignore_ascii_case(b"PAUSE") { + if cmd_args.len() < 2 { + responses.push(Frame::Error(Bytes::from_static( + b"ERR wrong number of arguments for 'CLIENT PAUSE' command", + ))); + } else { + let timeout_bytes = match &cmd_args[1] { + Frame::BulkString(b) => Some(b.as_ref()), + Frame::SimpleString(b) => Some(b.as_ref()), + _ => None, + }; + match timeout_bytes.and_then(|b| std::str::from_utf8(b).ok()).and_then(|s| s.parse::().ok()) { + Some(ms) => { + let mode = if cmd_args.len() > 2 { + match &cmd_args[2] { + Frame::BulkString(b) | Frame::SimpleString(b) if b.eq_ignore_ascii_case(b"WRITE") => crate::client_pause::PauseMode::Write, + _ => crate::client_pause::PauseMode::All, + } + } else { + crate::client_pause::PauseMode::All + }; + crate::client_pause::pause(ms, mode); + responses.push(Frame::SimpleString(Bytes::from_static(b"OK"))); + } + None => { + responses.push(Frame::Error(Bytes::from_static( + b"ERR timeout is not a valid integer or out of range", + ))); + } + } + } + continue; + } + if sub_bytes.eq_ignore_ascii_case(b"UNPAUSE") { + crate::client_pause::unpause(); + responses.push(Frame::SimpleString(Bytes::from_static(b"OK"))); + continue; + } + if sub_bytes.eq_ignore_ascii_case(b"NO-EVICT") || sub_bytes.eq_ignore_ascii_case(b"NO-TOUCH") { + responses.push(Frame::SimpleString(Bytes::from_static(b"OK"))); + continue; + } responses.push(Frame::Error(Bytes::from(format!( "ERR unknown subcommand '{}'", String::from_utf8_lossy(&sub_bytes) )))); @@ -1589,6 +1770,11 @@ pub(crate) async fn handle_connection_sharded_inner< arena.reset(); + // AUTH rate limiting: delay response to slow down brute-force attacks + if auth_delay_ms > 0 { + tokio::time::sleep(std::time::Duration::from_millis(auth_delay_ms)).await; + } + write_buf.clear(); for response in &responses { if conn.protocol_version >= 3 { @@ -1601,6 +1787,17 @@ pub(crate) async fn handle_connection_sharded_inner< return (HandlerResult::Done, None); } + // Update registry with current state after each batch + crate::client_registry::update(client_id, |e| { + e.db = conn.selected_db; + e.last_cmd_at = std::time::Instant::now(); + e.flags = crate::client_registry::ClientFlags { + subscriber: conn.subscription_count > 0, + in_multi: conn.in_multi, + blocked: false, + }; + }); + // Check if migration was triggered during frame processing. // All responses for the current batch have been written, so the // client sees no interruption -- TCP socket stays open. diff --git a/src/server/listener.rs b/src/server/listener.rs index 0f0a11aa..169191b5 100644 --- a/src/server/listener.rs +++ b/src/server/listener.rs @@ -216,6 +216,16 @@ pub async fn run_with_shutdown( } } + // Set TCP keepalive on accepted socket + if config.tcp_keepalive > 0 { + let interval = std::cmp::max(config.tcp_keepalive / 3, 1); + let ka = socket2::TcpKeepalive::new() + .with_time(std::time::Duration::from_secs(config.tcp_keepalive)) + .with_interval(std::time::Duration::from_secs(interval)); + let sock = socket2::SockRef::from(&stream); + let _ = sock.set_tcp_keepalive(&ka); + } + debug!("New connection from {}", addr); let db = db.clone(); let conn_token = token.child_token(); diff --git a/src/shard/conn_accept.rs b/src/shard/conn_accept.rs index d08a59ef..abec51f2 100644 --- a/src/shard/conn_accept.rs +++ b/src/shard/conn_accept.rs @@ -68,6 +68,27 @@ fn take_migration_read_buf(state: &mut MigratedConnectionState) -> BytesMut { std::mem::take(&mut state.read_buf_remainder) } +/// Set TCP keepalive on a raw file descriptor. +/// +/// Sets SO_KEEPALIVE and TCP_KEEPIDLE (Linux) / TCP_KEEPALIVE (macOS) to detect +/// dead connections. Called once per accepted socket. +#[cfg(unix)] +fn set_tcp_keepalive(fd: std::os::unix::io::RawFd, keepalive_secs: u64) { + if keepalive_secs == 0 { + return; + } + use std::os::unix::io::BorrowedFd; + // SAFETY: fd is a valid open socket owned by the caller. We borrow it + // for the duration of this function — SockRef does not close on drop. + let borrowed = unsafe { BorrowedFd::borrow_raw(fd) }; + let sock = socket2::SockRef::from(&borrowed); + let interval = std::cmp::max(keepalive_secs / 3, 1); + let ka = socket2::TcpKeepalive::new() + .with_time(std::time::Duration::from_secs(keepalive_secs)) + .with_interval(std::time::Duration::from_secs(interval)); + let _ = sock.set_tcp_keepalive(&ka); +} + /// Spawn a new tokio connection handler task (plain TCP or TLS). /// /// Clones all required Rc/Arc shared state and spawns via `tokio::task::spawn_local`. @@ -138,8 +159,17 @@ pub(crate) fn spawn_tokio_connection( let all_regs = all_pubsub_registries.to_vec(); let all_rsm = all_remote_sub_maps.to_vec(); let reqpass = rtcfg.read().requirepass.clone(); + let maxclients_tokio = rtcfg.read().maxclients; + let tcp_keepalive_secs = rtcfg.read().tcp_keepalive; let clk = cached_clock.clone(); + // Set TCP keepalive on accepted socket + #[cfg(unix)] + { + use std::os::unix::io::AsRawFd; + set_tcp_keepalive(tcp_stream.as_raw_fd(), tcp_keepalive_secs); + } + // Construct ConnectionContext from cloned shared state let conn_ctx = crate::server::conn::ConnectionContext::new( sdbs, @@ -178,7 +208,16 @@ pub(crate) fn spawn_tokio_connection( .peer_addr() .map(|a| a.to_string()) .unwrap_or_else(|_| "unknown".to_string()); + let maxclients = maxclients_tokio; tokio::task::spawn_local(async move { + // maxclients check for TLS connections (plain TCP checks in handle_connection_sharded) + if !crate::admin::metrics_setup::try_accept_connection(maxclients) { + tracing::warn!( + "Shard {}: TLS connection rejected: maxclients reached", + shard_id + ); + return; + } let acceptor = tokio_rustls::TlsAcceptor::from(tls_cfg); match acceptor.accept(tcp_stream).await { Ok(tls_stream) => { @@ -198,6 +237,7 @@ pub(crate) fn spawn_tokio_connection( tracing::warn!("Shard {}: TLS handshake failed: {}", shard_id, e); } } + crate::admin::metrics_setup::record_connection_closed(); }); } else { // Plain TCP connection @@ -407,6 +447,14 @@ pub(crate) fn spawn_monoio_connection( ) { use crate::server::connection::handle_connection_sharded_monoio; + // Set TCP keepalive on accepted socket before converting to monoio + #[cfg(unix)] + { + use std::os::unix::io::AsRawFd; + let keepalive_secs = runtime_config.read().tcp_keepalive; + set_tcp_keepalive(std_tcp_stream.as_raw_fd(), keepalive_secs); + } + match monoio::net::TcpStream::from_std(std_tcp_stream) { Ok(tcp_stream) => { let aff = affinity_tracker.clone(); @@ -460,10 +508,18 @@ pub(crate) fn spawn_monoio_connection( spill_fid, do_dir, ); + let maxclients = conn_ctx.runtime_config.read().maxclients; if let (true, Some(tls_swap)) = (is_tls, tls_config.as_ref()) { // Load current TLS config from ArcSwap — new connections see reloaded certs let tls_cfg = tls_swap.load_full(); monoio::spawn(async move { + if !crate::admin::metrics_setup::try_accept_connection(maxclients) { + tracing::warn!( + "Shard {}: TLS connection rejected: maxclients reached", + shard_id + ); + return; + } let acceptor = monoio_rustls::TlsAcceptor::from(tls_cfg); match acceptor.accept(tcp_stream).await { Ok(tls_stream) => { @@ -488,6 +544,7 @@ pub(crate) fn spawn_monoio_connection( ); } } + crate::admin::metrics_setup::record_connection_closed(); }); } else { // Plain TCP connection @@ -496,6 +553,13 @@ pub(crate) fn spawn_monoio_connection( #[cfg(target_os = "linux")] let notifiers2 = all_notifiers.to_vec(); monoio::spawn(async move { + if !crate::admin::metrics_setup::try_accept_connection(maxclients) { + tracing::warn!( + "Shard {}: connection rejected: maxclients reached", + shard_id + ); + return; + } let _result = handle_connection_sharded_monoio( tcp_stream, peer_addr, @@ -512,6 +576,8 @@ pub(crate) fn spawn_monoio_connection( // Handle migration result: extract FD via dup() and send via SPSC. // libc::dup is only available on Linux (target-specific dependency). #[cfg(target_os = "linux")] + let mut _migrated = false; + #[cfg(target_os = "linux")] if let (crate::server::conn::handler_monoio::MonoioHandlerResult::MigrateConnection { state, target_shard }, Some(stream)) = (_result.0, _result.1) { use std::os::unix::io::{AsRawFd, FromRawFd}; use ringbuf::traits::Producer; @@ -533,6 +599,7 @@ pub(crate) fn spawn_monoio_connection( }; match push_result { Ok(()) => { + _migrated = true; notifiers2[target_shard].notify_one(); tracing::info!( "Shard {}: migrated connection {} to shard {} (monoio)", @@ -553,6 +620,14 @@ pub(crate) fn spawn_monoio_connection( } } } + + // Decrement connected_clients unless connection was migrated (stays alive on target shard) + #[cfg(target_os = "linux")] + if !_migrated { + crate::admin::metrics_setup::record_connection_closed(); + } + #[cfg(not(target_os = "linux"))] + crate::admin::metrics_setup::record_connection_closed(); }); } } diff --git a/src/storage/eviction.rs b/src/storage/eviction.rs index 1d625174..c3d5ec23 100644 --- a/src/storage/eviction.rs +++ b/src/storage/eviction.rs @@ -687,6 +687,9 @@ mod tests { requirepass: None, protected_mode: "yes".to_string(), acllog_max_len: 128, + maxclients: 10000, + timeout: 0, + tcp_keepalive: 300, } } diff --git a/tests/blocking_list_timeout.rs b/tests/blocking_list_timeout.rs index 95f250fc..0e364302 100644 --- a/tests/blocking_list_timeout.rs +++ b/tests/blocking_list_timeout.rs @@ -14,6 +14,17 @@ use redis::AsyncCommands; const MOON_PORT: u16 = 16479; +/// Skip test if moon server is not running on MOON_PORT. +macro_rules! require_moon_server { + () => { + let client = redis::Client::open(format!("redis://127.0.0.1:{}/", MOON_PORT)).unwrap(); + if client.get_multiplexed_async_connection().await.is_err() { + eprintln!("SKIP: moon server not running on port {}", MOON_PORT); + return; + } + }; +} + /// Get a multiplexed connection (good for non-blocking commands). async fn get_conn() -> redis::aio::MultiplexedConnection { let client = redis::Client::open(format!("redis://127.0.0.1:{}/", MOON_PORT)).unwrap(); @@ -34,6 +45,7 @@ async fn cleanup_keys(conn: &mut redis::aio::MultiplexedConnection, keys: &[&str #[tokio::test] #[ignore] // Requires running Moon server on port 16479 async fn blpop_timeout_returns_nil() { + require_moon_server!(); let mut conn = get_conn().await; cleanup_keys(&mut conn, &["empty_key_blpop"]).await; @@ -73,6 +85,7 @@ async fn blpop_timeout_returns_nil() { #[tokio::test] #[ignore] // Requires running Moon server on port 16479 async fn brpop_wakes_on_rpush() { + require_moon_server!(); let mut conn = get_conn().await; cleanup_keys(&mut conn, &["wake_key"]).await; @@ -115,6 +128,7 @@ async fn brpop_wakes_on_rpush() { #[tokio::test] #[ignore] // Requires running Moon server on port 16479 async fn blmpop_count_greater_than_one() { + require_moon_server!(); let mut conn = get_conn().await; cleanup_keys(&mut conn, &["blmpop_key"]).await; @@ -174,6 +188,7 @@ async fn blmpop_count_greater_than_one() { #[tokio::test] #[ignore] // Requires running Moon server on port 16479 async fn brpoplpush_legacy_alias() { + require_moon_server!(); let mut conn = get_conn().await; cleanup_keys(&mut conn, &["brpl_src", "brpl_dst"]).await; @@ -208,6 +223,7 @@ async fn brpoplpush_legacy_alias() { #[tokio::test] #[ignore] // Requires running Moon server on port 16479 async fn connection_drop_cleans_registry() { + require_moon_server!(); let mut conn = get_conn().await; cleanup_keys(&mut conn, &["drop_test_key"]).await; diff --git a/tests/functions_fcall.rs b/tests/functions_fcall.rs index cbdd356c..ccab15de 100644 --- a/tests/functions_fcall.rs +++ b/tests/functions_fcall.rs @@ -7,10 +7,21 @@ //! Requires a running moon server on the port specified by MOON_PORT (default 16479): //! ./target/release/moon --port 16479 --shards 1 //! -//! Run with: cargo test --release --test functions_fcall +//! Run with: cargo test --release --test functions_fcall -- --ignored const MOON_PORT: u16 = 16479; +/// Skip test if moon server is not running on MOON_PORT. +macro_rules! require_moon_server { + () => { + let client = redis::Client::open(format!("redis://127.0.0.1:{}/", MOON_PORT)).unwrap(); + if client.get_multiplexed_async_connection().await.is_err() { + eprintln!("SKIP: moon server not running on port {}", MOON_PORT); + return; + } + }; +} + /// Get a multiplexed connection. async fn get_conn() -> redis::aio::MultiplexedConnection { let client = redis::Client::open(format!("redis://127.0.0.1:{}/", MOON_PORT)).unwrap(); @@ -41,6 +52,7 @@ async fn flush_functions(con: &mut redis::aio::MultiplexedConnection) { #[tokio::test] #[ignore] // Requires running Moon server on port 16479 async fn function_load_and_fcall() { + require_moon_server!(); let mut con = get_conn().await; flush_functions(&mut con).await; @@ -60,6 +72,7 @@ async fn function_load_and_fcall() { #[tokio::test] #[ignore] // Requires running Moon server on port 16479 async fn function_load_missing_header_errors() { + require_moon_server!(); let mut con = get_conn().await; flush_functions(&mut con).await; @@ -76,6 +89,7 @@ async fn function_load_missing_header_errors() { #[tokio::test] #[ignore] // Requires running Moon server on port 16479 async fn function_load_duplicate_without_replace_errors() { + require_moon_server!(); let mut con = get_conn().await; flush_functions(&mut con).await; @@ -100,6 +114,7 @@ async fn function_load_duplicate_without_replace_errors() { #[tokio::test] #[ignore] // Requires running Moon server on port 16479 async fn function_load_replace_succeeds() { + require_moon_server!(); let mut con = get_conn().await; flush_functions(&mut con).await; @@ -128,6 +143,7 @@ async fn function_load_replace_succeeds() { #[tokio::test] #[ignore] // Requires running Moon server on port 16479 async fn function_list_returns_libraries() { + require_moon_server!(); let mut con = get_conn().await; flush_functions(&mut con).await; @@ -149,6 +165,7 @@ async fn function_list_returns_libraries() { #[tokio::test] #[ignore] // Requires running Moon server on port 16479 async fn function_delete_removes() { + require_moon_server!(); let mut con = get_conn().await; flush_functions(&mut con).await; @@ -177,6 +194,7 @@ async fn function_delete_removes() { #[tokio::test] #[ignore] // Requires running Moon server on port 16479 async fn fcall_ro_rejects_writes() { + require_moon_server!(); let mut con = get_conn().await; flush_functions(&mut con).await; @@ -198,6 +216,7 @@ async fn fcall_ro_rejects_writes() { #[tokio::test] #[ignore] // Requires running Moon server on port 16479 async fn function_dump_restore_stats_deferred() { + require_moon_server!(); let mut con = get_conn().await; // FUNCTION DUMP @@ -231,6 +250,7 @@ async fn function_dump_restore_stats_deferred() { #[tokio::test] #[ignore = "Phase 101: FUNCTION is RAM-only; persistence deferred"] async fn function_not_persistent_across_restart() { + require_moon_server!(); // This test documents the known limitation that functions are RAM-only // and will not survive a server restart. When persistence is added // (future phase), this test should be un-ignored and verify that diff --git a/tests/integration.rs b/tests/integration.rs index 04ee327d..d44c2078 100644 --- a/tests/integration.rs +++ b/tests/integration.rs @@ -72,6 +72,9 @@ async fn start_server() -> (u16, CancellationToken) { slowlog_log_slower_than: 10000, slowlog_max_len: 128, check_config: false, + maxclients: 10000, + timeout: 0, + tcp_keepalive: 300, }; tokio::spawn(async move { @@ -141,6 +144,9 @@ async fn start_server_with_pass(password: &str) -> (u16, CancellationToken) { slowlog_log_slower_than: 10000, slowlog_max_len: 128, check_config: false, + maxclients: 10000, + timeout: 0, + tcp_keepalive: 300, }; tokio::spawn(async move { @@ -1282,6 +1288,9 @@ async fn start_server_with_persistence( slowlog_log_slower_than: 10000, slowlog_max_len: 128, check_config: false, + maxclients: 10000, + timeout: 0, + tcp_keepalive: 300, }; tokio::spawn(async move { @@ -2135,6 +2144,9 @@ async fn start_server_with_maxmemory(maxmemory: usize, policy: &str) -> (u16, Ca slowlog_log_slower_than: 10000, slowlog_max_len: 128, check_config: false, + maxclients: 10000, + timeout: 0, + tcp_keepalive: 300, }; tokio::spawn(async move { @@ -2515,6 +2527,9 @@ async fn start_sharded_server(num_shards: usize) -> (u16, CancellationToken) { slowlog_log_slower_than: 10000, slowlog_max_len: 128, check_config: false, + maxclients: 10000, + timeout: 0, + tcp_keepalive: 300, }; let cancel = token.clone(); @@ -3664,6 +3679,9 @@ async fn start_cluster_server() -> (u16, CancellationToken) { slowlog_log_slower_than: 10000, slowlog_max_len: 128, check_config: false, + maxclients: 10000, + timeout: 0, + tcp_keepalive: 300, }; std::thread::spawn(move || { @@ -4295,6 +4313,9 @@ async fn start_server_with_aclfile(acl_path: &str) -> (u16, CancellationToken) { slowlog_log_slower_than: 10000, slowlog_max_len: 128, check_config: false, + maxclients: 10000, + timeout: 0, + tcp_keepalive: 300, }; tokio::spawn(async move { diff --git a/tests/replication_test.rs b/tests/replication_test.rs index d35b2763..493a6efa 100644 --- a/tests/replication_test.rs +++ b/tests/replication_test.rs @@ -70,6 +70,9 @@ async fn start_server() -> (u16, CancellationToken) { slowlog_log_slower_than: 10000, slowlog_max_len: 128, check_config: false, + maxclients: 10000, + timeout: 0, + tcp_keepalive: 300, }; tokio::spawn(async move {