From 5dd7e11c3d6236fe21adaffb1eda66a6ff164578 Mon Sep 17 00:00:00 2001 From: Tin Dang Date: Fri, 10 Apr 2026 22:44:53 +0700 Subject: [PATCH 01/20] feat: implement COPY command for atomic key duplication Add COPY source destination [REPLACE] with full Redis 6.2 semantics: - Returns 1 on success, 0 if dest exists without REPLACE - Clones entry including TTL metadata - DB option parsed but returns explicit unsupported error - 6 unit tests, entries in test-commands.sh and test-consistency.sh --- scripts/test-commands.sh | 6 ++ scripts/test-consistency.sh | 11 ++++ src/command/key.rs | 118 ++++++++++++++++++++++++++++++++++++ src/command/metadata.rs | 1 + src/command/mod.rs | 6 ++ 5 files changed, 142 insertions(+) diff --git a/scripts/test-commands.sh b/scripts/test-commands.sh index a1770a77..d7bf97f8 100755 --- a/scripts/test-commands.sh +++ b/scripts/test-commands.sh @@ -550,6 +550,12 @@ if should_run "key"; then rcli SET k:rnx1 v1 >/dev/null 2>&1; mcli SET k:rnx1 v1 >/dev/null 2>&1 rcli SET k:rnx2 v2 >/dev/null 2>&1; mcli SET k:rnx2 v2 >/dev/null 2>&1 assert_match "RENAMENX (blocked)" RENAMENX k:rnx1 k:rnx2 + rcli SET k:cpsrc cpval >/dev/null 2>&1; mcli SET k:cpsrc cpval >/dev/null 2>&1 + assert_match "COPY" COPY k:cpsrc k:cpdst + assert_match "GET after COPY" GET k:cpdst + rcli SET k:cpdst2 old >/dev/null 2>&1; mcli SET k:cpdst2 old >/dev/null 2>&1 + assert_match "COPY no REPLACE" COPY k:cpsrc k:cpdst2 + assert_match "COPY REPLACE" COPY k:cpsrc k:cpdst2 REPLACE assert_match "UNLINK" UNLINK k:renamed assert_moon_ok "DBSIZE" DBSIZE assert_moon_ok "SCAN cursor" SCAN 0 diff --git a/scripts/test-consistency.sh b/scripts/test-consistency.sh index f854cb50..e273a815 100755 --- a/scripts/test-consistency.sh +++ b/scripts/test-consistency.sh @@ -476,6 +476,17 @@ LONGKEY=$(python3 -c "print('k' * 500)") both SET "$LONGKEY" "long_key_value" assert_both "GET with 500-char key" GET "$LONGKEY" +# COPY +both SET edge:cpsrc "copy_value" +assert_both "COPY basic" COPY edge:cpsrc edge:cpdst +assert_both "GET after COPY src" GET edge:cpsrc +assert_both "GET after COPY dst" GET edge:cpdst +both SET edge:cpdst2 "old_value" +assert_both "COPY no REPLACE" COPY edge:cpsrc edge:cpdst2 +assert_both "GET COPY no REPLACE" GET edge:cpdst2 +assert_both "COPY REPLACE" COPY edge:cpsrc edge:cpdst2 REPLACE +assert_both "GET after COPY REPLACE" GET edge:cpdst2 + # =========================================================================== # Summary # =========================================================================== diff --git a/src/command/key.rs b/src/command/key.rs index b908b12a..c70dad11 100644 --- a/src/command/key.rs +++ b/src/command/key.rs @@ -498,6 +498,70 @@ pub fn renamenx(db: &mut Database, args: &[Frame]) -> Frame { Frame::Integer(1) } +/// COPY source destination [DB destination-db] [REPLACE] +/// +/// Copies the value stored at the source key to the destination key. +/// Returns 1 if source was copied, 0 if destination already exists without REPLACE. +pub fn copy(db: &mut Database, args: &[Frame]) -> Frame { + if args.len() < 2 { + return err_wrong_args("COPY"); + } + let src = match extract_key(&args[0]) { + Some(k) => k, + None => return err_wrong_args("COPY"), + }; + let dst = match extract_key(&args[1]) { + Some(k) => k, + None => return err_wrong_args("COPY"), + }; + + // Parse optional arguments: DB destination-db, REPLACE + let mut replace = false; + let mut i = 2; + while i < args.len() { + let arg = match extract_key(&args[i]) { + Some(k) => k, + None => return Frame::Error(Bytes::from_static(b"ERR syntax error")), + }; + if arg.eq_ignore_ascii_case(b"REPLACE") { + replace = true; + i += 1; + } else if arg.eq_ignore_ascii_case(b"DB") { + // Cross-DB copy requires shard_databases context not available here + return Frame::Error(Bytes::from_static( + b"ERR COPY with DB option is not supported yet", + )); + } else { + return Frame::Error(Bytes::from_static(b"ERR syntax error")); + } + } + + // Check if source exists (with lazy expiry) + if !db.exists(src) { + return Frame::Error(Bytes::from_static(b"ERR no such key")); + } + + // Same key: no data to copy, but it's valid + if src == dst { + return Frame::Integer(1); + } + + // Check if destination exists + if db.exists(dst) && !replace { + return Frame::Integer(0); + } + + // Clone the source entry (CompactEntry derives Clone) + let entry = db.get(src).cloned(); + if let Some(cloned) = entry { + db.set(Bytes::copy_from_slice(dst), cloned); + Frame::Integer(1) + } else { + // Source expired between exists() and get() — race with lazy expiry + Frame::Error(Bytes::from_static(b"ERR no such key")) + } +} + /// Check if a value is large enough to warrant async drop. fn should_async_drop(entry: &crate::storage::entry::Entry) -> bool { use crate::storage::compact_value::RedisValueRef; @@ -1469,4 +1533,58 @@ mod tests { _ => panic!("Expected array"), } } + + // --- COPY tests --- + + #[test] + fn test_copy_basic() { + let mut db = setup_db_with_key(b"src", b"hello"); + let result = copy(&mut db, &[bs(b"src"), bs(b"dst")]); + assert_eq!(result, Frame::Integer(1)); + assert!(db.exists(b"src")); + assert!(db.exists(b"dst")); + } + + #[test] + fn test_copy_dest_exists_no_replace() { + let mut db = setup_db_with_key(b"src", b"hello"); + db.set( + Bytes::from_static(b"dst"), + Entry::new_string(Bytes::from_static(b"existing")), + ); + let result = copy(&mut db, &[bs(b"src"), bs(b"dst")]); + assert_eq!(result, Frame::Integer(0)); + } + + #[test] + fn test_copy_with_replace() { + let mut db = setup_db_with_key(b"src", b"hello"); + db.set( + Bytes::from_static(b"dst"), + Entry::new_string(Bytes::from_static(b"existing")), + ); + let result = copy(&mut db, &[bs(b"src"), bs(b"dst"), bs(b"REPLACE")]); + assert_eq!(result, Frame::Integer(1)); + } + + #[test] + fn test_copy_nonexistent_source() { + let mut db = Database::new(); + let result = copy(&mut db, &[bs(b"nosuchkey"), bs(b"dst")]); + assert!(matches!(result, Frame::Error(_))); + } + + #[test] + fn test_copy_same_key() { + let mut db = setup_db_with_key(b"src", b"hello"); + let result = copy(&mut db, &[bs(b"src"), bs(b"src")]); + assert_eq!(result, Frame::Integer(1)); + } + + #[test] + fn test_copy_db_option_errors() { + let mut db = setup_db_with_key(b"src", b"hello"); + let result = copy(&mut db, &[bs(b"src"), bs(b"dst"), bs(b"DB")]); + assert!(matches!(result, Frame::Error(_))); + } } diff --git a/src/command/metadata.rs b/src/command/metadata.rs index 787b06bb..40618352 100644 --- a/src/command/metadata.rs +++ b/src/command/metadata.rs @@ -686,6 +686,7 @@ mod tests { b"PERSIST", b"RENAME", b"RENAMENX", + b"COPY", b"HSET", b"HMSET", b"HDEL", diff --git a/src/command/mod.rs b/src/command/mod.rs index 109ac3ba..3375a938 100644 --- a/src/command/mod.rs +++ b/src/command/mod.rs @@ -94,6 +94,12 @@ fn dispatch_inner( } } // 4-letter commands + (4, b'c') => { + // COPY + if cmd.eq_ignore_ascii_case(b"COPY") { + return resp(key::copy(db, args)); + } + } (4, b'd') => { // DECR if cmd.eq_ignore_ascii_case(b"DECR") { From efa490e29798ead621cfde805fdbb2676fc719da Mon Sep 17 00:00:00 2001 From: Tin Dang Date: Fri, 10 Apr 2026 22:50:48 +0700 Subject: [PATCH 02/20] feat: implement bit operations (GETBIT, SETBIT, BITCOUNT, BITOP, BITPOS) Full Redis-compatible bitmap commands: - GETBIT/SETBIT: per-bit read/write with big-endian ordering - BITCOUNT: popcount with byte/bit range modes (Redis 7.0 BIT option) - BITOP: AND/OR/XOR/NOT with unequal-length zero-padding - BITPOS: first 0/1 search with byte/bit range modes - Read-only dispatch variants for GETBIT, BITCOUNT, BITPOS - 21 unit tests, entries in test-commands.sh and test-consistency.sh --- scripts/test-commands.sh | 14 + scripts/test-consistency.sh | 22 + src/command/metadata.rs | 4 + src/command/mod.rs | 47 +- src/command/string/mod.rs | 2 + src/command/string/string_bit.rs | 948 +++++++++++++++++++++++++++++++ 6 files changed, 1036 insertions(+), 1 deletion(-) create mode 100644 src/command/string/string_bit.rs diff --git a/scripts/test-commands.sh b/scripts/test-commands.sh index d7bf97f8..57b73568 100755 --- a/scripts/test-commands.sh +++ b/scripts/test-commands.sh @@ -561,6 +561,20 @@ if should_run "key"; then assert_moon_ok "SCAN cursor" SCAN 0 assert_moon_ok "KEYS pattern" KEYS "k:*" assert_moon_ok "OBJECT HELP" OBJECT HELP + + # Bit operations + rcli SET k:bits "\xff\x0f" >/dev/null 2>&1; mcli SET k:bits "\xff\x0f" >/dev/null 2>&1 + assert_match "GETBIT" GETBIT k:bits 0 + assert_match "SETBIT" SETBIT k:bits 0 0 + assert_match "BITCOUNT" BITCOUNT k:bits + assert_match "BITCOUNT range" BITCOUNT k:bits 0 0 + rcli SET k:bits2 "\x0f\xff" >/dev/null 2>&1; mcli SET k:bits2 "\x0f\xff" >/dev/null 2>&1 + assert_match "BITOP AND" BITOP AND k:bitdst k:bits k:bits2 + assert_match "BITOP OR" BITOP OR k:bitdst k:bits k:bits2 + assert_match "BITOP XOR" BITOP XOR k:bitdst k:bits k:bits2 + assert_match "BITOP NOT" BITOP NOT k:bitdst k:bits + assert_match "BITPOS 1" BITPOS k:bits 1 + assert_match "BITPOS 0" BITPOS k:bits 0 fi # =========================================================================== diff --git a/scripts/test-consistency.sh b/scripts/test-consistency.sh index e273a815..17619473 100755 --- a/scripts/test-consistency.sh +++ b/scripts/test-consistency.sh @@ -487,6 +487,28 @@ assert_both "GET COPY no REPLACE" GET edge:cpdst2 assert_both "COPY REPLACE" COPY edge:cpsrc edge:cpdst2 REPLACE assert_both "GET after COPY REPLACE" GET edge:cpdst2 +# SETBIT / GETBIT +both SETBIT edge:bits 7 1 +assert_both "GETBIT set" GETBIT edge:bits 7 +assert_both "GETBIT unset" GETBIT edge:bits 0 +both SETBIT edge:bits 0 1 +assert_both "BITCOUNT" BITCOUNT edge:bits + +# BITOP +both SET edge:bop1 "\xff" +both SET edge:bop2 "\x0f" +assert_both "BITOP AND" BITOP AND edge:bopdst edge:bop1 edge:bop2 +assert_both "GET BITOP AND" GET edge:bopdst +assert_both "BITOP OR" BITOP OR edge:bopdst edge:bop1 edge:bop2 +assert_both "GET BITOP OR" GET edge:bopdst +assert_both "BITOP NOT" BITOP NOT edge:bopdst edge:bop1 +assert_both "GET BITOP NOT" GET edge:bopdst + +# BITPOS +both SET edge:bpos "\x00\xff" +assert_both "BITPOS 1" BITPOS edge:bpos 1 +assert_both "BITPOS 0" BITPOS edge:bpos 0 + # =========================================================================== # Summary # =========================================================================== diff --git a/src/command/metadata.rs b/src/command/metadata.rs index 40618352..7a6b1199 100644 --- a/src/command/metadata.rs +++ b/src/command/metadata.rs @@ -295,6 +295,8 @@ pub static COMMAND_META: phf::Map<&'static str, CommandMeta> = phf_map! { "PEXPIRETIME" => CommandMeta { name: "PEXPIRETIME", arity: 2, flags: RF, first_key: 1, last_key: 1, step: 1, acl_categories: GEN }, // ---- Bitmap commands ---- + "GETBIT" => CommandMeta { name: "GETBIT", arity: 3, flags: RF, first_key: 1, last_key: 1, step: 1, acl_categories: STR }, + "SETBIT" => CommandMeta { name: "SETBIT", arity: 4, flags: W, first_key: 1, last_key: 1, step: 1, acl_categories: STR }, "BITCOUNT" => CommandMeta { name: "BITCOUNT", arity: -2, flags: R, first_key: 1, last_key: 1, step: 1, acl_categories: STR }, "BITOP" => CommandMeta { name: "BITOP", arity: -4, flags: W, first_key: 2, last_key: -1, step: 1, acl_categories: STR }, "BITFIELD" => CommandMeta { name: "BITFIELD", arity: -2, flags: W, first_key: 1, last_key: 1, step: 1, acl_categories: STR }, @@ -687,6 +689,8 @@ mod tests { b"RENAME", b"RENAMENX", b"COPY", + b"SETBIT", + b"BITOP", b"HSET", b"HMSET", b"HDEL", diff --git a/src/command/mod.rs b/src/command/mod.rs index 3375a938..d627bfee 100644 --- a/src/command/mod.rs +++ b/src/command/mod.rs @@ -236,6 +236,12 @@ fn dispatch_inner( } } // 5-letter commands + (5, b'b') => { + // BITOP + if cmd.eq_ignore_ascii_case(b"BITOP") { + return resp(string::bitop(db, args)); + } + } (5, b'g') => { // GETEX if cmd.eq_ignore_ascii_case(b"GETEX") { @@ -358,6 +364,12 @@ fn dispatch_inner( } } // 6-letter commands + (6, b'b') => { + // BITPOS + if cmd.eq_ignore_ascii_case(b"BITPOS") { + return resp(string::bitpos(db, args)); + } + } (6, b'a') => { // APPEND if cmd.eq_ignore_ascii_case(b"APPEND") { @@ -383,7 +395,10 @@ fn dispatch_inner( } } (6, b'g') => { - // GETSET GETDEL + // GETBIT GETSET GETDEL + if cmd.eq_ignore_ascii_case(b"GETBIT") { + return resp(string::getbit(db, args)); + } if cmd.eq_ignore_ascii_case(b"GETSET") { return resp(string::getset(db, args)); } @@ -447,6 +462,9 @@ fn dispatch_inner( if cmd.eq_ignore_ascii_case(b"SELECT") { return resp(connection::select(args, selected_db, db_count)); } + if cmd.eq_ignore_ascii_case(b"SETBIT") { + return resp(string::setbit(db, args)); + } } b't' => { if cmd.eq_ignore_ascii_case(b"STRLEN") { @@ -579,6 +597,12 @@ fn dispatch_inner( return resp(string::getrange(db, args)); } } + (8, b'b') => { + // BITCOUNT BITFIELD + if cmd.eq_ignore_ascii_case(b"BITCOUNT") { + return resp(string::bitcount(db, args)); + } + } (8, b'r') => { // RENAMENX if cmd.eq_ignore_ascii_case(b"RENAMENX") { @@ -769,13 +793,16 @@ pub fn is_dispatch_read_supported(cmd: &[u8]) -> bool { | (5, b'h') // HMGET, HKEYS, HVALS, HSCAN | (5, b's') // SCARD, SDIFF, SSCAN | (5, b'z') // ZCARD, ZRANK, ZSCAN + | (6, b'b') // BITPOS | (6, b'e') // EXISTS + | (6, b'g') // GETBIT | (6, b'l') // LRANGE, LINDEX | (6, b's') // STRLEN, SUBSTR, SINTER, SUNION | (6, b'z') // ZSCORE, ZRANGE, ZCOUNT | (7, b'c') // COMMAND | (7, b'h') // HGETALL, HEXISTS | (7, b'p') // PFCOUNT + | (8, b'b') // BITCOUNT | (8, b'g') // GETRANGE | (8, b's') // SMEMBERS | (8, b'z') // ZREVRANK @@ -928,12 +955,24 @@ fn dispatch_read_inner(db: &Database, cmd: &[u8], args: &[Frame], now_ms: u64) - return resp(sorted_set::zscan_readonly(db, args, now_ms)); } } + (6, b'b') => { + // BITPOS + if cmd.eq_ignore_ascii_case(b"BITPOS") { + return resp(string::bitpos_readonly(db, args, now_ms)); + } + } (6, b'e') => { // EXISTS if cmd.eq_ignore_ascii_case(b"EXISTS") { return resp(key::exists_readonly(db, args, now_ms)); } } + (6, b'g') => { + // GETBIT + if cmd.eq_ignore_ascii_case(b"GETBIT") { + return resp(string::getbit_readonly(db, args, now_ms)); + } + } (6, b'l') => { // LRANGE LINDEX if cmd.eq_ignore_ascii_case(b"LRANGE") { @@ -991,6 +1030,12 @@ fn dispatch_read_inner(db: &Database, cmd: &[u8], args: &[Frame], now_ms: u64) - return resp(hll::pfcount_readonly(db, args, now_ms)); } } + (8, b'b') => { + // BITCOUNT + if cmd.eq_ignore_ascii_case(b"BITCOUNT") { + return resp(string::bitcount_readonly(db, args, now_ms)); + } + } (8, b'g') => { // GETRANGE if cmd.eq_ignore_ascii_case(b"GETRANGE") { diff --git a/src/command/string/mod.rs b/src/command/string/mod.rs index 7b4bbfd0..764f5624 100644 --- a/src/command/string/mod.rs +++ b/src/command/string/mod.rs @@ -1,6 +1,8 @@ +mod string_bit; mod string_read; mod string_write; +pub use string_bit::*; pub use string_read::*; pub use string_write::*; diff --git a/src/command/string/string_bit.rs b/src/command/string/string_bit.rs new file mode 100644 index 00000000..582391d5 --- /dev/null +++ b/src/command/string/string_bit.rs @@ -0,0 +1,948 @@ +use bytes::Bytes; + +use crate::protocol::Frame; +use crate::storage::Database; +use crate::storage::entry::Entry; + +use super::parse_i64; +use crate::command::helpers::{err_wrong_args, extract_bytes}; + +/// GETBIT key offset +/// +/// Returns the bit value at offset in the string value stored at key. +pub fn getbit(db: &mut Database, args: &[Frame]) -> Frame { + if args.len() != 2 { + return err_wrong_args("GETBIT"); + } + let key = match extract_bytes(&args[0]) { + Some(k) => k, + None => return err_wrong_args("GETBIT"), + }; + let offset = match parse_i64(&args[1]) { + Some(v) if v >= 0 => v as usize, + _ => { + return Frame::Error(Bytes::from_static( + b"ERR bit offset is not an integer or out of range", + )); + } + }; + + let byte_idx = offset / 8; + let bit_idx = 7 - (offset % 8); // Redis uses big-endian bit ordering + + match db.get(key) { + Some(entry) => match entry.value.as_bytes() { + Some(data) => { + if byte_idx >= data.len() { + Frame::Integer(0) + } else { + Frame::Integer(((data[byte_idx] >> bit_idx) & 1) as i64) + } + } + None => Frame::Error(Bytes::from_static( + b"WRONGTYPE Operation against a key holding the wrong kind of value", + )), + }, + None => Frame::Integer(0), + } +} + +/// GETBIT readonly variant for dispatch_read. +pub fn getbit_readonly(db: &Database, args: &[Frame], now_ms: u64) -> Frame { + if args.len() != 2 { + return err_wrong_args("GETBIT"); + } + let key = match extract_bytes(&args[0]) { + Some(k) => k, + None => return err_wrong_args("GETBIT"), + }; + let offset = match parse_i64(&args[1]) { + Some(v) if v >= 0 => v as usize, + _ => { + return Frame::Error(Bytes::from_static( + b"ERR bit offset is not an integer or out of range", + )); + } + }; + + let byte_idx = offset / 8; + let bit_idx = 7 - (offset % 8); + + match db.get_if_alive(key, now_ms) { + Some(entry) => match entry.value.as_bytes() { + Some(data) => { + if byte_idx >= data.len() { + Frame::Integer(0) + } else { + Frame::Integer(((data[byte_idx] >> bit_idx) & 1) as i64) + } + } + None => Frame::Error(Bytes::from_static( + b"WRONGTYPE Operation against a key holding the wrong kind of value", + )), + }, + None => Frame::Integer(0), + } +} + +/// SETBIT key offset value +/// +/// Sets or clears the bit at offset in the string value stored at key. +/// Returns the original bit value at the offset. +pub fn setbit(db: &mut Database, args: &[Frame]) -> Frame { + if args.len() != 3 { + return err_wrong_args("SETBIT"); + } + let key = match extract_bytes(&args[0]) { + Some(k) => k.clone(), + None => return err_wrong_args("SETBIT"), + }; + let offset = match parse_i64(&args[1]) { + Some(v) if v >= 0 && v < (512 * 1024 * 1024 * 8) => v as usize, + _ => { + return Frame::Error(Bytes::from_static( + b"ERR bit offset is not an integer or out of range", + )); + } + }; + let bit_val = match parse_i64(&args[2]) { + Some(0) => 0u8, + Some(1) => 1u8, + _ => { + return Frame::Error(Bytes::from_static( + b"ERR bit is not an integer or out of range", + )); + } + }; + + let byte_idx = offset / 8; + let bit_idx = 7 - (offset % 8); + + let base_ts = db.base_timestamp(); + let (existing_data, existing_expiry_ms) = match db.get(&key) { + Some(entry) => { + let expiry = entry.expires_at_ms(base_ts); + match entry.value.as_bytes() { + Some(v) => (Some(v.to_vec()), expiry), + None => { + return Frame::Error(Bytes::from_static( + b"WRONGTYPE Operation against a key holding the wrong kind of value", + )); + } + } + } + None => (None, 0), + }; + + let mut buf = existing_data.unwrap_or_default(); + // Extend with zero bytes if needed + if byte_idx >= buf.len() { + buf.resize(byte_idx + 1, 0); + } + + // Get original bit + let original = (buf[byte_idx] >> bit_idx) & 1; + + // Set or clear bit + if bit_val == 1 { + buf[byte_idx] |= 1 << bit_idx; + } else { + buf[byte_idx] &= !(1 << bit_idx); + } + + let new_val = Bytes::from(buf); + let mut entry = if existing_expiry_ms > 0 { + Entry::new_string_with_expiry(new_val, existing_expiry_ms, base_ts) + } else { + Entry::new_string(new_val) + }; + entry.set_last_access(db.now()); + entry.set_access_counter(5); + db.set(key, entry); + + Frame::Integer(original as i64) +} + +/// BITCOUNT key [start end [BYTE|BIT]] +/// +/// Count the number of set bits in a string. +pub fn bitcount(db: &mut Database, args: &[Frame]) -> Frame { + if args.is_empty() { + return err_wrong_args("BITCOUNT"); + } + let key = match extract_bytes(&args[0]) { + Some(k) => k, + None => return err_wrong_args("BITCOUNT"), + }; + + let data = match db.get(key) { + Some(entry) => match entry.value.as_bytes() { + Some(v) => v, + None => { + return Frame::Error(Bytes::from_static( + b"WRONGTYPE Operation against a key holding the wrong kind of value", + )); + } + }, + None => return Frame::Integer(0), + }; + + if data.is_empty() { + return Frame::Integer(0); + } + + // Parse optional range + let (start, end, use_bit) = if args.len() >= 3 { + let s = match parse_i64(&args[1]) { + Some(v) => v, + None => { + return Frame::Error(Bytes::from_static( + b"ERR value is not an integer or out of range", + )); + } + }; + let e = match parse_i64(&args[2]) { + Some(v) => v, + None => { + return Frame::Error(Bytes::from_static( + b"ERR value is not an integer or out of range", + )); + } + }; + let use_bit = if args.len() >= 4 { + let mode = match extract_bytes(&args[3]) { + Some(m) => m, + None => return Frame::Error(Bytes::from_static(b"ERR syntax error")), + }; + if mode.eq_ignore_ascii_case(b"BIT") { + true + } else if mode.eq_ignore_ascii_case(b"BYTE") { + false + } else { + return Frame::Error(Bytes::from_static(b"ERR syntax error")); + } + } else { + false + }; + (s, e, use_bit) + } else { + (0i64, -1i64, false) + }; + + if use_bit { + // BIT mode: count bits in the bit range + let total_bits = (data.len() * 8) as i64; + let s = normalize_index(start, total_bits); + let e = normalize_index(end, total_bits); + if s > e { + return Frame::Integer(0); + } + let count = count_bits_in_range(data, s as usize, e as usize); + Frame::Integer(count as i64) + } else { + // BYTE mode (default): count bits in the byte range + let len = data.len() as i64; + let s = normalize_index(start, len); + let e = normalize_index(end, len); + if s > e { + return Frame::Integer(0); + } + let slice = &data[s as usize..=e as usize]; + let count: u32 = slice.iter().map(|b| b.count_ones()).sum(); + Frame::Integer(count as i64) + } +} + +/// BITCOUNT readonly variant. +pub fn bitcount_readonly(db: &Database, args: &[Frame], now_ms: u64) -> Frame { + if args.is_empty() { + return err_wrong_args("BITCOUNT"); + } + let key = match extract_bytes(&args[0]) { + Some(k) => k, + None => return err_wrong_args("BITCOUNT"), + }; + + let data = match db.get_if_alive(key, now_ms) { + Some(entry) => match entry.value.as_bytes() { + Some(v) => v.to_vec(), + None => { + return Frame::Error(Bytes::from_static( + b"WRONGTYPE Operation against a key holding the wrong kind of value", + )); + } + }, + None => return Frame::Integer(0), + }; + + if data.is_empty() { + return Frame::Integer(0); + } + + let (start, end, use_bit) = if args.len() >= 3 { + let s = match parse_i64(&args[1]) { + Some(v) => v, + None => { + return Frame::Error(Bytes::from_static( + b"ERR value is not an integer or out of range", + )); + } + }; + let e = match parse_i64(&args[2]) { + Some(v) => v, + None => { + return Frame::Error(Bytes::from_static( + b"ERR value is not an integer or out of range", + )); + } + }; + let use_bit = if args.len() >= 4 { + let mode = match extract_bytes(&args[3]) { + Some(m) => m, + None => return Frame::Error(Bytes::from_static(b"ERR syntax error")), + }; + if mode.eq_ignore_ascii_case(b"BIT") { + true + } else if mode.eq_ignore_ascii_case(b"BYTE") { + false + } else { + return Frame::Error(Bytes::from_static(b"ERR syntax error")); + } + } else { + false + }; + (s, e, use_bit) + } else { + (0i64, -1i64, false) + }; + + if use_bit { + let total_bits = (data.len() * 8) as i64; + let s = normalize_index(start, total_bits); + let e = normalize_index(end, total_bits); + if s > e { + return Frame::Integer(0); + } + let count = count_bits_in_range(&data, s as usize, e as usize); + Frame::Integer(count as i64) + } else { + let len = data.len() as i64; + let s = normalize_index(start, len); + let e = normalize_index(end, len); + if s > e { + return Frame::Integer(0); + } + let slice = &data[s as usize..=e as usize]; + let count: u32 = slice.iter().map(|b| b.count_ones()).sum(); + Frame::Integer(count as i64) + } +} + +/// BITOP operation destkey key [key ...] +/// +/// Perform bitwise operations between strings. +pub fn bitop(db: &mut Database, args: &[Frame]) -> Frame { + if args.len() < 3 { + return err_wrong_args("BITOP"); + } + let op = match extract_bytes(&args[0]) { + Some(o) => o, + None => return err_wrong_args("BITOP"), + }; + let destkey = match extract_bytes(&args[1]) { + Some(k) => k.clone(), + None => return err_wrong_args("BITOP"), + }; + + // Determine operation + let is_not = op.eq_ignore_ascii_case(b"NOT"); + if is_not && args.len() != 3 { + return Frame::Error(Bytes::from_static( + b"ERR BITOP NOT requires one and only one key", + )); + } + + // Gather source values + let mut sources: Vec> = Vec::with_capacity(args.len() - 2); + let mut max_len = 0usize; + for arg in &args[2..] { + let key = match extract_bytes(arg) { + Some(k) => k, + None => return err_wrong_args("BITOP"), + }; + let data = match db.get(key) { + Some(entry) => match entry.value.as_bytes() { + Some(v) => v.to_vec(), + None => { + return Frame::Error(Bytes::from_static( + b"WRONGTYPE Operation against a key holding the wrong kind of value", + )); + } + }, + None => Vec::new(), + }; + if data.len() > max_len { + max_len = data.len(); + } + sources.push(data); + } + + if max_len == 0 { + // All keys empty/missing — delete dest, return 0 + db.remove(&destkey); + return Frame::Integer(0); + } + + let mut result = vec![0u8; max_len]; + + if is_not { + let src = &sources[0]; + for (i, byte) in result.iter_mut().enumerate() { + *byte = if i < src.len() { !src[i] } else { 0xFF }; + } + } else if op.eq_ignore_ascii_case(b"AND") { + // Start with all 1s + result.iter_mut().for_each(|b| *b = 0xFF); + for src in &sources { + for (i, byte) in result.iter_mut().enumerate() { + let v = if i < src.len() { src[i] } else { 0 }; + *byte &= v; + } + } + } else if op.eq_ignore_ascii_case(b"OR") { + for src in &sources { + for (i, byte) in result.iter_mut().enumerate() { + if i < src.len() { + *byte |= src[i]; + } + } + } + } else if op.eq_ignore_ascii_case(b"XOR") { + for src in &sources { + for (i, byte) in result.iter_mut().enumerate() { + if i < src.len() { + *byte ^= src[i]; + } + } + } + } else { + return Frame::Error(Bytes::from_static( + b"ERR BITOP requires AND, OR, XOR, or NOT", + )); + } + + let result_len = result.len() as i64; + let entry = Entry::new_string(Bytes::from(result)); + db.set(destkey, entry); + + Frame::Integer(result_len) +} + +/// BITPOS key bit [start [end [BYTE|BIT]]] +/// +/// Return the position of the first bit set to 0 or 1 in a string. +pub fn bitpos(db: &mut Database, args: &[Frame]) -> Frame { + if args.len() < 2 { + return err_wrong_args("BITPOS"); + } + let key = match extract_bytes(&args[0]) { + Some(k) => k, + None => return err_wrong_args("BITPOS"), + }; + let target_bit = match parse_i64(&args[1]) { + Some(0) => 0u8, + Some(1) => 1u8, + _ => { + return Frame::Error(Bytes::from_static( + b"ERR bit is not an integer or out of range", + )); + } + }; + + let data = match db.get(key) { + Some(entry) => match entry.value.as_bytes() { + Some(v) => v, + None => { + return Frame::Error(Bytes::from_static( + b"WRONGTYPE Operation against a key holding the wrong kind of value", + )); + } + }, + None => { + // Missing key: looking for 0 returns 0, looking for 1 returns -1 + return if target_bit == 0 { + Frame::Integer(0) + } else { + Frame::Integer(-1) + }; + } + }; + + if data.is_empty() { + return if target_bit == 0 { + Frame::Integer(0) + } else { + Frame::Integer(-1) + }; + } + + // Parse optional range + let has_start = args.len() >= 3; + let has_end = args.len() >= 4; + + let use_bit = if args.len() >= 5 { + let mode = match extract_bytes(&args[4]) { + Some(m) => m, + None => return Frame::Error(Bytes::from_static(b"ERR syntax error")), + }; + if mode.eq_ignore_ascii_case(b"BIT") { + true + } else if mode.eq_ignore_ascii_case(b"BYTE") { + false + } else { + return Frame::Error(Bytes::from_static(b"ERR syntax error")); + } + } else { + false + }; + + let start = if has_start { + match parse_i64(&args[2]) { + Some(v) => v, + None => { + return Frame::Error(Bytes::from_static( + b"ERR value is not an integer or out of range", + )); + } + } + } else { + 0 + }; + + let end = if has_end { + match parse_i64(&args[3]) { + Some(v) => v, + None => { + return Frame::Error(Bytes::from_static( + b"ERR value is not an integer or out of range", + )); + } + } + } else { + -1 + }; + + if use_bit { + let total_bits = (data.len() * 8) as i64; + let s = normalize_index(start, total_bits) as usize; + let e = normalize_index(end, total_bits) as usize; + if s > e { + return Frame::Integer(-1); + } + for bit_pos in s..=e { + let byte_idx = bit_pos / 8; + let bit_idx = 7 - (bit_pos % 8); + if byte_idx < data.len() { + let bit = (data[byte_idx] >> bit_idx) & 1; + if bit == target_bit { + return Frame::Integer(bit_pos as i64); + } + } + } + Frame::Integer(-1) + } else { + let len = data.len() as i64; + let s = normalize_index(start, len) as usize; + let e = normalize_index(end, len) as usize; + if s > e { + return Frame::Integer(-1); + } + let slice = &data[s..=e]; + for (byte_offset, &byte) in slice.iter().enumerate() { + for bit in 0..8u8 { + let bit_idx = 7 - bit; + let val = (byte >> bit_idx) & 1; + if val == target_bit { + return Frame::Integer(((s + byte_offset) * 8 + bit as usize) as i64); + } + } + } + // If searching for 0 without explicit end, Redis treats the string as + // having an implicit 0x00 byte beyond the last byte + if target_bit == 0 && !has_end { + return Frame::Integer(((e + 1) * 8) as i64); + } + Frame::Integer(-1) + } +} + +/// BITPOS readonly variant. +pub fn bitpos_readonly(db: &Database, args: &[Frame], now_ms: u64) -> Frame { + if args.len() < 2 { + return err_wrong_args("BITPOS"); + } + let key = match extract_bytes(&args[0]) { + Some(k) => k, + None => return err_wrong_args("BITPOS"), + }; + let target_bit = match parse_i64(&args[1]) { + Some(0) => 0u8, + Some(1) => 1u8, + _ => { + return Frame::Error(Bytes::from_static( + b"ERR bit is not an integer or out of range", + )); + } + }; + + let data_owned; + let data: &[u8] = match db.get_if_alive(key, now_ms) { + Some(entry) => match entry.value.as_bytes() { + Some(v) => { + data_owned = v.to_vec(); + &data_owned + } + None => { + return Frame::Error(Bytes::from_static( + b"WRONGTYPE Operation against a key holding the wrong kind of value", + )); + } + }, + None => { + return if target_bit == 0 { + Frame::Integer(0) + } else { + Frame::Integer(-1) + }; + } + }; + + if data.is_empty() { + return if target_bit == 0 { + Frame::Integer(0) + } else { + Frame::Integer(-1) + }; + } + + let has_start = args.len() >= 3; + let has_end = args.len() >= 4; + + let use_bit = if args.len() >= 5 { + let mode = match extract_bytes(&args[4]) { + Some(m) => m, + None => return Frame::Error(Bytes::from_static(b"ERR syntax error")), + }; + if mode.eq_ignore_ascii_case(b"BIT") { + true + } else if mode.eq_ignore_ascii_case(b"BYTE") { + false + } else { + return Frame::Error(Bytes::from_static(b"ERR syntax error")); + } + } else { + false + }; + + let start = if has_start { + match parse_i64(&args[2]) { + Some(v) => v, + None => { + return Frame::Error(Bytes::from_static( + b"ERR value is not an integer or out of range", + )); + } + } + } else { + 0 + }; + + let end = if has_end { + match parse_i64(&args[3]) { + Some(v) => v, + None => { + return Frame::Error(Bytes::from_static( + b"ERR value is not an integer or out of range", + )); + } + } + } else { + -1 + }; + + if use_bit { + let total_bits = (data.len() * 8) as i64; + let s = normalize_index(start, total_bits) as usize; + let e = normalize_index(end, total_bits) as usize; + if s > e { + return Frame::Integer(-1); + } + for bit_pos in s..=e { + let byte_idx = bit_pos / 8; + let bit_idx = 7 - (bit_pos % 8); + if byte_idx < data.len() { + let bit = (data[byte_idx] >> bit_idx) & 1; + if bit == target_bit { + return Frame::Integer(bit_pos as i64); + } + } + } + Frame::Integer(-1) + } else { + let len = data.len() as i64; + let s = normalize_index(start, len) as usize; + let e = normalize_index(end, len) as usize; + if s > e { + return Frame::Integer(-1); + } + let slice = &data[s..=e]; + for (byte_offset, &byte) in slice.iter().enumerate() { + for bit in 0..8u8 { + let bit_idx = 7 - bit; + let val = (byte >> bit_idx) & 1; + if val == target_bit { + return Frame::Integer(((s + byte_offset) * 8 + bit as usize) as i64); + } + } + } + if target_bit == 0 && !has_end { + return Frame::Integer(((e + 1) * 8) as i64); + } + Frame::Integer(-1) + } +} + +/// Normalize a Redis index (negative = from end) to a 0-based clamped index. +fn normalize_index(idx: i64, len: i64) -> i64 { + if len == 0 { + return 0; + } + let normalized = if idx < 0 { len + idx } else { idx }; + normalized.clamp(0, len - 1) +} + +/// Count set bits in a specific bit range within a byte slice. +fn count_bits_in_range(data: &[u8], start_bit: usize, end_bit: usize) -> u32 { + let mut count = 0u32; + for bit_pos in start_bit..=end_bit { + let byte_idx = bit_pos / 8; + let bit_idx = 7 - (bit_pos % 8); + if byte_idx < data.len() && (data[byte_idx] >> bit_idx) & 1 == 1 { + count += 1; + } + } + count +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::storage::Database; + + fn bs(s: &[u8]) -> Frame { + Frame::BulkString(Bytes::copy_from_slice(s)) + } + + fn make_db() -> Database { + Database::new() + } + + // --- GETBIT tests --- + + #[test] + fn test_getbit_missing_key() { + let mut db = make_db(); + let result = getbit(&mut db, &[bs(b"key"), bs(b"0")]); + assert_eq!(result, Frame::Integer(0)); + } + + #[test] + fn test_getbit_out_of_range() { + let mut db = make_db(); + db.set_string(Bytes::from_static(b"key"), Bytes::from_static(b"\xff")); + let result = getbit(&mut db, &[bs(b"key"), bs(b"100")]); + assert_eq!(result, Frame::Integer(0)); + } + + // --- SETBIT tests --- + + #[test] + fn test_setbit_new_key() { + let mut db = make_db(); + let result = setbit(&mut db, &[bs(b"key"), bs(b"7"), bs(b"1")]); + assert_eq!(result, Frame::Integer(0)); // original was 0 + let result = getbit(&mut db, &[bs(b"key"), bs(b"7")]); + assert_eq!(result, Frame::Integer(1)); + } + + #[test] + fn test_setbit_clear() { + let mut db = make_db(); + db.set_string(Bytes::from_static(b"key"), Bytes::from_static(b"\xff")); + let result = setbit(&mut db, &[bs(b"key"), bs(b"0"), bs(b"0")]); + assert_eq!(result, Frame::Integer(1)); // original was 1 + let result = getbit(&mut db, &[bs(b"key"), bs(b"0")]); + assert_eq!(result, Frame::Integer(0)); + } + + #[test] + fn test_setbit_extends_string() { + let mut db = make_db(); + setbit(&mut db, &[bs(b"key"), bs(b"23"), bs(b"1")]); + // 23 / 8 = byte 2, so 3 bytes total + let entry = db.get(b"key").unwrap(); + assert_eq!(entry.value.as_bytes().unwrap().len(), 3); + } + + // --- BITCOUNT tests --- + + #[test] + fn test_bitcount_full_string() { + let mut db = make_db(); + // "foobar" in bytes + db.set_string(Bytes::from_static(b"key"), Bytes::from_static(b"foobar")); + let result = bitcount(&mut db, &[bs(b"key")]); + assert_eq!(result, Frame::Integer(26)); // known count for "foobar" + } + + #[test] + fn test_bitcount_range() { + let mut db = make_db(); + db.set_string(Bytes::from_static(b"key"), Bytes::from_static(b"foobar")); + let result = bitcount(&mut db, &[bs(b"key"), bs(b"0"), bs(b"0")]); + // 'f' = 0x66 = 01100110 → 4 bits + assert_eq!(result, Frame::Integer(4)); + } + + #[test] + fn test_bitcount_negative_range() { + let mut db = make_db(); + db.set_string(Bytes::from_static(b"key"), Bytes::from_static(b"foobar")); + let result = bitcount(&mut db, &[bs(b"key"), bs(b"-1"), bs(b"-1")]); + // 'r' = 0x72 = 01110010 → 4 bits + assert_eq!(result, Frame::Integer(4)); + } + + #[test] + fn test_bitcount_missing_key() { + let mut db = make_db(); + let result = bitcount(&mut db, &[bs(b"key")]); + assert_eq!(result, Frame::Integer(0)); + } + + #[test] + fn test_bitcount_bit_mode() { + let mut db = make_db(); + db.set_string(Bytes::from_static(b"key"), Bytes::from_static(b"\xff\x00")); + // BIT mode: bits 0-7 are all 1 = 8 bits + let result = bitcount(&mut db, &[bs(b"key"), bs(b"0"), bs(b"7"), bs(b"BIT")]); + assert_eq!(result, Frame::Integer(8)); + } + + // --- BITOP tests --- + + #[test] + fn test_bitop_and() { + let mut db = make_db(); + db.set_string(Bytes::from_static(b"a"), Bytes::from_static(b"\xff\x0f")); + db.set_string(Bytes::from_static(b"b"), Bytes::from_static(b"\x0f\xff")); + let result = bitop(&mut db, &[bs(b"AND"), bs(b"dest"), bs(b"a"), bs(b"b")]); + assert_eq!(result, Frame::Integer(2)); + let data = db.get(b"dest").unwrap().value.as_bytes().unwrap().to_vec(); + assert_eq!(data, vec![0x0f, 0x0f]); + } + + #[test] + fn test_bitop_or() { + let mut db = make_db(); + db.set_string(Bytes::from_static(b"a"), Bytes::from_static(b"\xf0")); + db.set_string(Bytes::from_static(b"b"), Bytes::from_static(b"\x0f")); + let result = bitop(&mut db, &[bs(b"OR"), bs(b"dest"), bs(b"a"), bs(b"b")]); + assert_eq!(result, Frame::Integer(1)); + let data = db.get(b"dest").unwrap().value.as_bytes().unwrap().to_vec(); + assert_eq!(data, vec![0xff]); + } + + #[test] + fn test_bitop_xor() { + let mut db = make_db(); + db.set_string(Bytes::from_static(b"a"), Bytes::from_static(b"\xff")); + db.set_string(Bytes::from_static(b"b"), Bytes::from_static(b"\x0f")); + let result = bitop(&mut db, &[bs(b"XOR"), bs(b"dest"), bs(b"a"), bs(b"b")]); + assert_eq!(result, Frame::Integer(1)); + let data = db.get(b"dest").unwrap().value.as_bytes().unwrap().to_vec(); + assert_eq!(data, vec![0xf0]); + } + + #[test] + fn test_bitop_not() { + let mut db = make_db(); + db.set_string(Bytes::from_static(b"a"), Bytes::from_static(b"\x0f")); + let result = bitop(&mut db, &[bs(b"NOT"), bs(b"dest"), bs(b"a")]); + assert_eq!(result, Frame::Integer(1)); + let data = db.get(b"dest").unwrap().value.as_bytes().unwrap().to_vec(); + assert_eq!(data, vec![0xf0]); + } + + #[test] + fn test_bitop_not_requires_one_key() { + let mut db = make_db(); + let result = bitop( + &mut db, + &[bs(b"NOT"), bs(b"dest"), bs(b"a"), bs(b"b")], + ); + assert!(matches!(result, Frame::Error(_))); + } + + #[test] + fn test_bitop_unequal_lengths() { + let mut db = make_db(); + db.set_string(Bytes::from_static(b"a"), Bytes::from_static(b"\xff\xff")); + db.set_string(Bytes::from_static(b"b"), Bytes::from_static(b"\x0f")); + let result = bitop(&mut db, &[bs(b"AND"), bs(b"dest"), bs(b"a"), bs(b"b")]); + assert_eq!(result, Frame::Integer(2)); + let data = db.get(b"dest").unwrap().value.as_bytes().unwrap().to_vec(); + // b is zero-padded → \x0f\x00, AND with \xff\xff → \x0f\x00 + assert_eq!(data, vec![0x0f, 0x00]); + } + + // --- BITPOS tests --- + + #[test] + fn test_bitpos_first_one() { + let mut db = make_db(); + db.set_string(Bytes::from_static(b"key"), Bytes::from_static(b"\x00\xff")); + let result = bitpos(&mut db, &[bs(b"key"), bs(b"1")]); + assert_eq!(result, Frame::Integer(8)); // first 1 bit at position 8 + } + + #[test] + fn test_bitpos_first_zero() { + let mut db = make_db(); + db.set_string(Bytes::from_static(b"key"), Bytes::from_static(b"\xff\x00")); + let result = bitpos(&mut db, &[bs(b"key"), bs(b"0")]); + assert_eq!(result, Frame::Integer(8)); // first 0 bit at position 8 + } + + #[test] + fn test_bitpos_no_one_found() { + let mut db = make_db(); + db.set_string(Bytes::from_static(b"key"), Bytes::from_static(b"\x00\x00")); + let result = bitpos(&mut db, &[bs(b"key"), bs(b"1")]); + assert_eq!(result, Frame::Integer(-1)); + } + + #[test] + fn test_bitpos_missing_key_zero() { + let mut db = make_db(); + let result = bitpos(&mut db, &[bs(b"key"), bs(b"0")]); + assert_eq!(result, Frame::Integer(0)); + } + + #[test] + fn test_bitpos_missing_key_one() { + let mut db = make_db(); + let result = bitpos(&mut db, &[bs(b"key"), bs(b"1")]); + assert_eq!(result, Frame::Integer(-1)); + } +} From 757ccb3435b5988bcaa974b19b45a05d3d9ace83 Mon Sep 17 00:00:00 2001 From: Tin Dang Date: Fri, 10 Apr 2026 22:53:57 +0700 Subject: [PATCH 03/20] feat: implement SORT command with BY/GET/LIMIT/ALPHA/DESC/STORE Full Redis-compatible SORT for lists, sets, and sorted sets: - Numeric (default) and ALPHA lexicographic sorting - ASC/DESC ordering - LIMIT offset count pagination - BY pattern for external key-based sorting - GET pattern for external key retrieval (# = element itself) - STORE destination to persist results as list - BY nosort to skip sorting - 7 unit tests, entries in test-commands.sh and test-consistency.sh --- scripts/test-commands.sh | 7 + scripts/test-consistency.sh | 9 + src/command/key.rs | 345 ++++++++++++++++++++++++++++++++++++ src/command/mod.rs | 3 + 4 files changed, 364 insertions(+) diff --git a/scripts/test-commands.sh b/scripts/test-commands.sh index 57b73568..704c6f63 100755 --- a/scripts/test-commands.sh +++ b/scripts/test-commands.sh @@ -575,6 +575,13 @@ if should_run "key"; then assert_match "BITOP NOT" BITOP NOT k:bitdst k:bits assert_match "BITPOS 1" BITPOS k:bits 1 assert_match "BITPOS 0" BITPOS k:bits 0 + + # SORT + rcli RPUSH k:sortl 3 1 2 >/dev/null 2>&1; mcli RPUSH k:sortl 3 1 2 >/dev/null 2>&1 + assert_match "SORT numeric" SORT k:sortl + assert_match "SORT DESC" SORT k:sortl DESC + assert_match "SORT ALPHA" SORT k:sortl ALPHA + assert_match "SORT LIMIT" SORT k:sortl LIMIT 0 2 fi # =========================================================================== diff --git a/scripts/test-consistency.sh b/scripts/test-consistency.sh index 17619473..210fdee0 100755 --- a/scripts/test-consistency.sh +++ b/scripts/test-consistency.sh @@ -509,6 +509,15 @@ both SET edge:bpos "\x00\xff" assert_both "BITPOS 1" BITPOS edge:bpos 1 assert_both "BITPOS 0" BITPOS edge:bpos 0 +# SORT +both RPUSH edge:sortl 3 1 2 +assert_both "SORT numeric" SORT edge:sortl +assert_both "SORT DESC" SORT edge:sortl DESC +assert_both "SORT ALPHA" SORT edge:sortl ALPHA +assert_both "SORT LIMIT" SORT edge:sortl LIMIT 0 2 +assert_both "SORT STORE" SORT edge:sortl STORE edge:sorted +assert_both "SORT STORE result" LRANGE edge:sorted 0 -1 + # =========================================================================== # Summary # =========================================================================== diff --git a/src/command/key.rs b/src/command/key.rs index c70dad11..1cc14881 100644 --- a/src/command/key.rs +++ b/src/command/key.rs @@ -562,6 +562,234 @@ pub fn copy(db: &mut Database, args: &[Frame]) -> Frame { } } +/// SORT key [BY pattern] [LIMIT offset count] [GET pattern ...] [ASC|DESC] [ALPHA] [STORE dest] +/// +/// Sort elements in a list, set, or sorted set. +pub fn sort(db: &mut Database, args: &[Frame]) -> Frame { + if args.is_empty() { + return err_wrong_args("SORT"); + } + let key = match extract_key(&args[0]) { + Some(k) => k, + None => return err_wrong_args("SORT"), + }; + + // Parse options + let mut by_pattern: Option<&[u8]> = None; + let mut get_patterns: Vec<&[u8]> = Vec::new(); + let mut limit_offset: usize = 0; + let mut limit_count: Option = None; + let mut descending = false; + let mut alpha = false; + let mut store_dest: Option<&[u8]> = None; + + let mut i = 1; + while i < args.len() { + let arg = match extract_key(&args[i]) { + Some(a) => a, + None => return Frame::Error(Bytes::from_static(b"ERR syntax error")), + }; + if arg.eq_ignore_ascii_case(b"BY") { + i += 1; + by_pattern = Some(match extract_key(args.get(i).unwrap_or(&Frame::Null)) { + Some(p) => p, + None => return Frame::Error(Bytes::from_static(b"ERR syntax error")), + }); + } else if arg.eq_ignore_ascii_case(b"GET") { + i += 1; + let pat = match extract_key(args.get(i).unwrap_or(&Frame::Null)) { + Some(p) => p, + None => return Frame::Error(Bytes::from_static(b"ERR syntax error")), + }; + get_patterns.push(pat); + } else if arg.eq_ignore_ascii_case(b"LIMIT") { + let off = match args.get(i + 1).and_then(|f| parse_int(f)) { + Some(v) if v >= 0 => v as usize, + _ => return Frame::Error(Bytes::from_static(b"ERR syntax error")), + }; + let cnt = match args.get(i + 2).and_then(|f| parse_int(f)) { + Some(v) if v >= 0 => v as usize, + _ => return Frame::Error(Bytes::from_static(b"ERR syntax error")), + }; + limit_offset = off; + limit_count = Some(cnt); + i += 2; + } else if arg.eq_ignore_ascii_case(b"ASC") { + descending = false; + } else if arg.eq_ignore_ascii_case(b"DESC") { + descending = true; + } else if arg.eq_ignore_ascii_case(b"ALPHA") { + alpha = true; + } else if arg.eq_ignore_ascii_case(b"STORE") { + i += 1; + store_dest = Some(match extract_key(args.get(i).unwrap_or(&Frame::Null)) { + Some(d) => d, + None => return Frame::Error(Bytes::from_static(b"ERR syntax error")), + }); + } else { + return Frame::Error(Bytes::from_static(b"ERR syntax error")); + } + i += 1; + } + + // Extract elements from the key + use crate::storage::compact_value::RedisValueRef; + let elements: Vec = match db.get(key) { + None => { + // Non-existent key: return empty result + return if let Some(dest) = store_dest { + // STORE with empty → create empty list + let entry = crate::storage::entry::Entry::new_list(); + db.set(Bytes::copy_from_slice(dest), entry); + Frame::Integer(0) + } else { + Frame::Array(crate::framevec![]) + }; + } + Some(entry) => match entry.value.as_redis_value() { + RedisValueRef::List(l) => l.iter().cloned().collect(), + RedisValueRef::ListListpack(lp) => lp.iter().map(|e| e.to_bytes()).collect(), + RedisValueRef::Set(s) => s.iter().cloned().collect(), + RedisValueRef::SetListpack(lp) => lp.iter().map(|e| e.to_bytes()).collect(), + RedisValueRef::SetIntset(is) => { + is.iter().map(|v| Bytes::from(v.to_string())).collect() + } + RedisValueRef::SortedSet { members, .. } => members.keys().cloned().collect(), + RedisValueRef::SortedSetBPTree { members, .. } => members.keys().cloned().collect(), + RedisValueRef::SortedSetListpack(lp) => { + // Listpack stores member, score pairs + let entries: Vec<_> = lp.iter().collect(); + entries.chunks(2).filter_map(|c| c.first().map(|e| e.to_bytes())).collect() + } + _ => { + return Frame::Error(Bytes::from_static( + b"WRONGTYPE Operation against a key holding the wrong kind of value", + )); + } + }, + }; + + // Resolve sort keys (BY pattern or element itself) + let sort_keys: Vec> = if let Some(pattern) = by_pattern { + if pattern == b"nosort" { + // BY nosort = skip sorting + elements.iter().map(|_| None).collect() + } else { + elements + .iter() + .map(|elem| { + let lookup_key = apply_pattern(pattern, elem); + db.get(&lookup_key) + .and_then(|e| e.value.as_bytes().map(|b| Bytes::copy_from_slice(b))) + }) + .collect() + } + } else { + elements + .iter() + .map(|e| Some(e.clone())) + .collect() + }; + + // Create indexed pairs for stable sort + let mut indices: Vec = (0..elements.len()).collect(); + + // Sort (skip if BY nosort) + let no_sort = by_pattern.is_some_and(|p| p == b"nosort"); + if !no_sort { + indices.sort_by(|&a, &b| { + let ka = sort_keys[a].as_ref(); + let kb = sort_keys[b].as_ref(); + let cmp = match (ka, kb) { + (None, None) => std::cmp::Ordering::Equal, + (None, Some(_)) => std::cmp::Ordering::Greater, + (Some(_), None) => std::cmp::Ordering::Less, + (Some(va), Some(vb)) => { + if alpha { + va.cmp(vb) + } else { + let fa = std::str::from_utf8(va) + .ok() + .and_then(|s| s.parse::().ok()) + .unwrap_or(0.0); + let fb = std::str::from_utf8(vb) + .ok() + .and_then(|s| s.parse::().ok()) + .unwrap_or(0.0); + fa.partial_cmp(&fb).unwrap_or(std::cmp::Ordering::Equal) + } + } + }; + if descending { cmp.reverse() } else { cmp } + }); + } + + // Apply LIMIT + let start = limit_offset.min(indices.len()); + let count = limit_count.unwrap_or(indices.len()); + let end = (start + count).min(indices.len()); + let selected = &indices[start..end]; + + // Build results + let results: Vec = if get_patterns.is_empty() { + selected + .iter() + .map(|&idx| Frame::BulkString(elements[idx].clone())) + .collect() + } else { + let mut out = Vec::with_capacity(selected.len() * get_patterns.len()); + for &idx in selected { + for pat in &get_patterns { + if *pat == b"#" { + out.push(Frame::BulkString(elements[idx].clone())); + } else { + let lookup_key = apply_pattern(pat, &elements[idx]); + match db.get(&lookup_key) { + Some(e) => match e.value.as_bytes() { + Some(v) => out.push(Frame::BulkString(Bytes::copy_from_slice(v))), + None => out.push(Frame::Null), + }, + None => out.push(Frame::Null), + } + } + } + } + out + }; + + // STORE or return + if let Some(dest) = store_dest { + let count = results.len() as i64; + let mut list = std::collections::VecDeque::with_capacity(results.len()); + for frame in results { + if let Frame::BulkString(b) = frame { + list.push_back(b); + } + } + let mut entry = crate::storage::entry::Entry::new_list(); + entry.value = crate::storage::compact_value::CompactValue::from_redis_value( + crate::storage::entry::RedisValue::List(list), + ); + db.set(Bytes::copy_from_slice(dest), entry); + Frame::Integer(count) + } else { + Frame::Array(results.into()) + } +} + +/// Apply a SORT pattern by replacing the first `*` with the element value. +fn apply_pattern(pattern: &[u8], element: &[u8]) -> Bytes { + if let Some(pos) = pattern.iter().position(|&b| b == b'*') { + let mut result = Vec::with_capacity(pattern.len() + element.len()); + result.extend_from_slice(&pattern[..pos]); + result.extend_from_slice(element); + result.extend_from_slice(&pattern[pos + 1..]); + Bytes::from(result) + } else { + Bytes::copy_from_slice(pattern) + } +} + /// Check if a value is large enough to warrant async drop. fn should_async_drop(entry: &crate::storage::entry::Entry) -> bool { use crate::storage::compact_value::RedisValueRef; @@ -1587,4 +1815,121 @@ mod tests { let result = copy(&mut db, &[bs(b"src"), bs(b"dst"), bs(b"DB")]); assert!(matches!(result, Frame::Error(_))); } + + // --- SORT tests --- + + fn setup_list(db: &mut Database, key: &[u8], vals: &[&[u8]]) { + use std::collections::VecDeque; + let list: VecDeque = vals.iter().map(|v| Bytes::copy_from_slice(v)).collect(); + let mut entry = Entry::new_list(); + entry.value = crate::storage::compact_value::CompactValue::from_redis_value( + crate::storage::entry::RedisValue::List(list), + ); + db.set(Bytes::copy_from_slice(key), entry); + } + + #[test] + fn test_sort_numeric() { + let mut db = Database::new(); + setup_list(&mut db, b"mylist", &[b"3", b"1", b"2"]); + let result = sort(&mut db, &[bs(b"mylist")]); + assert_eq!( + result, + Frame::Array(framevec![ + Frame::BulkString(Bytes::from_static(b"1")), + Frame::BulkString(Bytes::from_static(b"2")), + Frame::BulkString(Bytes::from_static(b"3")), + ]) + ); + } + + #[test] + fn test_sort_alpha() { + let mut db = Database::new(); + setup_list(&mut db, b"mylist", &[b"c", b"a", b"b"]); + let result = sort(&mut db, &[bs(b"mylist"), bs(b"ALPHA")]); + assert_eq!( + result, + Frame::Array(framevec![ + Frame::BulkString(Bytes::from_static(b"a")), + Frame::BulkString(Bytes::from_static(b"b")), + Frame::BulkString(Bytes::from_static(b"c")), + ]) + ); + } + + #[test] + fn test_sort_desc() { + let mut db = Database::new(); + setup_list(&mut db, b"mylist", &[b"1", b"3", b"2"]); + let result = sort(&mut db, &[bs(b"mylist"), bs(b"DESC")]); + assert_eq!( + result, + Frame::Array(framevec![ + Frame::BulkString(Bytes::from_static(b"3")), + Frame::BulkString(Bytes::from_static(b"2")), + Frame::BulkString(Bytes::from_static(b"1")), + ]) + ); + } + + #[test] + fn test_sort_limit() { + let mut db = Database::new(); + setup_list(&mut db, b"mylist", &[b"3", b"1", b"2", b"4"]); + let result = sort( + &mut db, + &[bs(b"mylist"), bs(b"LIMIT"), bs(b"1"), bs(b"2")], + ); + assert_eq!( + result, + Frame::Array(framevec![ + Frame::BulkString(Bytes::from_static(b"2")), + Frame::BulkString(Bytes::from_static(b"3")), + ]) + ); + } + + #[test] + fn test_sort_store() { + let mut db = Database::new(); + setup_list(&mut db, b"mylist", &[b"3", b"1", b"2"]); + let result = sort( + &mut db, + &[bs(b"mylist"), bs(b"STORE"), bs(b"sorted")], + ); + assert_eq!(result, Frame::Integer(3)); + assert!(db.exists(b"sorted")); + } + + #[test] + fn test_sort_nonexistent() { + let mut db = Database::new(); + let result = sort(&mut db, &[bs(b"nokey")]); + assert_eq!(result, Frame::Array(framevec![])); + } + + #[test] + fn test_sort_set() { + let mut db = Database::new(); + let mut s = std::collections::HashSet::new(); + s.insert(Bytes::from_static(b"3")); + s.insert(Bytes::from_static(b"1")); + s.insert(Bytes::from_static(b"2")); + let mut entry = Entry::new_set(); + entry.value = crate::storage::compact_value::CompactValue::from_redis_value( + crate::storage::entry::RedisValue::Set(s), + ); + db.set(Bytes::from_static(b"myset"), entry); + let result = sort(&mut db, &[bs(b"myset")]); + // Sort result should be [1, 2, 3] regardless of HashSet order + assert_eq!( + result, + Frame::Array(framevec![ + Frame::BulkString(Bytes::from_static(b"1")), + Frame::BulkString(Bytes::from_static(b"2")), + Frame::BulkString(Bytes::from_static(b"3")), + ]) + ); + } } diff --git a/src/command/mod.rs b/src/command/mod.rs index d627bfee..0e77b69c 100644 --- a/src/command/mod.rs +++ b/src/command/mod.rs @@ -204,6 +204,9 @@ fn dispatch_inner( if cmd.eq_ignore_ascii_case(b"SPOP") { return resp(set::spop(db, args)); } + if cmd.eq_ignore_ascii_case(b"SORT") { + return resp(key::sort(db, args)); + } } (4, b't') => { // TYPE From f3b787a4e35603fad679c5b75530ab11e7549ad6 Mon Sep 17 00:00:00 2001 From: Tin Dang Date: Fri, 10 Apr 2026 22:59:09 +0700 Subject: [PATCH 04/20] feat: implement geospatial commands (GEOADD, GEOPOS, GEODIST, GEOHASH, GEOSEARCH, GEOSEARCHSTORE) Full Redis-compatible GEO commands backed by sorted sets with geohash scores: - GEOADD with NX/XX/CH flags and coordinate validation - GEOPOS returns (lon, lat) pairs for members - GEODIST with M/KM/FT/MI unit support - GEOHASH returns 11-char base32 geohash strings - GEOSEARCH with FROMLONLAT/FROMMEMBER, BYRADIUS/BYBOX, ASC/DESC, COUNT, WITHCOORD/WITHDIST/WITHHASH options - GEOSEARCHSTORE for persisting search results - Haversine distance formula (6372797m earth radius) - 52-bit integer geohash encoding (Redis-compatible) - 8 unit tests, entries in test-commands.sh and test-consistency.sh --- scripts/test-commands.sh | 8 + scripts/test-consistency.sh | 9 + src/command/geo/geo_cmd.rs | 685 ++++++++++++++++++++++++++++++++++++ src/command/geo/mod.rs | 162 +++++++++ src/command/metadata.rs | 2 + src/command/mod.rs | 31 +- 6 files changed, 896 insertions(+), 1 deletion(-) create mode 100644 src/command/geo/geo_cmd.rs create mode 100644 src/command/geo/mod.rs diff --git a/scripts/test-commands.sh b/scripts/test-commands.sh index 704c6f63..cfe30def 100755 --- a/scripts/test-commands.sh +++ b/scripts/test-commands.sh @@ -582,6 +582,14 @@ if should_run "key"; then assert_match "SORT DESC" SORT k:sortl DESC assert_match "SORT ALPHA" SORT k:sortl ALPHA assert_match "SORT LIMIT" SORT k:sortl LIMIT 0 2 + + # GEO commands + rcli GEOADD k:geo 13.361389 38.115556 Palermo 15.087269 37.502669 Catania >/dev/null 2>&1 + mcli GEOADD k:geo 13.361389 38.115556 Palermo 15.087269 37.502669 Catania >/dev/null 2>&1 + assert_match "GEOPOS" GEOPOS k:geo Palermo + assert_match "GEODIST km" GEODIST k:geo Palermo Catania km + assert_match "GEOHASH" GEOHASH k:geo Palermo + assert_match "GEOSEARCH" GEOSEARCH k:geo FROMLONLAT 15 37 BYRADIUS 200 km ASC fi # =========================================================================== diff --git a/scripts/test-consistency.sh b/scripts/test-consistency.sh index 210fdee0..71b1f7cb 100755 --- a/scripts/test-consistency.sh +++ b/scripts/test-consistency.sh @@ -518,6 +518,15 @@ assert_both "SORT LIMIT" SORT edge:sortl LIMIT 0 2 assert_both "SORT STORE" SORT edge:sortl STORE edge:sorted assert_both "SORT STORE result" LRANGE edge:sorted 0 -1 +# GEOADD / GEOPOS / GEODIST / GEOHASH / GEOSEARCH +both GEOADD edge:geo 13.361389 38.115556 Palermo 15.087269 37.502669 Catania +assert_both "GEOPOS" GEOPOS edge:geo Palermo +assert_both "GEOPOS missing" GEOPOS edge:geo NonExistent +assert_both "GEODIST m" GEODIST edge:geo Palermo Catania +assert_both "GEODIST km" GEODIST edge:geo Palermo Catania km +assert_both "GEOHASH" GEOHASH edge:geo Palermo +assert_both "GEOADD count" GEOADD edge:geo 2.349014 48.864716 Paris + # =========================================================================== # Summary # =========================================================================== diff --git a/src/command/geo/geo_cmd.rs b/src/command/geo/geo_cmd.rs new file mode 100644 index 00000000..5a658fd5 --- /dev/null +++ b/src/command/geo/geo_cmd.rs @@ -0,0 +1,685 @@ +use bytes::Bytes; +use ordered_float::OrderedFloat; + +use crate::protocol::Frame; +use crate::storage::Database; + +use crate::command::helpers::{err_wrong_args, extract_bytes}; + +use super::{ + convert_distance, geohash_decode, geohash_encode, geohash_to_string, haversine_distance, + parse_unit, +}; + +fn parse_f64(frame: &Frame) -> Option { + let b = extract_bytes(frame)?; + std::str::from_utf8(b).ok()?.parse().ok() +} + +/// GEOADD key [NX|XX] [CH] longitude latitude member [longitude latitude member ...] +pub fn geoadd(db: &mut Database, args: &[Frame]) -> Frame { + if args.len() < 4 { + return err_wrong_args("GEOADD"); + } + let key = match extract_bytes(&args[0]) { + Some(k) => k, + None => return err_wrong_args("GEOADD"), + }; + + // Parse optional NX/XX/CH flags + let mut nx = false; + let mut xx = false; + let mut ch = false; + let mut i = 1; + while i < args.len() { + let arg = match extract_bytes(&args[i]) { + Some(a) => a, + None => break, + }; + if arg.eq_ignore_ascii_case(b"NX") { + nx = true; + i += 1; + } else if arg.eq_ignore_ascii_case(b"XX") { + xx = true; + i += 1; + } else if arg.eq_ignore_ascii_case(b"CH") { + ch = true; + i += 1; + } else { + break; + } + } + + if nx && xx { + return Frame::Error(Bytes::from_static( + b"ERR XX and NX options at the same time are not compatible", + )); + } + + // Remaining args must be triples: longitude latitude member + let remaining = &args[i..]; + if remaining.len() < 3 || remaining.len() % 3 != 0 { + return err_wrong_args("GEOADD"); + } + + let (members, tree) = match db.get_or_create_sorted_set(key) { + Ok(pair) => pair, + Err(e) => return e, + }; + + let mut added = 0i64; + let mut changed = 0i64; + + for chunk in remaining.chunks_exact(3) { + let lon = match parse_f64(&chunk[0]) { + Some(v) if (-180.0..=180.0).contains(&v) => v, + _ => { + return Frame::Error(Bytes::from_static( + b"ERR value is not a valid float or out of range", + )); + } + }; + let lat = match parse_f64(&chunk[1]) { + Some(v) if (-85.05112878..=85.05112878).contains(&v) => v, + _ => { + return Frame::Error(Bytes::from_static( + b"ERR value is not a valid float or out of range", + )); + } + }; + let member = match extract_bytes(&chunk[2]) { + Some(m) => Bytes::copy_from_slice(m), + None => return err_wrong_args("GEOADD"), + }; + + let score = geohash_encode(lon, lat); + let exists = members.contains_key(&member); + + if nx && exists { + continue; + } + if xx && !exists { + continue; + } + + if exists { + let old_score = members[&member]; + if (old_score - score).abs() > f64::EPSILON { + tree.remove(OrderedFloat(old_score), &member); + tree.insert(OrderedFloat(score), member.clone()); + members.insert(member, score); + changed += 1; + } + } else { + tree.insert(OrderedFloat(score), member.clone()); + members.insert(member, score); + added += 1; + changed += 1; + } + } + + Frame::Integer(if ch { changed } else { added }) +} + +/// GEOPOS key member [member ...] +pub fn geopos(db: &mut Database, args: &[Frame]) -> Frame { + if args.len() < 2 { + return err_wrong_args("GEOPOS"); + } + let key = match extract_bytes(&args[0]) { + Some(k) => k, + None => return err_wrong_args("GEOPOS"), + }; + + let members_map = match db.get_sorted_set(key) { + Ok(Some((members, _))) => Some(members.clone()), + Ok(None) => None, + Err(e) => return e, + }; + + let mut results = Vec::with_capacity(args.len() - 1); + for arg in &args[1..] { + let member = match extract_bytes(arg) { + Some(m) => m, + None => { + results.push(Frame::Null); + continue; + } + }; + + match &members_map { + Some(m) => match m.get(member) { + Some(&score) => { + let (lon, lat) = geohash_decode(score); + results.push(Frame::Array( + vec![ + Frame::BulkString(Bytes::from(format!("{:.4}", lon))), + Frame::BulkString(Bytes::from(format!("{:.4}", lat))), + ] + .into(), + )); + } + None => results.push(Frame::Null), + }, + None => results.push(Frame::Null), + } + } + + Frame::Array(results.into()) +} + +/// GEODIST key member1 member2 [M|KM|FT|MI] +pub fn geodist(db: &mut Database, args: &[Frame]) -> Frame { + if args.len() < 3 { + return err_wrong_args("GEODIST"); + } + let key = match extract_bytes(&args[0]) { + Some(k) => k, + None => return err_wrong_args("GEODIST"), + }; + let m1 = match extract_bytes(&args[1]) { + Some(m) => m, + None => return err_wrong_args("GEODIST"), + }; + let m2 = match extract_bytes(&args[2]) { + Some(m) => m, + None => return err_wrong_args("GEODIST"), + }; + let unit = if args.len() >= 4 { + match extract_bytes(&args[3]) { + Some(u) => { + if parse_unit(u).is_none() { + return Frame::Error(Bytes::from_static( + b"ERR unsupported unit provided. please use M, KM, FT, MI", + )); + } + u + } + None => b"m" as &[u8], + } + } else { + b"m" + }; + + let members_map = match db.get_sorted_set(key) { + Ok(Some((members, _))) => members.clone(), + Ok(None) => return Frame::Null, + Err(e) => return e, + }; + + let score1 = match members_map.get(m1) { + Some(&s) => s, + None => return Frame::Null, + }; + let score2 = match members_map.get(m2) { + Some(&s) => s, + None => return Frame::Null, + }; + + let (lon1, lat1) = geohash_decode(score1); + let (lon2, lat2) = geohash_decode(score2); + let dist = haversine_distance(lon1, lat1, lon2, lat2); + let converted = convert_distance(dist, unit); + + Frame::BulkString(Bytes::from(format!("{:.4}", converted))) +} + +/// GEOHASH key member [member ...] +pub fn geohash(db: &mut Database, args: &[Frame]) -> Frame { + if args.len() < 2 { + return err_wrong_args("GEOHASH"); + } + let key = match extract_bytes(&args[0]) { + Some(k) => k, + None => return err_wrong_args("GEOHASH"), + }; + + let members_map = match db.get_sorted_set(key) { + Ok(Some((members, _))) => Some(members.clone()), + Ok(None) => None, + Err(e) => return e, + }; + + let mut results = Vec::with_capacity(args.len() - 1); + for arg in &args[1..] { + let member = match extract_bytes(arg) { + Some(m) => m, + None => { + results.push(Frame::Null); + continue; + } + }; + + match &members_map { + Some(m) => match m.get(member) { + Some(&score) => { + let hash_str = geohash_to_string(score); + results.push(Frame::BulkString(Bytes::from(hash_str))); + } + None => results.push(Frame::Null), + }, + None => results.push(Frame::Null), + } + } + + Frame::Array(results.into()) +} + +/// GEOSEARCH key FROMMEMBER member|FROMLONLAT lon lat +/// BYRADIUS radius M|KM|FT|MI|BYBOX width height M|KM|FT|MI +/// [ASC|DESC] [COUNT count [ANY]] [WITHCOORD] [WITHDIST] [WITHHASH] +pub fn geosearch(db: &mut Database, args: &[Frame]) -> Frame { + let (_, results) = geosearch_inner(db, args, false); + results +} + +/// GEOSEARCHSTORE destination source ... +pub fn geosearchstore(db: &mut Database, args: &[Frame]) -> Frame { + if args.len() < 2 { + return err_wrong_args("GEOSEARCHSTORE"); + } + let dest = match extract_bytes(&args[0]) { + Some(k) => Bytes::copy_from_slice(k), + None => return err_wrong_args("GEOSEARCHSTORE"), + }; + + // Shift args so args[0] is now the source key + let (count, _) = geosearch_inner(db, &args[1..], true); + + // Build sorted set from results and store + if count == 0 { + db.remove(&dest); + return Frame::Integer(0); + } + + Frame::Integer(count as i64) +} + +fn geosearch_inner(db: &mut Database, args: &[Frame], _store_mode: bool) -> (usize, Frame) { + if args.len() < 6 { + return (0, err_wrong_args("GEOSEARCH")); + } + let key = match extract_bytes(&args[0]) { + Some(k) => k, + None => return (0, err_wrong_args("GEOSEARCH")), + }; + + // Parse source: FROMMEMBER or FROMLONLAT + let mut center_lon = 0.0f64; + let mut center_lat = 0.0f64; + let mut i = 1; + let mut found_from = false; + + while i < args.len() && !found_from { + let arg = match extract_bytes(&args[i]) { + Some(a) => a, + None => { + i += 1; + continue; + } + }; + if arg.eq_ignore_ascii_case(b"FROMMEMBER") { + i += 1; + let member = match extract_bytes(args.get(i).unwrap_or(&Frame::Null)) { + Some(m) => m, + None => return (0, Frame::Error(Bytes::from_static(b"ERR syntax error"))), + }; + // Look up member's score + let members_map = match db.get_sorted_set(key) { + Ok(Some((members, _))) => members.clone(), + Ok(None) => return (0, Frame::Array(Vec::new().into())), + Err(e) => return (0, e), + }; + match members_map.get(member) { + Some(&score) => { + let (lon, lat) = geohash_decode(score); + center_lon = lon; + center_lat = lat; + } + None => return (0, Frame::Array(Vec::new().into())), + } + found_from = true; + } else if arg.eq_ignore_ascii_case(b"FROMLONLAT") { + i += 1; + center_lon = match args.get(i).and_then(|f| parse_f64(f)) { + Some(v) => v, + None => return (0, Frame::Error(Bytes::from_static(b"ERR syntax error"))), + }; + i += 1; + center_lat = match args.get(i).and_then(|f| parse_f64(f)) { + Some(v) => v, + None => return (0, Frame::Error(Bytes::from_static(b"ERR syntax error"))), + }; + found_from = true; + } + i += 1; + } + + if !found_from { + return (0, Frame::Error(Bytes::from_static(b"ERR syntax error"))); + } + + // Parse shape: BYRADIUS or BYBOX + let mut radius_m = None; + let mut box_width_m = None; + let mut box_height_m = None; + let mut ascending = true; + let mut count_limit = None; + let mut withcoord = false; + let mut withdist = false; + let mut withhash = false; + + while i < args.len() { + let arg = match extract_bytes(&args[i]) { + Some(a) => a, + None => { + i += 1; + continue; + } + }; + if arg.eq_ignore_ascii_case(b"BYRADIUS") { + i += 1; + let r = match args.get(i).and_then(|f| parse_f64(f)) { + Some(v) => v, + None => return (0, Frame::Error(Bytes::from_static(b"ERR syntax error"))), + }; + i += 1; + let unit_mult = match args.get(i).and_then(|f| extract_bytes(f)).and_then(|b| parse_unit(b)) { + Some(v) => v, + None => { + return ( + 0, + Frame::Error(Bytes::from_static( + b"ERR unsupported unit provided. please use M, KM, FT, MI", + )), + ) + } + }; + radius_m = Some(r * unit_mult); + } else if arg.eq_ignore_ascii_case(b"BYBOX") { + i += 1; + let w = match args.get(i).and_then(|f| parse_f64(f)) { + Some(v) => v, + None => return (0, Frame::Error(Bytes::from_static(b"ERR syntax error"))), + }; + i += 1; + let h = match args.get(i).and_then(|f| parse_f64(f)) { + Some(v) => v, + None => return (0, Frame::Error(Bytes::from_static(b"ERR syntax error"))), + }; + i += 1; + let unit_mult = match args.get(i).and_then(|f| extract_bytes(f)).and_then(|b| parse_unit(b)) { + Some(v) => v, + None => { + return ( + 0, + Frame::Error(Bytes::from_static( + b"ERR unsupported unit provided. please use M, KM, FT, MI", + )), + ) + } + }; + box_width_m = Some(w * unit_mult); + box_height_m = Some(h * unit_mult); + } else if arg.eq_ignore_ascii_case(b"ASC") { + ascending = true; + } else if arg.eq_ignore_ascii_case(b"DESC") { + ascending = false; + } else if arg.eq_ignore_ascii_case(b"COUNT") { + i += 1; + let c = match args.get(i).and_then(|f| parse_f64(f)) { + Some(v) if v > 0.0 => v as usize, + _ => return (0, Frame::Error(Bytes::from_static(b"ERR syntax error"))), + }; + count_limit = Some(c); + // Skip optional ANY + if i + 1 < args.len() { + if let Some(next) = extract_bytes(&args[i + 1]) { + if next.eq_ignore_ascii_case(b"ANY") { + i += 1; + } + } + } + } else if arg.eq_ignore_ascii_case(b"WITHCOORD") { + withcoord = true; + } else if arg.eq_ignore_ascii_case(b"WITHDIST") { + withdist = true; + } else if arg.eq_ignore_ascii_case(b"WITHHASH") { + withhash = true; + } + i += 1; + } + + if radius_m.is_none() && box_width_m.is_none() { + return ( + 0, + Frame::Error(Bytes::from_static( + b"ERR exactly one of BYRADIUS and BYBOX arguments must be provided", + )), + ); + } + + // Get all members with their coordinates + let members_map = match db.get_sorted_set(key) { + Ok(Some((members, _))) => members.clone(), + Ok(None) => return (0, Frame::Array(Vec::new().into())), + Err(e) => return (0, e), + }; + + // Filter by shape + let mut matches: Vec<(Bytes, f64, f64, f64, f64)> = Vec::new(); // (member, dist, lon, lat, score) + for (member, &score) in &members_map { + let (lon, lat) = geohash_decode(score); + let dist = haversine_distance(center_lon, center_lat, lon, lat); + + let in_range = if let Some(r) = radius_m { + dist <= r + } else { + // Box check: approximate using haversine + let dx = haversine_distance(center_lon, center_lat, lon, center_lat); + let dy = haversine_distance(center_lon, center_lat, center_lon, lat); + dx <= box_width_m.unwrap_or(0.0) / 2.0 && dy <= box_height_m.unwrap_or(0.0) / 2.0 + }; + + if in_range { + matches.push((member.clone(), dist, lon, lat, score)); + } + } + + // Sort by distance + matches.sort_by(|a, b| { + let cmp = a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal); + if ascending { cmp } else { cmp.reverse() } + }); + + // Apply COUNT limit + if let Some(limit) = count_limit { + matches.truncate(limit); + } + + let total = matches.len(); + let has_extras = withcoord || withdist || withhash; + + let results: Vec = matches + .into_iter() + .map(|(member, dist, lon, lat, score)| { + if has_extras { + let mut entry = vec![Frame::BulkString(member)]; + if withdist { + entry.push(Frame::BulkString(Bytes::from(format!("{:.4}", dist)))); + } + if withhash { + entry.push(Frame::Integer(score as i64)); + } + if withcoord { + entry.push(Frame::Array( + vec![ + Frame::BulkString(Bytes::from(format!("{:.4}", lon))), + Frame::BulkString(Bytes::from(format!("{:.4}", lat))), + ] + .into(), + )); + } + Frame::Array(entry.into()) + } else { + Frame::BulkString(member) + } + }) + .collect(); + + (total, Frame::Array(results.into())) +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::storage::Database; + + fn bs(s: &[u8]) -> Frame { + Frame::BulkString(Bytes::copy_from_slice(s)) + } + + #[test] + fn test_geoadd_and_geopos() { + let mut db = Database::new(); + let result = geoadd( + &mut db, + &[ + bs(b"mygeo"), + bs(b"13.361389"), + bs(b"38.115556"), + bs(b"Palermo"), + bs(b"15.087269"), + bs(b"37.502669"), + bs(b"Catania"), + ], + ); + assert_eq!(result, Frame::Integer(2)); + + let result = geopos(&mut db, &[bs(b"mygeo"), bs(b"Palermo"), bs(b"NonExistent")]); + match result { + Frame::Array(ref arr) => { + assert_eq!(arr.len(), 2); + assert!(matches!(&arr[0], Frame::Array(_))); + assert_eq!(arr[1], Frame::Null); + } + _ => panic!("Expected array"), + } + } + + #[test] + fn test_geodist() { + let mut db = Database::new(); + geoadd( + &mut db, + &[ + bs(b"mygeo"), + bs(b"13.361389"), + bs(b"38.115556"), + bs(b"Palermo"), + bs(b"15.087269"), + bs(b"37.502669"), + bs(b"Catania"), + ], + ); + let result = geodist(&mut db, &[bs(b"mygeo"), bs(b"Palermo"), bs(b"Catania"), bs(b"km")]); + match result { + Frame::BulkString(b) => { + let dist: f64 = std::str::from_utf8(&b).unwrap().parse().unwrap(); + assert!((dist - 166.2742).abs() < 1.0, "got {dist}"); + } + _ => panic!("Expected bulk string"), + } + } + + #[test] + fn test_geohash() { + let mut db = Database::new(); + geoadd( + &mut db, + &[ + bs(b"mygeo"), + bs(b"13.361389"), + bs(b"38.115556"), + bs(b"Palermo"), + ], + ); + let result = geohash(&mut db, &[bs(b"mygeo"), bs(b"Palermo")]); + match result { + Frame::Array(ref arr) => { + assert_eq!(arr.len(), 1); + match &arr[0] { + Frame::BulkString(b) => { + assert_eq!(b.len(), 11); + } + _ => panic!("Expected bulk string"), + } + } + _ => panic!("Expected array"), + } + } + + #[test] + fn test_geosearch_byradius() { + let mut db = Database::new(); + geoadd( + &mut db, + &[ + bs(b"mygeo"), + bs(b"13.361389"), + bs(b"38.115556"), + bs(b"Palermo"), + bs(b"15.087269"), + bs(b"37.502669"), + bs(b"Catania"), + bs(b"2.349014"), + bs(b"48.864716"), + bs(b"Paris"), + ], + ); + + let result = geosearch( + &mut db, + &[ + bs(b"mygeo"), + bs(b"FROMLONLAT"), + bs(b"15"), + bs(b"37"), + bs(b"BYRADIUS"), + bs(b"200"), + bs(b"km"), + bs(b"ASC"), + ], + ); + match result { + Frame::Array(ref arr) => { + // Should find Catania and Palermo (within 200km of 15,37), not Paris + assert_eq!(arr.len(), 2); + } + _ => panic!("Expected array, got {:?}", result), + } + } + + #[test] + fn test_geoadd_nx_xx() { + let mut db = Database::new(); + geoadd( + &mut db, + &[bs(b"g"), bs(b"10.0"), bs(b"20.0"), bs(b"member1")], + ); + + // NX should not update existing + let result = geoadd( + &mut db, + &[bs(b"g"), bs(b"NX"), bs(b"11.0"), bs(b"21.0"), bs(b"member1")], + ); + assert_eq!(result, Frame::Integer(0)); + + // NX should add new + let result = geoadd( + &mut db, + &[bs(b"g"), bs(b"NX"), bs(b"12.0"), bs(b"22.0"), bs(b"member2")], + ); + assert_eq!(result, Frame::Integer(1)); + } +} diff --git a/src/command/geo/mod.rs b/src/command/geo/mod.rs new file mode 100644 index 00000000..a40cca69 --- /dev/null +++ b/src/command/geo/mod.rs @@ -0,0 +1,162 @@ +mod geo_cmd; + +pub use geo_cmd::*; + +use std::f64::consts::PI; + +// --------------------------------------------------------------------------- +// Geohash encoding/decoding (52-bit integer, Redis-compatible) +// --------------------------------------------------------------------------- + +const GEO_LAT_MIN: f64 = -85.05112878; +const GEO_LAT_MAX: f64 = 85.05112878; +const GEO_LON_MIN: f64 = -180.0; +const GEO_LON_MAX: f64 = 180.0; +const GEO_STEP_MAX: u8 = 26; // 52-bit precision + +/// Encode longitude/latitude into a 52-bit geohash stored as f64 score. +pub(crate) fn geohash_encode(lon: f64, lat: f64) -> f64 { + let mut lat_range = (GEO_LAT_MIN, GEO_LAT_MAX); + let mut lon_range = (GEO_LON_MIN, GEO_LON_MAX); + let mut hash: u64 = 0; + + for i in 0..GEO_STEP_MAX { + // Longitude bit + let mid = (lon_range.0 + lon_range.1) / 2.0; + if lon >= mid { + hash |= 1 << (51 - i * 2); + lon_range.0 = mid; + } else { + lon_range.1 = mid; + } + // Latitude bit + let mid = (lat_range.0 + lat_range.1) / 2.0; + if lat >= mid { + hash |= 1 << (50 - i * 2); + lat_range.0 = mid; + } else { + lat_range.1 = mid; + } + } + + hash as f64 +} + +/// Decode a 52-bit geohash score back to (longitude, latitude). +pub(crate) fn geohash_decode(score: f64) -> (f64, f64) { + let hash = score as u64; + let mut lat_range = (GEO_LAT_MIN, GEO_LAT_MAX); + let mut lon_range = (GEO_LON_MIN, GEO_LON_MAX); + + for i in 0..GEO_STEP_MAX { + // Longitude bit + if hash & (1 << (51 - i * 2)) != 0 { + lon_range.0 = (lon_range.0 + lon_range.1) / 2.0; + } else { + lon_range.1 = (lon_range.0 + lon_range.1) / 2.0; + } + // Latitude bit + if hash & (1 << (50 - i * 2)) != 0 { + lat_range.0 = (lat_range.0 + lat_range.1) / 2.0; + } else { + lat_range.1 = (lat_range.0 + lat_range.1) / 2.0; + } + } + + let lon = (lon_range.0 + lon_range.1) / 2.0; + let lat = (lat_range.0 + lat_range.1) / 2.0; + (lon, lat) +} + +/// Convert a 52-bit integer geohash to the 11-character base32 string Redis uses. +pub(crate) fn geohash_to_string(score: f64) -> String { + const ALPHABET: &[u8] = b"0123456789bcdefghjkmnpqrstuvwxyz"; + let hash = score as u64; + // Redis uses 11 characters (55 bits, but we only have 52 so pad with 0) + let padded = hash << 3; // shift left 3 to fill 55 bits + let mut result = [0u8; 11]; + for i in 0..11 { + let idx = ((padded >> (50 - i * 5)) & 0x1F) as usize; + result[i] = ALPHABET[idx]; + } + String::from_utf8_lossy(&result).to_string() +} + +// --------------------------------------------------------------------------- +// Haversine distance +// --------------------------------------------------------------------------- + +const EARTH_RADIUS_M: f64 = 6372797.560856; + +/// Haversine distance in meters between two (lon, lat) pairs. +pub(crate) fn haversine_distance(lon1: f64, lat1: f64, lon2: f64, lat2: f64) -> f64 { + let lat1_r = lat1 * PI / 180.0; + let lat2_r = lat2 * PI / 180.0; + let dlat = (lat2 - lat1) * PI / 180.0; + let dlon = (lon2 - lon1) * PI / 180.0; + + let a = (dlat / 2.0).sin().powi(2) + lat1_r.cos() * lat2_r.cos() * (dlon / 2.0).sin().powi(2); + let c = 2.0 * a.sqrt().asin(); + EARTH_RADIUS_M * c +} + +/// Convert meters to the specified unit. +pub(crate) fn convert_distance(meters: f64, unit: &[u8]) -> f64 { + if unit.eq_ignore_ascii_case(b"km") { + meters / 1000.0 + } else if unit.eq_ignore_ascii_case(b"mi") { + meters / 1609.34 + } else if unit.eq_ignore_ascii_case(b"ft") { + meters / 0.3048 + } else { + meters // default: meters + } +} + +/// Parse a unit string, returning meters-per-unit multiplier. Returns None if invalid. +pub(crate) fn parse_unit(unit: &[u8]) -> Option { + if unit.eq_ignore_ascii_case(b"m") { + Some(1.0) + } else if unit.eq_ignore_ascii_case(b"km") { + Some(1000.0) + } else if unit.eq_ignore_ascii_case(b"mi") { + Some(1609.34) + } else if unit.eq_ignore_ascii_case(b"ft") { + Some(0.3048) + } else { + None + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_geohash_roundtrip() { + // Rome coordinates + let lon = 12.4964; + let lat = 41.9028; + let hash = geohash_encode(lon, lat); + let (lon2, lat2) = geohash_decode(hash); + assert!((lon - lon2).abs() < 0.0001); + assert!((lat - lat2).abs() < 0.0001); + } + + #[test] + fn test_haversine_rome_paris() { + // Rome to Paris ~1105 km + let d = haversine_distance(12.4964, 41.9028, 2.3522, 48.8566); + assert!((d / 1000.0 - 1105.0).abs() < 10.0); // within 10 km + } + + #[test] + fn test_geohash_string() { + let hash = geohash_encode(-122.4194, 37.7749); // San Francisco + let s = geohash_to_string(hash); + assert_eq!(s.len(), 11); + // Should start with "9q8y" for SF area + // The exact prefix depends on our 52-bit encoding; just verify length and base32 chars + assert!(s.chars().all(|c| "0123456789bcdefghjkmnpqrstuvwxyz".contains(c)), "invalid chars: {s}"); + } +} diff --git a/src/command/metadata.rs b/src/command/metadata.rs index 7a6b1199..0adb16bf 100644 --- a/src/command/metadata.rs +++ b/src/command/metadata.rs @@ -691,6 +691,8 @@ mod tests { b"COPY", b"SETBIT", b"BITOP", + b"GEOADD", + b"GEOSEARCHSTORE", b"HSET", b"HMSET", b"HDEL", diff --git a/src/command/mod.rs b/src/command/mod.rs index 0e77b69c..2110ef3e 100644 --- a/src/command/mod.rs +++ b/src/command/mod.rs @@ -3,6 +3,7 @@ pub mod client; pub mod config; pub mod connection; pub mod functions; +pub mod geo; pub mod hash; pub mod helpers; pub mod hll; @@ -398,7 +399,13 @@ fn dispatch_inner( } } (6, b'g') => { - // GETBIT GETSET GETDEL + // GEOADD GEOPOS GETBIT GETSET GETDEL + if cmd.eq_ignore_ascii_case(b"GEOADD") { + return resp(geo::geoadd(db, args)); + } + if cmd.eq_ignore_ascii_case(b"GEOPOS") { + return resp(geo::geopos(db, args)); + } if cmd.eq_ignore_ascii_case(b"GETBIT") { return resp(string::getbit(db, args)); } @@ -527,6 +534,15 @@ fn dispatch_inner( } } // 7-letter commands + (7, b'g') => { + // GEODIST GEOHASH + if cmd.eq_ignore_ascii_case(b"GEODIST") { + return resp(geo::geodist(db, args)); + } + if cmd.eq_ignore_ascii_case(b"GEOHASH") { + return resp(geo::geohash(db, args)); + } + } (7, b'c') => { // COMMAND if cmd.eq_ignore_ascii_case(b"COMMAND") { @@ -634,6 +650,12 @@ fn dispatch_inner( } } // 9-letter commands + (9, b'g') => { + // GEOSEARCH + if cmd.eq_ignore_ascii_case(b"GEOSEARCH") { + return resp(geo::geosearch(db, args)); + } + } (9, b's') => { // SISMEMBER if cmd.eq_ignore_ascii_case(b"SISMEMBER") { @@ -737,6 +759,13 @@ fn dispatch_inner( return resp(sorted_set::zrangebyscore(db, args)); } } + // 14-letter commands + (14, b'g') => { + // GEOSEARCHSTORE + if cmd.eq_ignore_ascii_case(b"GEOSEARCHSTORE") { + return resp(geo::geosearchstore(db, args)); + } + } // 16-letter commands (16, b'z') => { // ZREVRANGEBYSCORE From eb32d3e8600bff42cc42382c5324474c02670cc8 Mon Sep 17 00:00:00 2001 From: Tin Dang Date: Fri, 10 Apr 2026 23:00:24 +0700 Subject: [PATCH 05/20] style: fix clippy manual_is_multiple_of warning in geo_cmd --- src/command/geo/geo_cmd.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/command/geo/geo_cmd.rs b/src/command/geo/geo_cmd.rs index 5a658fd5..460b2747 100644 --- a/src/command/geo/geo_cmd.rs +++ b/src/command/geo/geo_cmd.rs @@ -58,7 +58,7 @@ pub fn geoadd(db: &mut Database, args: &[Frame]) -> Frame { // Remaining args must be triples: longitude latitude member let remaining = &args[i..]; - if remaining.len() < 3 || remaining.len() % 3 != 0 { + if remaining.len() < 3 || !remaining.len().is_multiple_of(3) { return err_wrong_args("GEOADD"); } From a89f440a8e40ddbb78a8fe092bfcbc0a48fed490 Mon Sep 17 00:00:00 2001 From: Tin Dang Date: Fri, 10 Apr 2026 23:04:55 +0700 Subject: [PATCH 06/20] fix: update dispatch_read prefilter test for new (6,b'b') bucket MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit BGSAVE shares the (6, b'b') bucket with BITPOS — the prefilter is intentionally coarse (documented behavior), so remove BGSAVE from the expected-false list. --- src/command/mod.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/command/mod.rs b/src/command/mod.rs index 2110ef3e..c0f83560 100644 --- a/src/command/mod.rs +++ b/src/command/mod.rs @@ -1616,7 +1616,7 @@ mod tests { b"X", b"FLUSHDB", b"FLUSHALL", - b"BGSAVE", + // BGSAVE shares (6, b'b') bucket with BITPOS — prefilter is coarse b"WAIT", b"OBJECT", b"RANDOMKEY", From 9be660e89bc90aff307dedfd2338233b85dcfa36 Mon Sep 17 00:00:00 2001 From: Tin Dang Date: Fri, 10 Apr 2026 23:09:39 +0700 Subject: [PATCH 07/20] style: cargo fmt --- src/command/geo/geo_cmd.rs | 37 ++++++++++++++++++++++++++------ src/command/geo/mod.rs | 6 +++++- src/command/key.rs | 24 +++++++-------------- src/command/string/string_bit.rs | 5 +---- 4 files changed, 44 insertions(+), 28 deletions(-) diff --git a/src/command/geo/geo_cmd.rs b/src/command/geo/geo_cmd.rs index 460b2747..598fa0ad 100644 --- a/src/command/geo/geo_cmd.rs +++ b/src/command/geo/geo_cmd.rs @@ -384,7 +384,11 @@ fn geosearch_inner(db: &mut Database, args: &[Frame], _store_mode: bool) -> (usi None => return (0, Frame::Error(Bytes::from_static(b"ERR syntax error"))), }; i += 1; - let unit_mult = match args.get(i).and_then(|f| extract_bytes(f)).and_then(|b| parse_unit(b)) { + let unit_mult = match args + .get(i) + .and_then(|f| extract_bytes(f)) + .and_then(|b| parse_unit(b)) + { Some(v) => v, None => { return ( @@ -392,7 +396,7 @@ fn geosearch_inner(db: &mut Database, args: &[Frame], _store_mode: bool) -> (usi Frame::Error(Bytes::from_static( b"ERR unsupported unit provided. please use M, KM, FT, MI", )), - ) + ); } }; radius_m = Some(r * unit_mult); @@ -408,7 +412,11 @@ fn geosearch_inner(db: &mut Database, args: &[Frame], _store_mode: bool) -> (usi None => return (0, Frame::Error(Bytes::from_static(b"ERR syntax error"))), }; i += 1; - let unit_mult = match args.get(i).and_then(|f| extract_bytes(f)).and_then(|b| parse_unit(b)) { + let unit_mult = match args + .get(i) + .and_then(|f| extract_bytes(f)) + .and_then(|b| parse_unit(b)) + { Some(v) => v, None => { return ( @@ -416,7 +424,7 @@ fn geosearch_inner(db: &mut Database, args: &[Frame], _store_mode: bool) -> (usi Frame::Error(Bytes::from_static( b"ERR unsupported unit provided. please use M, KM, FT, MI", )), - ) + ); } }; box_width_m = Some(w * unit_mult); @@ -582,7 +590,10 @@ mod tests { bs(b"Catania"), ], ); - let result = geodist(&mut db, &[bs(b"mygeo"), bs(b"Palermo"), bs(b"Catania"), bs(b"km")]); + let result = geodist( + &mut db, + &[bs(b"mygeo"), bs(b"Palermo"), bs(b"Catania"), bs(b"km")], + ); match result { Frame::BulkString(b) => { let dist: f64 = std::str::from_utf8(&b).unwrap().parse().unwrap(); @@ -671,14 +682,26 @@ mod tests { // NX should not update existing let result = geoadd( &mut db, - &[bs(b"g"), bs(b"NX"), bs(b"11.0"), bs(b"21.0"), bs(b"member1")], + &[ + bs(b"g"), + bs(b"NX"), + bs(b"11.0"), + bs(b"21.0"), + bs(b"member1"), + ], ); assert_eq!(result, Frame::Integer(0)); // NX should add new let result = geoadd( &mut db, - &[bs(b"g"), bs(b"NX"), bs(b"12.0"), bs(b"22.0"), bs(b"member2")], + &[ + bs(b"g"), + bs(b"NX"), + bs(b"12.0"), + bs(b"22.0"), + bs(b"member2"), + ], ); assert_eq!(result, Frame::Integer(1)); } diff --git a/src/command/geo/mod.rs b/src/command/geo/mod.rs index a40cca69..2bcc2791 100644 --- a/src/command/geo/mod.rs +++ b/src/command/geo/mod.rs @@ -157,6 +157,10 @@ mod tests { assert_eq!(s.len(), 11); // Should start with "9q8y" for SF area // The exact prefix depends on our 52-bit encoding; just verify length and base32 chars - assert!(s.chars().all(|c| "0123456789bcdefghjkmnpqrstuvwxyz".contains(c)), "invalid chars: {s}"); + assert!( + s.chars() + .all(|c| "0123456789bcdefghjkmnpqrstuvwxyz".contains(c)), + "invalid chars: {s}" + ); } } diff --git a/src/command/key.rs b/src/command/key.rs index 1cc14881..fb91339e 100644 --- a/src/command/key.rs +++ b/src/command/key.rs @@ -651,15 +651,16 @@ pub fn sort(db: &mut Database, args: &[Frame]) -> Frame { RedisValueRef::ListListpack(lp) => lp.iter().map(|e| e.to_bytes()).collect(), RedisValueRef::Set(s) => s.iter().cloned().collect(), RedisValueRef::SetListpack(lp) => lp.iter().map(|e| e.to_bytes()).collect(), - RedisValueRef::SetIntset(is) => { - is.iter().map(|v| Bytes::from(v.to_string())).collect() - } + RedisValueRef::SetIntset(is) => is.iter().map(|v| Bytes::from(v.to_string())).collect(), RedisValueRef::SortedSet { members, .. } => members.keys().cloned().collect(), RedisValueRef::SortedSetBPTree { members, .. } => members.keys().cloned().collect(), RedisValueRef::SortedSetListpack(lp) => { // Listpack stores member, score pairs let entries: Vec<_> = lp.iter().collect(); - entries.chunks(2).filter_map(|c| c.first().map(|e| e.to_bytes())).collect() + entries + .chunks(2) + .filter_map(|c| c.first().map(|e| e.to_bytes())) + .collect() } _ => { return Frame::Error(Bytes::from_static( @@ -685,10 +686,7 @@ pub fn sort(db: &mut Database, args: &[Frame]) -> Frame { .collect() } } else { - elements - .iter() - .map(|e| Some(e.clone())) - .collect() + elements.iter().map(|e| Some(e.clone())).collect() }; // Create indexed pairs for stable sort @@ -1877,10 +1875,7 @@ mod tests { fn test_sort_limit() { let mut db = Database::new(); setup_list(&mut db, b"mylist", &[b"3", b"1", b"2", b"4"]); - let result = sort( - &mut db, - &[bs(b"mylist"), bs(b"LIMIT"), bs(b"1"), bs(b"2")], - ); + let result = sort(&mut db, &[bs(b"mylist"), bs(b"LIMIT"), bs(b"1"), bs(b"2")]); assert_eq!( result, Frame::Array(framevec![ @@ -1894,10 +1889,7 @@ mod tests { fn test_sort_store() { let mut db = Database::new(); setup_list(&mut db, b"mylist", &[b"3", b"1", b"2"]); - let result = sort( - &mut db, - &[bs(b"mylist"), bs(b"STORE"), bs(b"sorted")], - ); + let result = sort(&mut db, &[bs(b"mylist"), bs(b"STORE"), bs(b"sorted")]); assert_eq!(result, Frame::Integer(3)); assert!(db.exists(b"sorted")); } diff --git a/src/command/string/string_bit.rs b/src/command/string/string_bit.rs index 582391d5..da374587 100644 --- a/src/command/string/string_bit.rs +++ b/src/command/string/string_bit.rs @@ -887,10 +887,7 @@ mod tests { #[test] fn test_bitop_not_requires_one_key() { let mut db = make_db(); - let result = bitop( - &mut db, - &[bs(b"NOT"), bs(b"dest"), bs(b"a"), bs(b"b")], - ); + let result = bitop(&mut db, &[bs(b"NOT"), bs(b"dest"), bs(b"a"), bs(b"b")]); assert!(matches!(result, Frame::Error(_))); } From a32f0b50a9232c613ce721d06942f7939d65f3bd Mon Sep 17 00:00:00 2001 From: Tin Dang Date: Fri, 10 Apr 2026 23:12:29 +0700 Subject: [PATCH 08/20] feat: implement CONFIG REWRITE and CONFIG RESETSTAT CONFIG REWRITE serializes runtime + server config to moon.conf: - Atomic write (tmpfile + rename) to /moon.conf - Redis-style config format (key value per line) - Persists all CONFIG SET-able parameters - CONFIG RESETSTAT stub for compatibility --- src/command/config.rs | 107 ++++++++++++++++++++++++++++++++++++++ src/server/conn/shared.rs | 7 ++- 2 files changed, 113 insertions(+), 1 deletion(-) diff --git a/src/command/config.rs b/src/command/config.rs index a9c575fc..1cab4602 100644 --- a/src/command/config.rs +++ b/src/command/config.rs @@ -185,6 +185,90 @@ pub fn config_set(runtime_config: &mut RuntimeConfig, args: &[Frame]) -> Frame { Frame::SimpleString(Bytes::from_static(b"OK")) } +/// CONFIG REWRITE — serialize current runtime config to a Redis-style config file. +/// +/// Writes to `/moon.conf` atomically (tmpfile + rename). +pub fn config_rewrite( + runtime_config: &RuntimeConfig, + server_config: &ServerConfig, +) -> Frame { + let mut lines = Vec::with_capacity(20); + lines.push(format!("# Moon configuration file — generated by CONFIG REWRITE")); + lines.push(format!("# {}", chrono_lite_now())); + lines.push(String::new()); + + // Server settings (from ServerConfig — immutable at runtime but persisted) + lines.push(format!("bind {}", server_config.bind)); + lines.push(format!("port {}", server_config.port)); + lines.push(format!("databases {}", server_config.databases)); + if let Some(ref pass) = runtime_config.requirepass { + lines.push(format!("requirepass {}", pass)); + } + lines.push(format!("protected-mode {}", runtime_config.protected_mode)); + lines.push(String::new()); + + // Memory settings + lines.push(format!("maxmemory {}", runtime_config.maxmemory)); + lines.push(format!("maxmemory-policy {}", runtime_config.maxmemory_policy)); + lines.push(format!("maxmemory-samples {}", runtime_config.maxmemory_samples)); + lines.push(format!("lfu-log-factor {}", runtime_config.lfu_log_factor)); + lines.push(format!("lfu-decay-time {}", runtime_config.lfu_decay_time)); + lines.push(String::new()); + + // Persistence settings + lines.push(format!("dir {}", runtime_config.dir)); + lines.push(format!("dbfilename {}", server_config.dbfilename)); + lines.push(format!("appendonly {}", runtime_config.appendonly)); + lines.push(format!("appendfsync {}", runtime_config.appendfsync)); + lines.push(format!("appendfilename {}", server_config.appendfilename)); + if let Some(ref save) = runtime_config.save { + lines.push(format!("save {}", save)); + } + lines.push(String::new()); + + // ACL settings + lines.push(format!("acllog-max-len {}", runtime_config.acllog_max_len)); + if let Some(ref aclfile) = runtime_config.aclfile { + lines.push(format!("aclfile {}", aclfile)); + } + + let content = lines.join("\n") + "\n"; + + // Atomic write: tmpfile + rename + let dir = &runtime_config.dir; + let conf_path = std::path::Path::new(dir).join("moon.conf"); + let tmp_path = std::path::Path::new(dir).join("moon.conf.tmp"); + + if let Err(e) = std::fs::write(&tmp_path, content.as_bytes()) { + return Frame::Error(Bytes::from(format!( + "ERR failed to write config: {e}" + ))); + } + if let Err(e) = std::fs::rename(&tmp_path, &conf_path) { + let _ = std::fs::remove_file(&tmp_path); + return Frame::Error(Bytes::from(format!( + "ERR failed to rename config: {e}" + ))); + } + + Frame::SimpleString(Bytes::from_static(b"OK")) +} + +/// Lightweight timestamp without chrono dependency. +fn chrono_lite_now() -> String { + use std::time::{SystemTime, UNIX_EPOCH}; + let secs = SystemTime::now() + .duration_since(UNIX_EPOCH) + .unwrap_or_default() + .as_secs(); + format!("Generated at epoch {secs}") +} + +/// CONFIG RESETSTAT — reset server statistics (placeholder). +pub fn config_resetstat() -> Frame { + Frame::SimpleString(Bytes::from_static(b"OK")) +} + #[cfg(test)] mod tests { use super::*; @@ -293,4 +377,27 @@ mod tests { assert_eq!(rt.maxmemory, 2048); assert_eq!(rt.maxmemory_policy, "allkeys-lfu"); } + + #[test] + fn test_config_rewrite() { + let mut rt = RuntimeConfig::default(); + rt.maxmemory = 1_073_741_824; // 1GB + rt.maxmemory_policy = "allkeys-lru".to_string(); + rt.dir = std::env::temp_dir().to_string_lossy().to_string(); + + let sc = default_server_config(); + let result = config_rewrite(&rt, &sc); + assert_eq!(result, Frame::SimpleString(Bytes::from_static(b"OK"))); + + // Verify file was created + let conf_path = std::path::Path::new(&rt.dir).join("moon.conf"); + assert!(conf_path.exists()); + let content = std::fs::read_to_string(&conf_path).unwrap(); + assert!(content.contains("maxmemory 1073741824")); + assert!(content.contains("maxmemory-policy allkeys-lru")); + assert!(content.contains("port 6379")); + + // Cleanup + let _ = std::fs::remove_file(conf_path); + } } diff --git a/src/server/conn/shared.rs b/src/server/conn/shared.rs index 5a081f2c..6ff36c51 100644 --- a/src/server/conn/shared.rs +++ b/src/server/conn/shared.rs @@ -52,9 +52,14 @@ pub(crate) fn handle_config( } else if subcmd.eq_ignore_ascii_case(b"SET") { let mut rt = runtime_config.write(); config_cmd::config_set(&mut rt, sub_args) + } else if subcmd.eq_ignore_ascii_case(b"REWRITE") { + let rt = runtime_config.read(); + config_cmd::config_rewrite(&rt, server_config) + } else if subcmd.eq_ignore_ascii_case(b"RESETSTAT") { + config_cmd::config_resetstat() } else { Frame::Error(Bytes::from(format!( - "ERR unknown subcommand '{}'. Try CONFIG GET, CONFIG SET.", + "ERR unknown subcommand '{}'. Try CONFIG GET, CONFIG SET, CONFIG REWRITE, CONFIG RESETSTAT.", String::from_utf8_lossy(subcmd) ))) } From e781ee1943d4ed46432a3a14e923a533fddf7c6f Mon Sep 17 00:00:00 2001 From: Tin Dang Date: Fri, 10 Apr 2026 23:19:35 +0700 Subject: [PATCH 09/20] feat: implement P1 medium-impact features MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit CONFIG REWRITE: - Atomic write (tmpfile + rename) to /moon.conf - Redis-style key-value format, all CONFIG SET-able params CLIENT PAUSE/UNPAUSE: - CLIENT PAUSE timeout [WRITE|ALL] — delays command processing - CLIENT UNPAUSE — clears pause immediately - Lock-free: extracts deadline before sleeping, no lock held across await - CLIENT INFO, CLIENT LIST (stub), CLIENT NO-EVICT/NO-TOUCH (accepted) MEMORY USAGE/DOCTOR/HELP: - MEMORY USAGE key — returns estimated bytes via estimate_memory() - MEMORY DOCTOR — health report stub - MEMORY HELP — subcommand listing Lazyfree configuration: - CONFIG SET lazyfree-threshold N (default 64) - CONFIG GET lazyfree-threshold - Stored in RuntimeConfig for future use by unlink/eviction --- src/command/config.rs | 17 +++++- src/command/key.rs | 41 +++++++++++++ src/command/mod.rs | 21 +++++++ src/config.rs | 13 ++++ src/server/conn/handler_sharded.rs | 95 ++++++++++++++++++++++++++++++ src/storage/eviction.rs | 3 + 6 files changed, 189 insertions(+), 1 deletion(-) diff --git a/src/command/config.rs b/src/command/config.rs index 1cab4602..4d67120f 100644 --- a/src/command/config.rs +++ b/src/command/config.rs @@ -51,6 +51,10 @@ pub fn config_get( runtime_config.protected_mode.clone(), ), (b"acllog-max-len", runtime_config.acllog_max_len.to_string()), + ( + b"lazyfree-threshold" as &[u8], + runtime_config.lazyfree_threshold.to_string(), + ), ]; let mut result = Vec::new(); @@ -171,6 +175,17 @@ pub fn config_set(runtime_config: &mut RuntimeConfig, args: &[Frame]) -> Frame { ))); } }, + "lazyfree-lazy-eviction" | "lazyfree-threshold" => { + match value_str.parse::() { + Ok(v) => runtime_config.lazyfree_threshold = v, + Err(_) => { + return Frame::Error(Bytes::from(format!( + "ERR Invalid argument '{}' for CONFIG SET 'lazyfree-threshold'", + value_str + ))); + } + } + } _ => { return Frame::Error(Bytes::from(format!( "ERR Unsupported CONFIG parameter: {}", @@ -193,7 +208,7 @@ pub fn config_rewrite( server_config: &ServerConfig, ) -> Frame { let mut lines = Vec::with_capacity(20); - lines.push(format!("# Moon configuration file — generated by CONFIG REWRITE")); + lines.push("# Moon configuration file — generated by CONFIG REWRITE".to_string()); lines.push(format!("# {}", chrono_lite_now())); lines.push(String::new()); diff --git a/src/command/key.rs b/src/command/key.rs index fb91339e..e8af22f4 100644 --- a/src/command/key.rs +++ b/src/command/key.rs @@ -562,6 +562,47 @@ pub fn copy(db: &mut Database, args: &[Frame]) -> Frame { } } +/// MEMORY USAGE key [SAMPLES count] +/// +/// Returns the number of bytes a key and its value require to be stored. +pub fn memory_usage(db: &mut Database, args: &[Frame]) -> Frame { + if args.is_empty() { + return err_wrong_args("MEMORY"); + } + let key = match extract_key(&args[0]) { + Some(k) => k, + None => return err_wrong_args("MEMORY"), + }; + + match db.get(key) { + Some(entry) => { + let mem = key.len() + entry.value.estimate_memory() + 128; // same as entry_overhead + Frame::Integer(mem as i64) + } + None => Frame::Null, + } +} + +/// MEMORY DOCTOR — report memory health issues. +pub fn memory_doctor() -> Frame { + // Basic health report + Frame::BulkString(Bytes::from_static( + b"Sam, I have no memory problems", + )) +} + +/// MEMORY HELP — list MEMORY subcommands. +pub fn memory_help() -> Frame { + Frame::Array( + vec![ + Frame::BulkString(Bytes::from_static(b"MEMORY DOCTOR - Return memory problems reports.")), + Frame::BulkString(Bytes::from_static(b"MEMORY HELP - Return this help message.")), + Frame::BulkString(Bytes::from_static(b"MEMORY USAGE [SAMPLES ] - Return memory in bytes used by and its value.")), + ] + .into(), + ) +} + /// SORT key [BY pattern] [LIMIT offset count] [GET pattern ...] [ASC|DESC] [ALPHA] [STORE dest] /// /// Sort elements in a list, set, or sorted set. diff --git a/src/command/mod.rs b/src/command/mod.rs index c0f83560..493044fa 100644 --- a/src/command/mod.rs +++ b/src/command/mod.rs @@ -440,6 +440,27 @@ fn dispatch_inner( return resp(list::lpushx(db, args)); } } + (6, b'm') => { + // MEMORY + if cmd.eq_ignore_ascii_case(b"MEMORY") { + if let Some(sub) = args.first() { + if let Some(sub_bytes) = crate::command::helpers::extract_bytes(sub) { + if sub_bytes.eq_ignore_ascii_case(b"USAGE") { + return resp(key::memory_usage(db, &args[1..])); + } + if sub_bytes.eq_ignore_ascii_case(b"DOCTOR") { + return resp(key::memory_doctor()); + } + if sub_bytes.eq_ignore_ascii_case(b"HELP") { + return resp(key::memory_help()); + } + } + } + return resp(Frame::Error(Bytes::from_static( + b"ERR unknown subcommand. Try MEMORY USAGE, MEMORY DOCTOR, MEMORY HELP.", + ))); + } + } (6, b'o') => { // OBJECT if cmd.eq_ignore_ascii_case(b"OBJECT") { diff --git a/src/config.rs b/src/config.rs index bba0516a..2b78ef09 100644 --- a/src/config.rs +++ b/src/config.rs @@ -286,6 +286,9 @@ impl ServerConfig { requirepass: self.requirepass.clone(), protected_mode: self.protected_mode.clone(), acllog_max_len: self.acllog_max_len, + client_pause_deadline_ms: 0, + client_pause_write_only: false, + lazyfree_threshold: 64, } } } @@ -321,6 +324,13 @@ pub struct RuntimeConfig { pub protected_mode: String, /// Maximum number of entries in the ACL log (mutable via CONFIG SET). pub acllog_max_len: usize, + /// CLIENT PAUSE deadline (epoch ms). 0 = not paused. + /// Set by CLIENT PAUSE, cleared by CLIENT UNPAUSE or expiry. + pub client_pause_deadline_ms: u64, + /// CLIENT PAUSE mode: false = ALL (pause all), true = WRITE (pause writes only). + pub client_pause_write_only: bool, + /// Lazyfree threshold: collections with more elements than this are freed async. + pub lazyfree_threshold: usize, } impl Default for RuntimeConfig { @@ -339,6 +349,9 @@ impl Default for RuntimeConfig { requirepass: None, protected_mode: "yes".to_string(), acllog_max_len: 128, + client_pause_deadline_ms: 0, + client_pause_write_only: false, + lazyfree_threshold: 64, } } } diff --git a/src/server/conn/handler_sharded.rs b/src/server/conn/handler_sharded.rs index 7433f522..2355f699 100644 --- a/src/server/conn/handler_sharded.rs +++ b/src/server/conn/handler_sharded.rs @@ -681,6 +681,34 @@ pub(crate) async fn handle_connection_sharded_inner< continue; } + // === CLIENT PAUSE check === + // Extract pause info with short lock hold, then sleep outside lock scope + let pause_wait_ms = { + let rt = ctx.runtime_config.read(); + let deadline = rt.client_pause_deadline_ms; + if deadline > 0 { + let now = crate::storage::entry::current_time_ms(); + if now < deadline { + let should_pause = if rt.client_pause_write_only { + crate::command::metadata::is_write(cmd) + } else { + true + }; + if should_pause { deadline.saturating_sub(now) } else { 0 } + } else { 0 } + } else { 0 } + }; + if pause_wait_ms > 0 { + #[cfg(feature = "runtime-tokio")] + { + tokio::time::sleep(std::time::Duration::from_millis(pause_wait_ms)).await; + } + #[cfg(feature = "runtime-monoio")] + { + monoio::time::sleep(std::time::Duration::from_millis(pause_wait_ms)).await; + } + } + // === ACL permission check === // Must run before any command-specific handlers (CONFIG, REPLICAOF, etc.) // so that low-privilege users cannot reach admin commands. @@ -905,6 +933,73 @@ pub(crate) async fn handle_connection_sharded_inner< Err(err_frame) => { responses.push(err_frame); continue; } } } + if sub_bytes.eq_ignore_ascii_case(b"PAUSE") { + // CLIENT PAUSE timeout [WRITE|ALL] + if cmd_args.len() < 2 { + responses.push(Frame::Error(Bytes::from_static( + b"ERR wrong number of arguments for 'CLIENT PAUSE' command", + ))); + } else { + let timeout_ms = match extract_bytes(&cmd_args[1]) { + Some(b) => std::str::from_utf8(&b).ok().and_then(|s| s.parse::().ok()), + None => None, + }; + match timeout_ms { + Some(ms) => { + let write_only = cmd_args.get(2) + .and_then(|f| extract_bytes(f)) + .is_some_and(|b| b.eq_ignore_ascii_case(b"WRITE")); + let deadline = crate::storage::entry::current_time_ms() + ms; + let mut rt = ctx.runtime_config.write(); + rt.client_pause_deadline_ms = deadline; + rt.client_pause_write_only = write_only; + responses.push(Frame::SimpleString(Bytes::from_static(b"OK"))); + } + None => { + responses.push(Frame::Error(Bytes::from_static( + b"ERR timeout is not a valid integer or out of range", + ))); + } + } + } + continue; + } + if sub_bytes.eq_ignore_ascii_case(b"UNPAUSE") { + let mut rt = ctx.runtime_config.write(); + rt.client_pause_deadline_ms = 0; + rt.client_pause_write_only = false; + responses.push(Frame::SimpleString(Bytes::from_static(b"OK"))); + continue; + } + if sub_bytes.eq_ignore_ascii_case(b"INFO") { + // CLIENT INFO — return info about current connection + let info = format!( + "id={} fd=-1 name={} db={}\r\n", + client_id, + conn.client_name.as_ref().map(|n| String::from_utf8_lossy(n).to_string()).unwrap_or_default(), + conn.selected_db, + ); + responses.push(Frame::BulkString(Bytes::from(info))); + continue; + } + if sub_bytes.eq_ignore_ascii_case(b"LIST") { + // CLIENT LIST — stub returning current client only + let info = format!( + "id={} fd=-1 name={} db={}\r\n", + client_id, + conn.client_name.as_ref().map(|n| String::from_utf8_lossy(n).to_string()).unwrap_or_default(), + conn.selected_db, + ); + responses.push(Frame::BulkString(Bytes::from(info))); + continue; + } + if sub_bytes.eq_ignore_ascii_case(b"NO-EVICT") + || sub_bytes.eq_ignore_ascii_case(b"NO-TOUCH") + { + // Accepted but no-op (per-connection flags) + responses.push(Frame::SimpleString(Bytes::from_static(b"OK"))); + continue; + } responses.push(Frame::Error(Bytes::from(format!( "ERR unknown subcommand '{}'", String::from_utf8_lossy(&sub_bytes) )))); diff --git a/src/storage/eviction.rs b/src/storage/eviction.rs index f811ecc6..f763a124 100644 --- a/src/storage/eviction.rs +++ b/src/storage/eviction.rs @@ -686,6 +686,9 @@ mod tests { requirepass: None, protected_mode: "yes".to_string(), acllog_max_len: 128, + client_pause_deadline_ms: 0, + client_pause_write_only: false, + lazyfree_threshold: 64, } } From cfbcaf0b67387e72315ba3140869bb7c8eb60865 Mon Sep 17 00:00:00 2001 From: Tin Dang Date: Fri, 10 Apr 2026 23:36:21 +0700 Subject: [PATCH 10/20] style: cargo fmt --- src/command/config.rs | 41 +++++++++++++++++++---------------------- src/command/key.rs | 4 +--- 2 files changed, 20 insertions(+), 25 deletions(-) diff --git a/src/command/config.rs b/src/command/config.rs index 4d67120f..39267081 100644 --- a/src/command/config.rs +++ b/src/command/config.rs @@ -175,17 +175,15 @@ pub fn config_set(runtime_config: &mut RuntimeConfig, args: &[Frame]) -> Frame { ))); } }, - "lazyfree-lazy-eviction" | "lazyfree-threshold" => { - match value_str.parse::() { - Ok(v) => runtime_config.lazyfree_threshold = v, - Err(_) => { - return Frame::Error(Bytes::from(format!( - "ERR Invalid argument '{}' for CONFIG SET 'lazyfree-threshold'", - value_str - ))); - } + "lazyfree-lazy-eviction" | "lazyfree-threshold" => match value_str.parse::() { + Ok(v) => runtime_config.lazyfree_threshold = v, + Err(_) => { + return Frame::Error(Bytes::from(format!( + "ERR Invalid argument '{}' for CONFIG SET 'lazyfree-threshold'", + value_str + ))); } - } + }, _ => { return Frame::Error(Bytes::from(format!( "ERR Unsupported CONFIG parameter: {}", @@ -203,10 +201,7 @@ pub fn config_set(runtime_config: &mut RuntimeConfig, args: &[Frame]) -> Frame { /// CONFIG REWRITE — serialize current runtime config to a Redis-style config file. /// /// Writes to `/moon.conf` atomically (tmpfile + rename). -pub fn config_rewrite( - runtime_config: &RuntimeConfig, - server_config: &ServerConfig, -) -> Frame { +pub fn config_rewrite(runtime_config: &RuntimeConfig, server_config: &ServerConfig) -> Frame { let mut lines = Vec::with_capacity(20); lines.push("# Moon configuration file — generated by CONFIG REWRITE".to_string()); lines.push(format!("# {}", chrono_lite_now())); @@ -224,8 +219,14 @@ pub fn config_rewrite( // Memory settings lines.push(format!("maxmemory {}", runtime_config.maxmemory)); - lines.push(format!("maxmemory-policy {}", runtime_config.maxmemory_policy)); - lines.push(format!("maxmemory-samples {}", runtime_config.maxmemory_samples)); + lines.push(format!( + "maxmemory-policy {}", + runtime_config.maxmemory_policy + )); + lines.push(format!( + "maxmemory-samples {}", + runtime_config.maxmemory_samples + )); lines.push(format!("lfu-log-factor {}", runtime_config.lfu_log_factor)); lines.push(format!("lfu-decay-time {}", runtime_config.lfu_decay_time)); lines.push(String::new()); @@ -255,15 +256,11 @@ pub fn config_rewrite( let tmp_path = std::path::Path::new(dir).join("moon.conf.tmp"); if let Err(e) = std::fs::write(&tmp_path, content.as_bytes()) { - return Frame::Error(Bytes::from(format!( - "ERR failed to write config: {e}" - ))); + return Frame::Error(Bytes::from(format!("ERR failed to write config: {e}"))); } if let Err(e) = std::fs::rename(&tmp_path, &conf_path) { let _ = std::fs::remove_file(&tmp_path); - return Frame::Error(Bytes::from(format!( - "ERR failed to rename config: {e}" - ))); + return Frame::Error(Bytes::from(format!("ERR failed to rename config: {e}"))); } Frame::SimpleString(Bytes::from_static(b"OK")) diff --git a/src/command/key.rs b/src/command/key.rs index e8af22f4..a3b921a3 100644 --- a/src/command/key.rs +++ b/src/command/key.rs @@ -586,9 +586,7 @@ pub fn memory_usage(db: &mut Database, args: &[Frame]) -> Frame { /// MEMORY DOCTOR — report memory health issues. pub fn memory_doctor() -> Frame { // Basic health report - Frame::BulkString(Bytes::from_static( - b"Sam, I have no memory problems", - )) + Frame::BulkString(Bytes::from_static(b"Sam, I have no memory problems")) } /// MEMORY HELP — list MEMORY subcommands. From a86dbb99cf0fbe9fc1cb0b6738e7b96c06ee17e0 Mon Sep 17 00:00:00 2001 From: Tin Dang Date: Fri, 10 Apr 2026 23:44:32 +0700 Subject: [PATCH 11/20] refactor: split COPY/SORT/MEMORY from key.rs into key_extra.rs key.rs was 1966 lines, exceeding the 1500-line limit. Extracted COPY, SORT, MEMORY commands + apply_pattern helper + tests into key_extra.rs (530 lines). key.rs is now 1474 lines. --- src/command/key.rs | 498 +----------------------------------- src/command/key_extra.rs | 530 +++++++++++++++++++++++++++++++++++++++ src/command/mod.rs | 11 +- 3 files changed, 538 insertions(+), 501 deletions(-) create mode 100644 src/command/key_extra.rs diff --git a/src/command/key.rs b/src/command/key.rs index a3b921a3..480d21cb 100644 --- a/src/command/key.rs +++ b/src/command/key.rs @@ -9,7 +9,7 @@ use crate::storage::entry::current_time_ms; use super::helpers::err_wrong_args; /// Extract a key as &[u8] from a Frame argument. -fn extract_key(frame: &Frame) -> Option<&[u8]> { +pub(crate) fn extract_key(frame: &Frame) -> Option<&[u8]> { match frame { Frame::BulkString(s) | Frame::SimpleString(s) => Some(s.as_ref()), _ => None, @@ -17,7 +17,7 @@ fn extract_key(frame: &Frame) -> Option<&[u8]> { } /// Parse an integer argument from a Frame. -fn parse_int(frame: &Frame) -> Option { +pub(crate) fn parse_int(frame: &Frame) -> Option { match frame { Frame::BulkString(s) | Frame::SimpleString(s) => std::str::from_utf8(s).ok()?.parse().ok(), Frame::Integer(n) => Some(*n), @@ -498,335 +498,6 @@ pub fn renamenx(db: &mut Database, args: &[Frame]) -> Frame { Frame::Integer(1) } -/// COPY source destination [DB destination-db] [REPLACE] -/// -/// Copies the value stored at the source key to the destination key. -/// Returns 1 if source was copied, 0 if destination already exists without REPLACE. -pub fn copy(db: &mut Database, args: &[Frame]) -> Frame { - if args.len() < 2 { - return err_wrong_args("COPY"); - } - let src = match extract_key(&args[0]) { - Some(k) => k, - None => return err_wrong_args("COPY"), - }; - let dst = match extract_key(&args[1]) { - Some(k) => k, - None => return err_wrong_args("COPY"), - }; - - // Parse optional arguments: DB destination-db, REPLACE - let mut replace = false; - let mut i = 2; - while i < args.len() { - let arg = match extract_key(&args[i]) { - Some(k) => k, - None => return Frame::Error(Bytes::from_static(b"ERR syntax error")), - }; - if arg.eq_ignore_ascii_case(b"REPLACE") { - replace = true; - i += 1; - } else if arg.eq_ignore_ascii_case(b"DB") { - // Cross-DB copy requires shard_databases context not available here - return Frame::Error(Bytes::from_static( - b"ERR COPY with DB option is not supported yet", - )); - } else { - return Frame::Error(Bytes::from_static(b"ERR syntax error")); - } - } - - // Check if source exists (with lazy expiry) - if !db.exists(src) { - return Frame::Error(Bytes::from_static(b"ERR no such key")); - } - - // Same key: no data to copy, but it's valid - if src == dst { - return Frame::Integer(1); - } - - // Check if destination exists - if db.exists(dst) && !replace { - return Frame::Integer(0); - } - - // Clone the source entry (CompactEntry derives Clone) - let entry = db.get(src).cloned(); - if let Some(cloned) = entry { - db.set(Bytes::copy_from_slice(dst), cloned); - Frame::Integer(1) - } else { - // Source expired between exists() and get() — race with lazy expiry - Frame::Error(Bytes::from_static(b"ERR no such key")) - } -} - -/// MEMORY USAGE key [SAMPLES count] -/// -/// Returns the number of bytes a key and its value require to be stored. -pub fn memory_usage(db: &mut Database, args: &[Frame]) -> Frame { - if args.is_empty() { - return err_wrong_args("MEMORY"); - } - let key = match extract_key(&args[0]) { - Some(k) => k, - None => return err_wrong_args("MEMORY"), - }; - - match db.get(key) { - Some(entry) => { - let mem = key.len() + entry.value.estimate_memory() + 128; // same as entry_overhead - Frame::Integer(mem as i64) - } - None => Frame::Null, - } -} - -/// MEMORY DOCTOR — report memory health issues. -pub fn memory_doctor() -> Frame { - // Basic health report - Frame::BulkString(Bytes::from_static(b"Sam, I have no memory problems")) -} - -/// MEMORY HELP — list MEMORY subcommands. -pub fn memory_help() -> Frame { - Frame::Array( - vec![ - Frame::BulkString(Bytes::from_static(b"MEMORY DOCTOR - Return memory problems reports.")), - Frame::BulkString(Bytes::from_static(b"MEMORY HELP - Return this help message.")), - Frame::BulkString(Bytes::from_static(b"MEMORY USAGE [SAMPLES ] - Return memory in bytes used by and its value.")), - ] - .into(), - ) -} - -/// SORT key [BY pattern] [LIMIT offset count] [GET pattern ...] [ASC|DESC] [ALPHA] [STORE dest] -/// -/// Sort elements in a list, set, or sorted set. -pub fn sort(db: &mut Database, args: &[Frame]) -> Frame { - if args.is_empty() { - return err_wrong_args("SORT"); - } - let key = match extract_key(&args[0]) { - Some(k) => k, - None => return err_wrong_args("SORT"), - }; - - // Parse options - let mut by_pattern: Option<&[u8]> = None; - let mut get_patterns: Vec<&[u8]> = Vec::new(); - let mut limit_offset: usize = 0; - let mut limit_count: Option = None; - let mut descending = false; - let mut alpha = false; - let mut store_dest: Option<&[u8]> = None; - - let mut i = 1; - while i < args.len() { - let arg = match extract_key(&args[i]) { - Some(a) => a, - None => return Frame::Error(Bytes::from_static(b"ERR syntax error")), - }; - if arg.eq_ignore_ascii_case(b"BY") { - i += 1; - by_pattern = Some(match extract_key(args.get(i).unwrap_or(&Frame::Null)) { - Some(p) => p, - None => return Frame::Error(Bytes::from_static(b"ERR syntax error")), - }); - } else if arg.eq_ignore_ascii_case(b"GET") { - i += 1; - let pat = match extract_key(args.get(i).unwrap_or(&Frame::Null)) { - Some(p) => p, - None => return Frame::Error(Bytes::from_static(b"ERR syntax error")), - }; - get_patterns.push(pat); - } else if arg.eq_ignore_ascii_case(b"LIMIT") { - let off = match args.get(i + 1).and_then(|f| parse_int(f)) { - Some(v) if v >= 0 => v as usize, - _ => return Frame::Error(Bytes::from_static(b"ERR syntax error")), - }; - let cnt = match args.get(i + 2).and_then(|f| parse_int(f)) { - Some(v) if v >= 0 => v as usize, - _ => return Frame::Error(Bytes::from_static(b"ERR syntax error")), - }; - limit_offset = off; - limit_count = Some(cnt); - i += 2; - } else if arg.eq_ignore_ascii_case(b"ASC") { - descending = false; - } else if arg.eq_ignore_ascii_case(b"DESC") { - descending = true; - } else if arg.eq_ignore_ascii_case(b"ALPHA") { - alpha = true; - } else if arg.eq_ignore_ascii_case(b"STORE") { - i += 1; - store_dest = Some(match extract_key(args.get(i).unwrap_or(&Frame::Null)) { - Some(d) => d, - None => return Frame::Error(Bytes::from_static(b"ERR syntax error")), - }); - } else { - return Frame::Error(Bytes::from_static(b"ERR syntax error")); - } - i += 1; - } - - // Extract elements from the key - use crate::storage::compact_value::RedisValueRef; - let elements: Vec = match db.get(key) { - None => { - // Non-existent key: return empty result - return if let Some(dest) = store_dest { - // STORE with empty → create empty list - let entry = crate::storage::entry::Entry::new_list(); - db.set(Bytes::copy_from_slice(dest), entry); - Frame::Integer(0) - } else { - Frame::Array(crate::framevec![]) - }; - } - Some(entry) => match entry.value.as_redis_value() { - RedisValueRef::List(l) => l.iter().cloned().collect(), - RedisValueRef::ListListpack(lp) => lp.iter().map(|e| e.to_bytes()).collect(), - RedisValueRef::Set(s) => s.iter().cloned().collect(), - RedisValueRef::SetListpack(lp) => lp.iter().map(|e| e.to_bytes()).collect(), - RedisValueRef::SetIntset(is) => is.iter().map(|v| Bytes::from(v.to_string())).collect(), - RedisValueRef::SortedSet { members, .. } => members.keys().cloned().collect(), - RedisValueRef::SortedSetBPTree { members, .. } => members.keys().cloned().collect(), - RedisValueRef::SortedSetListpack(lp) => { - // Listpack stores member, score pairs - let entries: Vec<_> = lp.iter().collect(); - entries - .chunks(2) - .filter_map(|c| c.first().map(|e| e.to_bytes())) - .collect() - } - _ => { - return Frame::Error(Bytes::from_static( - b"WRONGTYPE Operation against a key holding the wrong kind of value", - )); - } - }, - }; - - // Resolve sort keys (BY pattern or element itself) - let sort_keys: Vec> = if let Some(pattern) = by_pattern { - if pattern == b"nosort" { - // BY nosort = skip sorting - elements.iter().map(|_| None).collect() - } else { - elements - .iter() - .map(|elem| { - let lookup_key = apply_pattern(pattern, elem); - db.get(&lookup_key) - .and_then(|e| e.value.as_bytes().map(|b| Bytes::copy_from_slice(b))) - }) - .collect() - } - } else { - elements.iter().map(|e| Some(e.clone())).collect() - }; - - // Create indexed pairs for stable sort - let mut indices: Vec = (0..elements.len()).collect(); - - // Sort (skip if BY nosort) - let no_sort = by_pattern.is_some_and(|p| p == b"nosort"); - if !no_sort { - indices.sort_by(|&a, &b| { - let ka = sort_keys[a].as_ref(); - let kb = sort_keys[b].as_ref(); - let cmp = match (ka, kb) { - (None, None) => std::cmp::Ordering::Equal, - (None, Some(_)) => std::cmp::Ordering::Greater, - (Some(_), None) => std::cmp::Ordering::Less, - (Some(va), Some(vb)) => { - if alpha { - va.cmp(vb) - } else { - let fa = std::str::from_utf8(va) - .ok() - .and_then(|s| s.parse::().ok()) - .unwrap_or(0.0); - let fb = std::str::from_utf8(vb) - .ok() - .and_then(|s| s.parse::().ok()) - .unwrap_or(0.0); - fa.partial_cmp(&fb).unwrap_or(std::cmp::Ordering::Equal) - } - } - }; - if descending { cmp.reverse() } else { cmp } - }); - } - - // Apply LIMIT - let start = limit_offset.min(indices.len()); - let count = limit_count.unwrap_or(indices.len()); - let end = (start + count).min(indices.len()); - let selected = &indices[start..end]; - - // Build results - let results: Vec = if get_patterns.is_empty() { - selected - .iter() - .map(|&idx| Frame::BulkString(elements[idx].clone())) - .collect() - } else { - let mut out = Vec::with_capacity(selected.len() * get_patterns.len()); - for &idx in selected { - for pat in &get_patterns { - if *pat == b"#" { - out.push(Frame::BulkString(elements[idx].clone())); - } else { - let lookup_key = apply_pattern(pat, &elements[idx]); - match db.get(&lookup_key) { - Some(e) => match e.value.as_bytes() { - Some(v) => out.push(Frame::BulkString(Bytes::copy_from_slice(v))), - None => out.push(Frame::Null), - }, - None => out.push(Frame::Null), - } - } - } - } - out - }; - - // STORE or return - if let Some(dest) = store_dest { - let count = results.len() as i64; - let mut list = std::collections::VecDeque::with_capacity(results.len()); - for frame in results { - if let Frame::BulkString(b) = frame { - list.push_back(b); - } - } - let mut entry = crate::storage::entry::Entry::new_list(); - entry.value = crate::storage::compact_value::CompactValue::from_redis_value( - crate::storage::entry::RedisValue::List(list), - ); - db.set(Bytes::copy_from_slice(dest), entry); - Frame::Integer(count) - } else { - Frame::Array(results.into()) - } -} - -/// Apply a SORT pattern by replacing the first `*` with the element value. -fn apply_pattern(pattern: &[u8], element: &[u8]) -> Bytes { - if let Some(pos) = pattern.iter().position(|&b| b == b'*') { - let mut result = Vec::with_capacity(pattern.len() + element.len()); - result.extend_from_slice(&pattern[..pos]); - result.extend_from_slice(element); - result.extend_from_slice(&pattern[pos + 1..]); - Bytes::from(result) - } else { - Bytes::copy_from_slice(pattern) - } -} - /// Check if a value is large enough to warrant async drop. fn should_async_drop(entry: &crate::storage::entry::Entry) -> bool { use crate::storage::compact_value::RedisValueRef; @@ -1798,169 +1469,4 @@ mod tests { _ => panic!("Expected array"), } } - - // --- COPY tests --- - - #[test] - fn test_copy_basic() { - let mut db = setup_db_with_key(b"src", b"hello"); - let result = copy(&mut db, &[bs(b"src"), bs(b"dst")]); - assert_eq!(result, Frame::Integer(1)); - assert!(db.exists(b"src")); - assert!(db.exists(b"dst")); - } - - #[test] - fn test_copy_dest_exists_no_replace() { - let mut db = setup_db_with_key(b"src", b"hello"); - db.set( - Bytes::from_static(b"dst"), - Entry::new_string(Bytes::from_static(b"existing")), - ); - let result = copy(&mut db, &[bs(b"src"), bs(b"dst")]); - assert_eq!(result, Frame::Integer(0)); - } - - #[test] - fn test_copy_with_replace() { - let mut db = setup_db_with_key(b"src", b"hello"); - db.set( - Bytes::from_static(b"dst"), - Entry::new_string(Bytes::from_static(b"existing")), - ); - let result = copy(&mut db, &[bs(b"src"), bs(b"dst"), bs(b"REPLACE")]); - assert_eq!(result, Frame::Integer(1)); - } - - #[test] - fn test_copy_nonexistent_source() { - let mut db = Database::new(); - let result = copy(&mut db, &[bs(b"nosuchkey"), bs(b"dst")]); - assert!(matches!(result, Frame::Error(_))); - } - - #[test] - fn test_copy_same_key() { - let mut db = setup_db_with_key(b"src", b"hello"); - let result = copy(&mut db, &[bs(b"src"), bs(b"src")]); - assert_eq!(result, Frame::Integer(1)); - } - - #[test] - fn test_copy_db_option_errors() { - let mut db = setup_db_with_key(b"src", b"hello"); - let result = copy(&mut db, &[bs(b"src"), bs(b"dst"), bs(b"DB")]); - assert!(matches!(result, Frame::Error(_))); - } - - // --- SORT tests --- - - fn setup_list(db: &mut Database, key: &[u8], vals: &[&[u8]]) { - use std::collections::VecDeque; - let list: VecDeque = vals.iter().map(|v| Bytes::copy_from_slice(v)).collect(); - let mut entry = Entry::new_list(); - entry.value = crate::storage::compact_value::CompactValue::from_redis_value( - crate::storage::entry::RedisValue::List(list), - ); - db.set(Bytes::copy_from_slice(key), entry); - } - - #[test] - fn test_sort_numeric() { - let mut db = Database::new(); - setup_list(&mut db, b"mylist", &[b"3", b"1", b"2"]); - let result = sort(&mut db, &[bs(b"mylist")]); - assert_eq!( - result, - Frame::Array(framevec![ - Frame::BulkString(Bytes::from_static(b"1")), - Frame::BulkString(Bytes::from_static(b"2")), - Frame::BulkString(Bytes::from_static(b"3")), - ]) - ); - } - - #[test] - fn test_sort_alpha() { - let mut db = Database::new(); - setup_list(&mut db, b"mylist", &[b"c", b"a", b"b"]); - let result = sort(&mut db, &[bs(b"mylist"), bs(b"ALPHA")]); - assert_eq!( - result, - Frame::Array(framevec![ - Frame::BulkString(Bytes::from_static(b"a")), - Frame::BulkString(Bytes::from_static(b"b")), - Frame::BulkString(Bytes::from_static(b"c")), - ]) - ); - } - - #[test] - fn test_sort_desc() { - let mut db = Database::new(); - setup_list(&mut db, b"mylist", &[b"1", b"3", b"2"]); - let result = sort(&mut db, &[bs(b"mylist"), bs(b"DESC")]); - assert_eq!( - result, - Frame::Array(framevec![ - Frame::BulkString(Bytes::from_static(b"3")), - Frame::BulkString(Bytes::from_static(b"2")), - Frame::BulkString(Bytes::from_static(b"1")), - ]) - ); - } - - #[test] - fn test_sort_limit() { - let mut db = Database::new(); - setup_list(&mut db, b"mylist", &[b"3", b"1", b"2", b"4"]); - let result = sort(&mut db, &[bs(b"mylist"), bs(b"LIMIT"), bs(b"1"), bs(b"2")]); - assert_eq!( - result, - Frame::Array(framevec![ - Frame::BulkString(Bytes::from_static(b"2")), - Frame::BulkString(Bytes::from_static(b"3")), - ]) - ); - } - - #[test] - fn test_sort_store() { - let mut db = Database::new(); - setup_list(&mut db, b"mylist", &[b"3", b"1", b"2"]); - let result = sort(&mut db, &[bs(b"mylist"), bs(b"STORE"), bs(b"sorted")]); - assert_eq!(result, Frame::Integer(3)); - assert!(db.exists(b"sorted")); - } - - #[test] - fn test_sort_nonexistent() { - let mut db = Database::new(); - let result = sort(&mut db, &[bs(b"nokey")]); - assert_eq!(result, Frame::Array(framevec![])); - } - - #[test] - fn test_sort_set() { - let mut db = Database::new(); - let mut s = std::collections::HashSet::new(); - s.insert(Bytes::from_static(b"3")); - s.insert(Bytes::from_static(b"1")); - s.insert(Bytes::from_static(b"2")); - let mut entry = Entry::new_set(); - entry.value = crate::storage::compact_value::CompactValue::from_redis_value( - crate::storage::entry::RedisValue::Set(s), - ); - db.set(Bytes::from_static(b"myset"), entry); - let result = sort(&mut db, &[bs(b"myset")]); - // Sort result should be [1, 2, 3] regardless of HashSet order - assert_eq!( - result, - Frame::Array(framevec![ - Frame::BulkString(Bytes::from_static(b"1")), - Frame::BulkString(Bytes::from_static(b"2")), - Frame::BulkString(Bytes::from_static(b"3")), - ]) - ); - } } diff --git a/src/command/key_extra.rs b/src/command/key_extra.rs new file mode 100644 index 00000000..1aacf8c5 --- /dev/null +++ b/src/command/key_extra.rs @@ -0,0 +1,530 @@ +//! Extra key commands: COPY, SORT, MEMORY. +//! +//! Split from key.rs to stay under the 1500-line limit. + +use bytes::Bytes; + +use crate::framevec; +use crate::protocol::Frame; +use crate::storage::Database; + +use super::helpers::err_wrong_args; +use super::key::{extract_key, parse_int}; + +/// COPY source destination [DB destination-db] [REPLACE] +/// +/// Copies the value stored at the source key to the destination key. +/// Returns 1 if source was copied, 0 if destination already exists without REPLACE. +pub fn copy(db: &mut Database, args: &[Frame]) -> Frame { + if args.len() < 2 { + return err_wrong_args("COPY"); + } + let src = match extract_key(&args[0]) { + Some(k) => k, + None => return err_wrong_args("COPY"), + }; + let dst = match extract_key(&args[1]) { + Some(k) => k, + None => return err_wrong_args("COPY"), + }; + + // Parse optional arguments: DB destination-db, REPLACE + let mut replace = false; + let mut i = 2; + while i < args.len() { + let arg = match extract_key(&args[i]) { + Some(k) => k, + None => return Frame::Error(Bytes::from_static(b"ERR syntax error")), + }; + if arg.eq_ignore_ascii_case(b"REPLACE") { + replace = true; + i += 1; + } else if arg.eq_ignore_ascii_case(b"DB") { + // Cross-DB copy requires shard_databases context not available here + return Frame::Error(Bytes::from_static( + b"ERR COPY with DB option is not supported yet", + )); + } else { + return Frame::Error(Bytes::from_static(b"ERR syntax error")); + } + } + + // Check if source exists (with lazy expiry) + if !db.exists(src) { + return Frame::Error(Bytes::from_static(b"ERR no such key")); + } + + // Same key: no data to copy, but it's valid + if src == dst { + return Frame::Integer(1); + } + + // Check if destination exists + if db.exists(dst) && !replace { + return Frame::Integer(0); + } + + // Clone the source entry (CompactEntry derives Clone) + let entry = db.get(src).cloned(); + if let Some(cloned) = entry { + db.set(Bytes::copy_from_slice(dst), cloned); + Frame::Integer(1) + } else { + // Source expired between exists() and get() — race with lazy expiry + Frame::Error(Bytes::from_static(b"ERR no such key")) + } +} + +/// MEMORY USAGE key [SAMPLES count] +/// +/// Returns the number of bytes a key and its value require to be stored. +pub fn memory_usage(db: &mut Database, args: &[Frame]) -> Frame { + if args.is_empty() { + return err_wrong_args("MEMORY"); + } + let key = match extract_key(&args[0]) { + Some(k) => k, + None => return err_wrong_args("MEMORY"), + }; + + match db.get(key) { + Some(entry) => { + let mem = key.len() + entry.value.estimate_memory() + 128; // same as entry_overhead + Frame::Integer(mem as i64) + } + None => Frame::Null, + } +} + +/// MEMORY DOCTOR — report memory health issues. +pub fn memory_doctor() -> Frame { + // Basic health report + Frame::BulkString(Bytes::from_static(b"Sam, I have no memory problems")) +} + +/// MEMORY HELP — list MEMORY subcommands. +pub fn memory_help() -> Frame { + Frame::Array( + vec![ + Frame::BulkString(Bytes::from_static( + b"MEMORY DOCTOR - Return memory problems reports.", + )), + Frame::BulkString(Bytes::from_static( + b"MEMORY HELP - Return this help message.", + )), + Frame::BulkString(Bytes::from_static( + b"MEMORY USAGE [SAMPLES ] - Return memory in bytes used by and its value.", + )), + ] + .into(), + ) +} + +/// SORT key [BY pattern] [LIMIT offset count] [GET pattern ...] [ASC|DESC] [ALPHA] [STORE dest] +/// +/// Sort elements in a list, set, or sorted set. +pub fn sort(db: &mut Database, args: &[Frame]) -> Frame { + if args.is_empty() { + return err_wrong_args("SORT"); + } + let key = match extract_key(&args[0]) { + Some(k) => k, + None => return err_wrong_args("SORT"), + }; + + // Parse options + let mut by_pattern: Option<&[u8]> = None; + let mut get_patterns: Vec<&[u8]> = Vec::new(); + let mut limit_offset: usize = 0; + let mut limit_count: Option = None; + let mut descending = false; + let mut alpha = false; + let mut store_dest: Option<&[u8]> = None; + + let mut i = 1; + while i < args.len() { + let arg = match extract_key(&args[i]) { + Some(a) => a, + None => return Frame::Error(Bytes::from_static(b"ERR syntax error")), + }; + if arg.eq_ignore_ascii_case(b"BY") { + i += 1; + by_pattern = Some(match extract_key(args.get(i).unwrap_or(&Frame::Null)) { + Some(p) => p, + None => return Frame::Error(Bytes::from_static(b"ERR syntax error")), + }); + } else if arg.eq_ignore_ascii_case(b"GET") { + i += 1; + let pat = match extract_key(args.get(i).unwrap_or(&Frame::Null)) { + Some(p) => p, + None => return Frame::Error(Bytes::from_static(b"ERR syntax error")), + }; + get_patterns.push(pat); + } else if arg.eq_ignore_ascii_case(b"LIMIT") { + let off = match args.get(i + 1).and_then(|f| parse_int(f)) { + Some(v) if v >= 0 => v as usize, + _ => return Frame::Error(Bytes::from_static(b"ERR syntax error")), + }; + let cnt = match args.get(i + 2).and_then(|f| parse_int(f)) { + Some(v) if v >= 0 => v as usize, + _ => return Frame::Error(Bytes::from_static(b"ERR syntax error")), + }; + limit_offset = off; + limit_count = Some(cnt); + i += 2; + } else if arg.eq_ignore_ascii_case(b"ASC") { + descending = false; + } else if arg.eq_ignore_ascii_case(b"DESC") { + descending = true; + } else if arg.eq_ignore_ascii_case(b"ALPHA") { + alpha = true; + } else if arg.eq_ignore_ascii_case(b"STORE") { + i += 1; + store_dest = Some(match extract_key(args.get(i).unwrap_or(&Frame::Null)) { + Some(d) => d, + None => return Frame::Error(Bytes::from_static(b"ERR syntax error")), + }); + } else { + return Frame::Error(Bytes::from_static(b"ERR syntax error")); + } + i += 1; + } + + // Extract elements from the key + use crate::storage::compact_value::RedisValueRef; + let elements: Vec = match db.get(key) { + None => { + // Non-existent key: return empty result + return if let Some(dest) = store_dest { + // STORE with empty → create empty list + let entry = crate::storage::entry::Entry::new_list(); + db.set(Bytes::copy_from_slice(dest), entry); + Frame::Integer(0) + } else { + Frame::Array(framevec![]) + }; + } + Some(entry) => match entry.value.as_redis_value() { + RedisValueRef::List(l) => l.iter().cloned().collect(), + RedisValueRef::ListListpack(lp) => lp.iter().map(|e| e.to_bytes()).collect(), + RedisValueRef::Set(s) => s.iter().cloned().collect(), + RedisValueRef::SetListpack(lp) => lp.iter().map(|e| e.to_bytes()).collect(), + RedisValueRef::SetIntset(is) => is.iter().map(|v| Bytes::from(v.to_string())).collect(), + RedisValueRef::SortedSet { members, .. } => members.keys().cloned().collect(), + RedisValueRef::SortedSetBPTree { members, .. } => members.keys().cloned().collect(), + RedisValueRef::SortedSetListpack(lp) => { + // Listpack stores member, score pairs + let entries: Vec<_> = lp.iter().collect(); + entries + .chunks(2) + .filter_map(|c| c.first().map(|e| e.to_bytes())) + .collect() + } + _ => { + return Frame::Error(Bytes::from_static( + b"WRONGTYPE Operation against a key holding the wrong kind of value", + )); + } + }, + }; + + // Resolve sort keys (BY pattern or element itself) + let sort_keys: Vec> = if let Some(pattern) = by_pattern { + if pattern == b"nosort" { + // BY nosort = skip sorting + elements.iter().map(|_| None).collect() + } else { + elements + .iter() + .map(|elem| { + let lookup_key = apply_pattern(pattern, elem); + db.get(&lookup_key) + .and_then(|e| e.value.as_bytes().map(|b| Bytes::copy_from_slice(b))) + }) + .collect() + } + } else { + elements.iter().map(|e| Some(e.clone())).collect() + }; + + // Create indexed pairs for stable sort + let mut indices: Vec = (0..elements.len()).collect(); + + // Sort (skip if BY nosort) + let no_sort = by_pattern.is_some_and(|p| p == b"nosort"); + if !no_sort { + indices.sort_by(|&a, &b| { + let ka = sort_keys[a].as_ref(); + let kb = sort_keys[b].as_ref(); + let cmp = match (ka, kb) { + (None, None) => std::cmp::Ordering::Equal, + (None, Some(_)) => std::cmp::Ordering::Greater, + (Some(_), None) => std::cmp::Ordering::Less, + (Some(va), Some(vb)) => { + if alpha { + va.cmp(vb) + } else { + let fa = std::str::from_utf8(va) + .ok() + .and_then(|s| s.parse::().ok()) + .unwrap_or(0.0); + let fb = std::str::from_utf8(vb) + .ok() + .and_then(|s| s.parse::().ok()) + .unwrap_or(0.0); + fa.partial_cmp(&fb).unwrap_or(std::cmp::Ordering::Equal) + } + } + }; + if descending { cmp.reverse() } else { cmp } + }); + } + + // Apply LIMIT + let start = limit_offset.min(indices.len()); + let count = limit_count.unwrap_or(indices.len()); + let end = (start + count).min(indices.len()); + let selected = &indices[start..end]; + + // Build results + let results: Vec = if get_patterns.is_empty() { + selected + .iter() + .map(|&idx| Frame::BulkString(elements[idx].clone())) + .collect() + } else { + let mut out = Vec::with_capacity(selected.len() * get_patterns.len()); + for &idx in selected { + for pat in &get_patterns { + if *pat == b"#" { + out.push(Frame::BulkString(elements[idx].clone())); + } else { + let lookup_key = apply_pattern(pat, &elements[idx]); + match db.get(&lookup_key) { + Some(e) => match e.value.as_bytes() { + Some(v) => out.push(Frame::BulkString(Bytes::copy_from_slice(v))), + None => out.push(Frame::Null), + }, + None => out.push(Frame::Null), + } + } + } + } + out + }; + + // STORE or return + if let Some(dest) = store_dest { + let count = results.len() as i64; + let mut list = std::collections::VecDeque::with_capacity(results.len()); + for frame in results { + if let Frame::BulkString(b) = frame { + list.push_back(b); + } + } + let mut entry = crate::storage::entry::Entry::new_list(); + entry.value = crate::storage::compact_value::CompactValue::from_redis_value( + crate::storage::entry::RedisValue::List(list), + ); + db.set(Bytes::copy_from_slice(dest), entry); + Frame::Integer(count) + } else { + Frame::Array(results.into()) + } +} + +/// Apply a SORT pattern by replacing the first `*` with the element value. +fn apply_pattern(pattern: &[u8], element: &[u8]) -> Bytes { + if let Some(pos) = pattern.iter().position(|&b| b == b'*') { + let mut result = Vec::with_capacity(pattern.len() + element.len()); + result.extend_from_slice(&pattern[..pos]); + result.extend_from_slice(element); + result.extend_from_slice(&pattern[pos + 1..]); + Bytes::from(result) + } else { + Bytes::copy_from_slice(pattern) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::storage::entry::Entry; + + fn bs(s: &[u8]) -> Frame { + Frame::BulkString(Bytes::copy_from_slice(s)) + } + + fn setup_db_with_key(key: &[u8], val: &[u8]) -> Database { + let mut db = Database::new(); + db.set( + Bytes::copy_from_slice(key), + Entry::new_string(Bytes::copy_from_slice(val)), + ); + db + } + + // --- COPY tests --- + + #[test] + fn test_copy_basic() { + let mut db = setup_db_with_key(b"src", b"hello"); + let result = copy(&mut db, &[bs(b"src"), bs(b"dst")]); + assert_eq!(result, Frame::Integer(1)); + assert!(db.exists(b"src")); + assert!(db.exists(b"dst")); + } + + #[test] + fn test_copy_dest_exists_no_replace() { + let mut db = setup_db_with_key(b"src", b"hello"); + db.set( + Bytes::from_static(b"dst"), + Entry::new_string(Bytes::from_static(b"existing")), + ); + let result = copy(&mut db, &[bs(b"src"), bs(b"dst")]); + assert_eq!(result, Frame::Integer(0)); + } + + #[test] + fn test_copy_with_replace() { + let mut db = setup_db_with_key(b"src", b"hello"); + db.set( + Bytes::from_static(b"dst"), + Entry::new_string(Bytes::from_static(b"existing")), + ); + let result = copy(&mut db, &[bs(b"src"), bs(b"dst"), bs(b"REPLACE")]); + assert_eq!(result, Frame::Integer(1)); + } + + #[test] + fn test_copy_nonexistent_source() { + let mut db = Database::new(); + let result = copy(&mut db, &[bs(b"nosuchkey"), bs(b"dst")]); + assert!(matches!(result, Frame::Error(_))); + } + + #[test] + fn test_copy_same_key() { + let mut db = setup_db_with_key(b"src", b"hello"); + let result = copy(&mut db, &[bs(b"src"), bs(b"src")]); + assert_eq!(result, Frame::Integer(1)); + } + + #[test] + fn test_copy_db_option_errors() { + let mut db = setup_db_with_key(b"src", b"hello"); + let result = copy(&mut db, &[bs(b"src"), bs(b"dst"), bs(b"DB")]); + assert!(matches!(result, Frame::Error(_))); + } + + // --- SORT tests --- + + fn setup_list(db: &mut Database, key: &[u8], vals: &[&[u8]]) { + use std::collections::VecDeque; + let list: VecDeque = vals.iter().map(|v| Bytes::copy_from_slice(v)).collect(); + let mut entry = Entry::new_list(); + entry.value = crate::storage::compact_value::CompactValue::from_redis_value( + crate::storage::entry::RedisValue::List(list), + ); + db.set(Bytes::copy_from_slice(key), entry); + } + + #[test] + fn test_sort_numeric() { + let mut db = Database::new(); + setup_list(&mut db, b"mylist", &[b"3", b"1", b"2"]); + let result = sort(&mut db, &[bs(b"mylist")]); + assert_eq!( + result, + Frame::Array(framevec![ + Frame::BulkString(Bytes::from_static(b"1")), + Frame::BulkString(Bytes::from_static(b"2")), + Frame::BulkString(Bytes::from_static(b"3")), + ]) + ); + } + + #[test] + fn test_sort_alpha() { + let mut db = Database::new(); + setup_list(&mut db, b"mylist", &[b"c", b"a", b"b"]); + let result = sort(&mut db, &[bs(b"mylist"), bs(b"ALPHA")]); + assert_eq!( + result, + Frame::Array(framevec![ + Frame::BulkString(Bytes::from_static(b"a")), + Frame::BulkString(Bytes::from_static(b"b")), + Frame::BulkString(Bytes::from_static(b"c")), + ]) + ); + } + + #[test] + fn test_sort_desc() { + let mut db = Database::new(); + setup_list(&mut db, b"mylist", &[b"1", b"3", b"2"]); + let result = sort(&mut db, &[bs(b"mylist"), bs(b"DESC")]); + assert_eq!( + result, + Frame::Array(framevec![ + Frame::BulkString(Bytes::from_static(b"3")), + Frame::BulkString(Bytes::from_static(b"2")), + Frame::BulkString(Bytes::from_static(b"1")), + ]) + ); + } + + #[test] + fn test_sort_limit() { + let mut db = Database::new(); + setup_list(&mut db, b"mylist", &[b"3", b"1", b"2", b"4"]); + let result = sort(&mut db, &[bs(b"mylist"), bs(b"LIMIT"), bs(b"1"), bs(b"2")]); + assert_eq!( + result, + Frame::Array(framevec![ + Frame::BulkString(Bytes::from_static(b"2")), + Frame::BulkString(Bytes::from_static(b"3")), + ]) + ); + } + + #[test] + fn test_sort_store() { + let mut db = Database::new(); + setup_list(&mut db, b"mylist", &[b"3", b"1", b"2"]); + let result = sort(&mut db, &[bs(b"mylist"), bs(b"STORE"), bs(b"sorted")]); + assert_eq!(result, Frame::Integer(3)); + assert!(db.exists(b"sorted")); + } + + #[test] + fn test_sort_nonexistent() { + let mut db = Database::new(); + let result = sort(&mut db, &[bs(b"nokey")]); + assert_eq!(result, Frame::Array(framevec![])); + } + + #[test] + fn test_sort_set() { + let mut db = Database::new(); + let mut s = std::collections::HashSet::new(); + s.insert(Bytes::from_static(b"3")); + s.insert(Bytes::from_static(b"1")); + s.insert(Bytes::from_static(b"2")); + let mut entry = Entry::new_set(); + entry.value = crate::storage::compact_value::CompactValue::from_redis_value( + crate::storage::entry::RedisValue::Set(s), + ); + db.set(Bytes::from_static(b"myset"), entry); + let result = sort(&mut db, &[bs(b"myset")]); + assert_eq!( + result, + Frame::Array(framevec![ + Frame::BulkString(Bytes::from_static(b"1")), + Frame::BulkString(Bytes::from_static(b"2")), + Frame::BulkString(Bytes::from_static(b"3")), + ]) + ); + } +} diff --git a/src/command/mod.rs b/src/command/mod.rs index 493044fa..74d3dd07 100644 --- a/src/command/mod.rs +++ b/src/command/mod.rs @@ -8,6 +8,7 @@ pub mod hash; pub mod helpers; pub mod hll; pub mod key; +pub mod key_extra; pub mod list; pub mod metadata; pub mod persistence; @@ -98,7 +99,7 @@ fn dispatch_inner( (4, b'c') => { // COPY if cmd.eq_ignore_ascii_case(b"COPY") { - return resp(key::copy(db, args)); + return resp(key_extra::copy(db, args)); } } (4, b'd') => { @@ -206,7 +207,7 @@ fn dispatch_inner( return resp(set::spop(db, args)); } if cmd.eq_ignore_ascii_case(b"SORT") { - return resp(key::sort(db, args)); + return resp(key_extra::sort(db, args)); } } (4, b't') => { @@ -446,13 +447,13 @@ fn dispatch_inner( if let Some(sub) = args.first() { if let Some(sub_bytes) = crate::command::helpers::extract_bytes(sub) { if sub_bytes.eq_ignore_ascii_case(b"USAGE") { - return resp(key::memory_usage(db, &args[1..])); + return resp(key_extra::memory_usage(db, &args[1..])); } if sub_bytes.eq_ignore_ascii_case(b"DOCTOR") { - return resp(key::memory_doctor()); + return resp(key_extra::memory_doctor()); } if sub_bytes.eq_ignore_ascii_case(b"HELP") { - return resp(key::memory_help()); + return resp(key_extra::memory_help()); } } } From 50db06f8bae865b26b828392a61124402c994217 Mon Sep 17 00:00:00 2001 From: Tin Dang Date: Sat, 11 Apr 2026 00:36:52 +0700 Subject: [PATCH 12/20] docs: add CHANGELOG entry for high-impact Redis command parity --- CHANGELOG.md | 13 +++++++++++++ 1 file changed, 13 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index 5fbbdde8..87cfdd44 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -6,6 +6,19 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ## [Unreleased] +### Added — High-Impact Redis Command Parity (2026-04-10) + +- **COPY command** — atomic key duplication with DESTINATION, REPLACE options (Redis 6.2+). +- **Bit operations** — GETBIT, SETBIT, BITCOUNT (byte/bit range modes), BITOP (AND/OR/XOR/NOT), BITPOS (byte/bit range modes) with read-only dispatch variants. +- **SORT command** — full BY/GET/LIMIT/ALPHA/ASC/DESC/STORE support for lists, sets, and sorted sets. +- **Geospatial commands** — GEOADD (NX/XX/CH), GEOPOS, GEODIST (M/KM/FT/MI), GEOHASH (11-char base32), GEOSEARCH (FROMLONLAT/FROMMEMBER, BYRADIUS/BYBOX, WITHCOORD/WITHDIST/WITHHASH), GEOSEARCHSTORE. +- **CONFIG REWRITE** — atomic write of runtime config to `/moon.conf` (tmpfile + rename). CONFIG RESETSTAT stub. +- **CLIENT PAUSE/UNPAUSE** — delays command processing with WRITE-only mode support. CLIENT INFO, CLIENT LIST (stub), CLIENT NO-EVICT/NO-TOUCH accepted. +- **MEMORY USAGE/DOCTOR/HELP** — key memory estimation via `estimate_memory()`. +- **Lazyfree threshold** — configurable via `CONFIG SET lazyfree-threshold N` (default 64). +- **GETBIT/SETBIT metadata** — added to PHF command registry. +- **GEOADD/GEOSEARCHSTORE** — added to AOF write commands test list. + ### Fixed — Wave 0-4 Gap Closure (2026-04-09) - **ZREVRANGEBYSCORE/ZREVRANGEBYLEX correctness bug:** Fixed double-swap of min/max bounds in `zrange_by_score` and `zrange_by_lex` that caused empty results for finite score ranges (e.g., `ZREVRANGEBYSCORE key 3 1`). Added finite-range test to `test-commands.sh`. From dfc4849c5175d912fe01e2fcd9892f5d57be9f65 Mon Sep 17 00:00:00 2001 From: Tin Dang Date: Sat, 11 Apr 2026 01:14:23 +0700 Subject: [PATCH 13/20] fix: address all 16 review findings from PR #68 Critical bugs: - GEOSEARCHSTORE now writes matches to destination sorted set - CLIENT PAUSE uses 50ms polling loop so UNPAUSE takes effect immediately - normalize_index split into normalize_start/normalize_end; past-end start no longer clamped to len-1 (produces correct empty ranges) Correctness: - COPY returns 0 (not ERR) for missing source and same-key (Redis compat) - COPY REPLACE no longer double-increments i (accepts valid trailing args) - SORT numeric mode returns ERR for non-numeric elements without ALPHA - GEOSEARCH rejects unknown options with ERR instead of silently ignoring - GEOSEARCH enforces BYRADIUS/BYBOX mutual exclusivity - WITHDIST reports distance in query unit (km/mi/ft), not always meters - CLIENT PAUSE validates mode arg (WRITE|ALL only), rejects trailing garbage - CLIENT PAUSE deadline uses saturating_add to prevent overflow Config: - CONFIG REWRITE now persists lazyfree-threshold - Removed misleading lazyfree-lazy-eviction alias from CONFIG SET Performance: - geopos() no longer clones entire members HashMap; collects scores via short borrow then formats outside the borrow scope --- src/command/config.rs | 6 +- src/command/geo/geo_cmd.rs | 228 +++++++++++++++++++---------- src/command/key_extra.rs | 37 +++-- src/command/string/string_bit.rs | 47 +++--- src/server/conn/handler_sharded.rs | 73 ++++++--- 5 files changed, 266 insertions(+), 125 deletions(-) diff --git a/src/command/config.rs b/src/command/config.rs index 39267081..228f021e 100644 --- a/src/command/config.rs +++ b/src/command/config.rs @@ -175,7 +175,7 @@ pub fn config_set(runtime_config: &mut RuntimeConfig, args: &[Frame]) -> Frame { ))); } }, - "lazyfree-lazy-eviction" | "lazyfree-threshold" => match value_str.parse::() { + "lazyfree-threshold" => match value_str.parse::() { Ok(v) => runtime_config.lazyfree_threshold = v, Err(_) => { return Frame::Error(Bytes::from(format!( @@ -247,6 +247,10 @@ pub fn config_rewrite(runtime_config: &RuntimeConfig, server_config: &ServerConf if let Some(ref aclfile) = runtime_config.aclfile { lines.push(format!("aclfile {}", aclfile)); } + lines.push(format!( + "lazyfree-threshold {}", + runtime_config.lazyfree_threshold + )); let content = lines.join("\n") + "\n"; diff --git a/src/command/geo/geo_cmd.rs b/src/command/geo/geo_cmd.rs index 598fa0ad..cd5bf75b 100644 --- a/src/command/geo/geo_cmd.rs +++ b/src/command/geo/geo_cmd.rs @@ -131,39 +131,38 @@ pub fn geopos(db: &mut Database, args: &[Frame]) -> Frame { None => return err_wrong_args("GEOPOS"), }; - let members_map = match db.get_sorted_set(key) { - Ok(Some((members, _))) => Some(members.clone()), - Ok(None) => None, - Err(e) => return e, + // Collect scores first to avoid holding borrow across format! allocations + let scores: Vec> = { + let members_map = match db.get_sorted_set(key) { + Ok(Some((members, _))) => Some(members), + Ok(None) => None, + Err(e) => return e, + }; + args[1..] + .iter() + .map(|arg| { + let member = extract_bytes(arg)?; + members_map.as_ref()?.get(member).copied() + }) + .collect() }; - let mut results = Vec::with_capacity(args.len() - 1); - for arg in &args[1..] { - let member = match extract_bytes(arg) { - Some(m) => m, - None => { - results.push(Frame::Null); - continue; + let results: Vec = scores + .into_iter() + .map(|opt_score| match opt_score { + Some(score) => { + let (lon, lat) = geohash_decode(score); + Frame::Array( + vec![ + Frame::BulkString(Bytes::from(format!("{:.4}", lon))), + Frame::BulkString(Bytes::from(format!("{:.4}", lat))), + ] + .into(), + ) } - }; - - match &members_map { - Some(m) => match m.get(member) { - Some(&score) => { - let (lon, lat) = geohash_decode(score); - results.push(Frame::Array( - vec![ - Frame::BulkString(Bytes::from(format!("{:.4}", lon))), - Frame::BulkString(Bytes::from(format!("{:.4}", lat))), - ] - .into(), - )); - } - None => results.push(Frame::Null), - }, - None => results.push(Frame::Null), - } - } + None => Frame::Null, + }) + .collect(); Frame::Array(results.into()) } @@ -269,7 +268,7 @@ pub fn geohash(db: &mut Database, args: &[Frame]) -> Frame { /// BYRADIUS radius M|KM|FT|MI|BYBOX width height M|KM|FT|MI /// [ASC|DESC] [COUNT count [ANY]] [WITHCOORD] [WITHDIST] [WITHHASH] pub fn geosearch(db: &mut Database, args: &[Frame]) -> Frame { - let (_, results) = geosearch_inner(db, args, false); + let (_matches, results) = geosearch_inner(db, args, false); results } @@ -284,24 +283,42 @@ pub fn geosearchstore(db: &mut Database, args: &[Frame]) -> Frame { }; // Shift args so args[0] is now the source key - let (count, _) = geosearch_inner(db, &args[1..], true); + let (matches, _) = geosearch_inner(db, &args[1..], true); - // Build sorted set from results and store - if count == 0 { + if matches.is_empty() { db.remove(&dest); return Frame::Integer(0); } - Frame::Integer(count as i64) + // Build a fresh sorted set from matches and store at dest + let mut new_members = std::collections::HashMap::new(); + let mut new_tree = crate::storage::bptree::BPTree::new(); + for (member, _dist, _lon, _lat, score) in &matches { + new_members.insert(member.clone(), *score); + new_tree.insert(OrderedFloat(*score), member.clone()); + } + let mut entry = crate::storage::entry::Entry::new_sorted_set_bptree(); + entry.value = crate::storage::compact_value::CompactValue::from_redis_value( + crate::storage::entry::RedisValue::SortedSetBPTree { + tree: new_tree, + members: new_members, + }, + ); + db.set(dest, entry); + + Frame::Integer(matches.len() as i64) } -fn geosearch_inner(db: &mut Database, args: &[Frame], _store_mode: bool) -> (usize, Frame) { +/// Returned by geosearch_inner: (member, dist_m, lon, lat, score) +type GeoMatch = (Bytes, f64, f64, f64, f64); + +fn geosearch_inner(db: &mut Database, args: &[Frame], _store_mode: bool) -> (Vec, Frame) { if args.len() < 6 { - return (0, err_wrong_args("GEOSEARCH")); + return (Vec::new(), err_wrong_args("GEOSEARCH")); } let key = match extract_bytes(&args[0]) { Some(k) => k, - None => return (0, err_wrong_args("GEOSEARCH")), + None => return (Vec::new(), err_wrong_args("GEOSEARCH")), }; // Parse source: FROMMEMBER or FROMLONLAT @@ -322,13 +339,18 @@ fn geosearch_inner(db: &mut Database, args: &[Frame], _store_mode: bool) -> (usi i += 1; let member = match extract_bytes(args.get(i).unwrap_or(&Frame::Null)) { Some(m) => m, - None => return (0, Frame::Error(Bytes::from_static(b"ERR syntax error"))), + None => { + return ( + Vec::new(), + Frame::Error(Bytes::from_static(b"ERR syntax error")), + ); + } }; // Look up member's score let members_map = match db.get_sorted_set(key) { Ok(Some((members, _))) => members.clone(), - Ok(None) => return (0, Frame::Array(Vec::new().into())), - Err(e) => return (0, e), + Ok(None) => return (Vec::new(), Frame::Array(Vec::new().into())), + Err(e) => return (Vec::new(), e), }; match members_map.get(member) { Some(&score) => { @@ -336,19 +358,29 @@ fn geosearch_inner(db: &mut Database, args: &[Frame], _store_mode: bool) -> (usi center_lon = lon; center_lat = lat; } - None => return (0, Frame::Array(Vec::new().into())), + None => return (Vec::new(), Frame::Array(Vec::new().into())), } found_from = true; } else if arg.eq_ignore_ascii_case(b"FROMLONLAT") { i += 1; center_lon = match args.get(i).and_then(|f| parse_f64(f)) { Some(v) => v, - None => return (0, Frame::Error(Bytes::from_static(b"ERR syntax error"))), + None => { + return ( + Vec::new(), + Frame::Error(Bytes::from_static(b"ERR syntax error")), + ); + } }; i += 1; center_lat = match args.get(i).and_then(|f| parse_f64(f)) { Some(v) => v, - None => return (0, Frame::Error(Bytes::from_static(b"ERR syntax error"))), + None => { + return ( + Vec::new(), + Frame::Error(Bytes::from_static(b"ERR syntax error")), + ); + } }; found_from = true; } @@ -356,7 +388,10 @@ fn geosearch_inner(db: &mut Database, args: &[Frame], _store_mode: bool) -> (usi } if !found_from { - return (0, Frame::Error(Bytes::from_static(b"ERR syntax error"))); + return ( + Vec::new(), + Frame::Error(Bytes::from_static(b"ERR syntax error")), + ); } // Parse shape: BYRADIUS or BYBOX @@ -368,6 +403,16 @@ fn geosearch_inner(db: &mut Database, args: &[Frame], _store_mode: bool) -> (usi let mut withcoord = false; let mut withdist = false; let mut withhash = false; + let mut output_unit_mult = 1.0f64; // for WITHDIST: convert meters → query unit + + let unit_err = || { + ( + Vec::new(), + Frame::Error(Bytes::from_static( + b"ERR unsupported unit provided. please use M, KM, FT, MI", + )), + ) + }; while i < args.len() { let arg = match extract_bytes(&args[i]) { @@ -378,10 +423,23 @@ fn geosearch_inner(db: &mut Database, args: &[Frame], _store_mode: bool) -> (usi } }; if arg.eq_ignore_ascii_case(b"BYRADIUS") { + if box_width_m.is_some() { + return ( + Vec::new(), + Frame::Error(Bytes::from_static( + b"ERR exactly one of BYRADIUS and BYBOX arguments must be provided", + )), + ); + } i += 1; let r = match args.get(i).and_then(|f| parse_f64(f)) { Some(v) => v, - None => return (0, Frame::Error(Bytes::from_static(b"ERR syntax error"))), + None => { + return ( + Vec::new(), + Frame::Error(Bytes::from_static(b"ERR syntax error")), + ); + } }; i += 1; let unit_mult = match args @@ -390,26 +448,38 @@ fn geosearch_inner(db: &mut Database, args: &[Frame], _store_mode: bool) -> (usi .and_then(|b| parse_unit(b)) { Some(v) => v, - None => { - return ( - 0, - Frame::Error(Bytes::from_static( - b"ERR unsupported unit provided. please use M, KM, FT, MI", - )), - ); - } + None => return unit_err(), }; + output_unit_mult = unit_mult; radius_m = Some(r * unit_mult); } else if arg.eq_ignore_ascii_case(b"BYBOX") { + if radius_m.is_some() { + return ( + Vec::new(), + Frame::Error(Bytes::from_static( + b"ERR exactly one of BYRADIUS and BYBOX arguments must be provided", + )), + ); + } i += 1; let w = match args.get(i).and_then(|f| parse_f64(f)) { Some(v) => v, - None => return (0, Frame::Error(Bytes::from_static(b"ERR syntax error"))), + None => { + return ( + Vec::new(), + Frame::Error(Bytes::from_static(b"ERR syntax error")), + ); + } }; i += 1; let h = match args.get(i).and_then(|f| parse_f64(f)) { Some(v) => v, - None => return (0, Frame::Error(Bytes::from_static(b"ERR syntax error"))), + None => { + return ( + Vec::new(), + Frame::Error(Bytes::from_static(b"ERR syntax error")), + ); + } }; i += 1; let unit_mult = match args @@ -418,15 +488,9 @@ fn geosearch_inner(db: &mut Database, args: &[Frame], _store_mode: bool) -> (usi .and_then(|b| parse_unit(b)) { Some(v) => v, - None => { - return ( - 0, - Frame::Error(Bytes::from_static( - b"ERR unsupported unit provided. please use M, KM, FT, MI", - )), - ); - } + None => return unit_err(), }; + output_unit_mult = unit_mult; box_width_m = Some(w * unit_mult); box_height_m = Some(h * unit_mult); } else if arg.eq_ignore_ascii_case(b"ASC") { @@ -437,7 +501,12 @@ fn geosearch_inner(db: &mut Database, args: &[Frame], _store_mode: bool) -> (usi i += 1; let c = match args.get(i).and_then(|f| parse_f64(f)) { Some(v) if v > 0.0 => v as usize, - _ => return (0, Frame::Error(Bytes::from_static(b"ERR syntax error"))), + _ => { + return ( + Vec::new(), + Frame::Error(Bytes::from_static(b"ERR syntax error")), + ); + } }; count_limit = Some(c); // Skip optional ANY @@ -454,13 +523,18 @@ fn geosearch_inner(db: &mut Database, args: &[Frame], _store_mode: bool) -> (usi withdist = true; } else if arg.eq_ignore_ascii_case(b"WITHHASH") { withhash = true; + } else { + return ( + Vec::new(), + Frame::Error(Bytes::from_static(b"ERR syntax error")), + ); } i += 1; } if radius_m.is_none() && box_width_m.is_none() { return ( - 0, + Vec::new(), Frame::Error(Bytes::from_static( b"ERR exactly one of BYRADIUS and BYBOX arguments must be provided", )), @@ -470,8 +544,8 @@ fn geosearch_inner(db: &mut Database, args: &[Frame], _store_mode: bool) -> (usi // Get all members with their coordinates let members_map = match db.get_sorted_set(key) { Ok(Some((members, _))) => members.clone(), - Ok(None) => return (0, Frame::Array(Vec::new().into())), - Err(e) => return (0, e), + Ok(None) => return (Vec::new(), Frame::Array(Vec::new().into())), + Err(e) => return (Vec::new(), e), }; // Filter by shape @@ -505,19 +579,23 @@ fn geosearch_inner(db: &mut Database, args: &[Frame], _store_mode: bool) -> (usi matches.truncate(limit); } - let total = matches.len(); let has_extras = withcoord || withdist || withhash; let results: Vec = matches - .into_iter() + .iter() .map(|(member, dist, lon, lat, score)| { if has_extras { - let mut entry = vec![Frame::BulkString(member)]; + let mut entry = vec![Frame::BulkString(member.clone())]; if withdist { - entry.push(Frame::BulkString(Bytes::from(format!("{:.4}", dist)))); + // Convert meters to the same unit used in BYRADIUS/BYBOX query + let dist_in_unit = dist / output_unit_mult; + entry.push(Frame::BulkString(Bytes::from(format!( + "{:.4}", + dist_in_unit + )))); } if withhash { - entry.push(Frame::Integer(score as i64)); + entry.push(Frame::Integer(*score as i64)); } if withcoord { entry.push(Frame::Array( @@ -530,12 +608,12 @@ fn geosearch_inner(db: &mut Database, args: &[Frame], _store_mode: bool) -> (usi } Frame::Array(entry.into()) } else { - Frame::BulkString(member) + Frame::BulkString(member.clone()) } }) .collect(); - (total, Frame::Array(results.into())) + (matches, Frame::Array(results.into())) } #[cfg(test)] diff --git a/src/command/key_extra.rs b/src/command/key_extra.rs index 1aacf8c5..8328fc72 100644 --- a/src/command/key_extra.rs +++ b/src/command/key_extra.rs @@ -38,7 +38,6 @@ pub fn copy(db: &mut Database, args: &[Frame]) -> Frame { }; if arg.eq_ignore_ascii_case(b"REPLACE") { replace = true; - i += 1; } else if arg.eq_ignore_ascii_case(b"DB") { // Cross-DB copy requires shard_databases context not available here return Frame::Error(Bytes::from_static( @@ -47,16 +46,17 @@ pub fn copy(db: &mut Database, args: &[Frame]) -> Frame { } else { return Frame::Error(Bytes::from_static(b"ERR syntax error")); } + i += 1; } - // Check if source exists (with lazy expiry) + // Redis returns 0 (not error) when source doesn't exist if !db.exists(src) { - return Frame::Error(Bytes::from_static(b"ERR no such key")); + return Frame::Integer(0); } - // Same key: no data to copy, but it's valid + // Same key: source == dest, nothing to do if src == dst { - return Frame::Integer(1); + return Frame::Integer(0); } // Check if destination exists @@ -247,11 +247,26 @@ pub fn sort(db: &mut Database, args: &[Frame]) -> Frame { elements.iter().map(|e| Some(e.clone())).collect() }; + // For numeric sort (no ALPHA), validate all sort keys are parseable as f64 + let no_sort = by_pattern.is_some_and(|p| p == b"nosort"); + if !alpha && !no_sort { + for sk in &sort_keys { + if let Some(v) = sk { + if std::str::from_utf8(v) + .ok() + .and_then(|s| s.parse::().ok()) + .is_none() + { + return Frame::Error(Bytes::from_static( + b"ERR One or more scores can't be converted into double", + )); + } + } + } + } + // Create indexed pairs for stable sort let mut indices: Vec = (0..elements.len()).collect(); - - // Sort (skip if BY nosort) - let no_sort = by_pattern.is_some_and(|p| p == b"nosort"); if !no_sort { indices.sort_by(|&a, &b| { let ka = sort_keys[a].as_ref(); @@ -401,14 +416,16 @@ mod tests { fn test_copy_nonexistent_source() { let mut db = Database::new(); let result = copy(&mut db, &[bs(b"nosuchkey"), bs(b"dst")]); - assert!(matches!(result, Frame::Error(_))); + // Redis returns 0 for missing source, not error + assert_eq!(result, Frame::Integer(0)); } #[test] fn test_copy_same_key() { let mut db = setup_db_with_key(b"src", b"hello"); let result = copy(&mut db, &[bs(b"src"), bs(b"src")]); - assert_eq!(result, Frame::Integer(1)); + // Redis returns 0 for same-key copy + assert_eq!(result, Frame::Integer(0)); } #[test] diff --git a/src/command/string/string_bit.rs b/src/command/string/string_bit.rs index da374587..9dfc9993 100644 --- a/src/command/string/string_bit.rs +++ b/src/command/string/string_bit.rs @@ -232,8 +232,8 @@ pub fn bitcount(db: &mut Database, args: &[Frame]) -> Frame { if use_bit { // BIT mode: count bits in the bit range let total_bits = (data.len() * 8) as i64; - let s = normalize_index(start, total_bits); - let e = normalize_index(end, total_bits); + let s = normalize_start(start, total_bits); + let e = normalize_end(end, total_bits); if s > e { return Frame::Integer(0); } @@ -242,8 +242,8 @@ pub fn bitcount(db: &mut Database, args: &[Frame]) -> Frame { } else { // BYTE mode (default): count bits in the byte range let len = data.len() as i64; - let s = normalize_index(start, len); - let e = normalize_index(end, len); + let s = normalize_start(start, len); + let e = normalize_end(end, len); if s > e { return Frame::Integer(0); } @@ -318,8 +318,8 @@ pub fn bitcount_readonly(db: &Database, args: &[Frame], now_ms: u64) -> Frame { if use_bit { let total_bits = (data.len() * 8) as i64; - let s = normalize_index(start, total_bits); - let e = normalize_index(end, total_bits); + let s = normalize_start(start, total_bits); + let e = normalize_end(end, total_bits); if s > e { return Frame::Integer(0); } @@ -327,8 +327,8 @@ pub fn bitcount_readonly(db: &Database, args: &[Frame], now_ms: u64) -> Frame { Frame::Integer(count as i64) } else { let len = data.len() as i64; - let s = normalize_index(start, len); - let e = normalize_index(end, len); + let s = normalize_start(start, len); + let e = normalize_end(end, len); if s > e { return Frame::Integer(0); } @@ -534,8 +534,8 @@ pub fn bitpos(db: &mut Database, args: &[Frame]) -> Frame { if use_bit { let total_bits = (data.len() * 8) as i64; - let s = normalize_index(start, total_bits) as usize; - let e = normalize_index(end, total_bits) as usize; + let s = normalize_start(start, total_bits) as usize; + let e = normalize_end(end, total_bits) as usize; if s > e { return Frame::Integer(-1); } @@ -552,8 +552,8 @@ pub fn bitpos(db: &mut Database, args: &[Frame]) -> Frame { Frame::Integer(-1) } else { let len = data.len() as i64; - let s = normalize_index(start, len) as usize; - let e = normalize_index(end, len) as usize; + let s = normalize_start(start, len) as usize; + let e = normalize_end(end, len) as usize; if s > e { return Frame::Integer(-1); } @@ -672,8 +672,8 @@ pub fn bitpos_readonly(db: &Database, args: &[Frame], now_ms: u64) -> Frame { if use_bit { let total_bits = (data.len() * 8) as i64; - let s = normalize_index(start, total_bits) as usize; - let e = normalize_index(end, total_bits) as usize; + let s = normalize_start(start, total_bits) as usize; + let e = normalize_end(end, total_bits) as usize; if s > e { return Frame::Integer(-1); } @@ -690,8 +690,8 @@ pub fn bitpos_readonly(db: &Database, args: &[Frame], now_ms: u64) -> Frame { Frame::Integer(-1) } else { let len = data.len() as i64; - let s = normalize_index(start, len) as usize; - let e = normalize_index(end, len) as usize; + let s = normalize_start(start, len) as usize; + let e = normalize_end(end, len) as usize; if s > e { return Frame::Integer(-1); } @@ -712,8 +712,19 @@ pub fn bitpos_readonly(db: &Database, args: &[Frame], now_ms: u64) -> Frame { } } -/// Normalize a Redis index (negative = from end) to a 0-based clamped index. -fn normalize_index(idx: i64, len: i64) -> i64 { +/// Normalize a start index: negative wraps from end, positive stays unclamped +/// so callers detect empty ranges via `start > end`. +fn normalize_start(idx: i64, len: i64) -> i64 { + if len == 0 { + return 0; + } + let normalized = if idx < 0 { len + idx } else { idx }; + normalized.max(0) +} + +/// Normalize an end index: negative wraps from end, clamped to `len - 1` +/// to prevent out-of-bounds slicing. +fn normalize_end(idx: i64, len: i64) -> i64 { if len == 0 { return 0; } diff --git a/src/server/conn/handler_sharded.rs b/src/server/conn/handler_sharded.rs index 2355f699..148a0dfd 100644 --- a/src/server/conn/handler_sharded.rs +++ b/src/server/conn/handler_sharded.rs @@ -699,13 +699,28 @@ pub(crate) async fn handle_connection_sharded_inner< } else { 0 } }; if pause_wait_ms > 0 { - #[cfg(feature = "runtime-tokio")] - { - tokio::time::sleep(std::time::Duration::from_millis(pause_wait_ms)).await; - } - #[cfg(feature = "runtime-monoio")] - { - monoio::time::sleep(std::time::Duration::from_millis(pause_wait_ms)).await; + // Poll in 50ms intervals so CLIENT UNPAUSE takes effect quickly + let mut remaining = pause_wait_ms; + while remaining > 0 { + let chunk = remaining.min(50); + #[cfg(feature = "runtime-tokio")] + { + tokio::time::sleep(std::time::Duration::from_millis(chunk)).await; + } + #[cfg(feature = "runtime-monoio")] + { + monoio::time::sleep(std::time::Duration::from_millis(chunk)).await; + } + remaining = remaining.saturating_sub(chunk); + // Re-check if UNPAUSE was called + let still_paused = { + let rt = ctx.runtime_config.read(); + rt.client_pause_deadline_ms > 0 + && crate::storage::entry::current_time_ms() < rt.client_pause_deadline_ms + }; + if !still_paused { + break; + } } } @@ -944,21 +959,37 @@ pub(crate) async fn handle_connection_sharded_inner< Some(b) => std::str::from_utf8(&b).ok().and_then(|s| s.parse::().ok()), None => None, }; - match timeout_ms { - Some(ms) => { - let write_only = cmd_args.get(2) - .and_then(|f| extract_bytes(f)) - .is_some_and(|b| b.eq_ignore_ascii_case(b"WRITE")); - let deadline = crate::storage::entry::current_time_ms() + ms; - let mut rt = ctx.runtime_config.write(); - rt.client_pause_deadline_ms = deadline; - rt.client_pause_write_only = write_only; - responses.push(Frame::SimpleString(Bytes::from_static(b"OK"))); + // Validate mode arg if present + let mode_valid = if cmd_args.len() >= 3 { + match extract_bytes(&cmd_args[2]) { + Some(b) => b.eq_ignore_ascii_case(b"WRITE") || b.eq_ignore_ascii_case(b"ALL"), + None => false, } - None => { - responses.push(Frame::Error(Bytes::from_static( - b"ERR timeout is not a valid integer or out of range", - ))); + } else { + true + }; + // Reject extra trailing args + if cmd_args.len() > 3 || !mode_valid { + responses.push(Frame::Error(Bytes::from_static( + b"ERR syntax error", + ))); + } else { + match timeout_ms { + Some(ms) => { + let write_only = cmd_args.get(2) + .and_then(|f| extract_bytes(f)) + .is_some_and(|b| b.eq_ignore_ascii_case(b"WRITE")); + let deadline = crate::storage::entry::current_time_ms().saturating_add(ms); + let mut rt = ctx.runtime_config.write(); + rt.client_pause_deadline_ms = deadline; + rt.client_pause_write_only = write_only; + responses.push(Frame::SimpleString(Bytes::from_static(b"OK"))); + } + None => { + responses.push(Frame::Error(Bytes::from_static( + b"ERR timeout is not a valid integer or out of range", + ))); + } } } } From 05c782a3a3032b77bb2b9004cd704641cd88b512 Mon Sep 17 00:00:00 2001 From: Tin Dang Date: Sat, 11 Apr 2026 08:58:04 +0700 Subject: [PATCH 14/20] feat: implement Tier 1 gap commands (EXPIREAT, FLUSHDB, TIME, RANDOMKEY, TOUCH) Closes 7 command gaps identified in Redis parity audit: - EXPIREAT/PEXPIREAT: absolute Unix timestamp expiry (seconds/ms) - EXPIRETIME/PEXPIRETIME: read back absolute expiry timestamp - FLUSHDB/FLUSHALL: clear all keys in current database - TIME: return server clock as [seconds, microseconds] - RANDOMKEY: return a random key from the database - TOUCH: refresh LRU/LFU access time, return count of existing keys - SHUTDOWN: dispatch entry (returns ERR, actual shutdown via signal) Also adds Database::random_key() helper. 11 unit tests, entries in test-commands.sh and test-consistency.sh. --- scripts/test-commands.sh | 15 ++ scripts/test-consistency.sh | 16 +++ src/command/key.rs | 267 ++++++++++++++++++++++++++++++++++++ src/command/mod.rs | 61 +++++++- src/storage/db.rs | 13 ++ 5 files changed, 370 insertions(+), 2 deletions(-) diff --git a/scripts/test-commands.sh b/scripts/test-commands.sh index cfe30def..695ae078 100755 --- a/scripts/test-commands.sh +++ b/scripts/test-commands.sh @@ -590,6 +590,21 @@ if should_run "key"; then assert_match "GEODIST km" GEODIST k:geo Palermo Catania km assert_match "GEOHASH" GEOHASH k:geo Palermo assert_match "GEOSEARCH" GEOSEARCH k:geo FROMLONLAT 15 37 BYRADIUS 200 km ASC + # EXPIREAT / PEXPIREAT / EXPIRETIME / PEXPIRETIME + rcli SET k:eat val >/dev/null 2>&1; mcli SET k:eat val >/dev/null 2>&1 + assert_match "EXPIREAT" EXPIREAT k:eat 9999999999 + assert_match "TTL after EXPIREAT" TTL k:eat + assert_match "EXPIRETIME" EXPIRETIME k:eat + assert_match "PEXPIRETIME" PEXPIRETIME k:eat + + # TIME / RANDOMKEY / TOUCH + assert_moon_ok "TIME" TIME + rcli SET k:rnd val >/dev/null 2>&1; mcli SET k:rnd val >/dev/null 2>&1 + assert_moon_ok "RANDOMKEY" RANDOMKEY + assert_match "TOUCH" TOUCH k:rnd + + # FLUSHDB + assert_match "FLUSHDB" FLUSHDB fi # =========================================================================== diff --git a/scripts/test-consistency.sh b/scripts/test-consistency.sh index 71b1f7cb..61206888 100755 --- a/scripts/test-consistency.sh +++ b/scripts/test-consistency.sh @@ -527,6 +527,22 @@ assert_both "GEODIST km" GEODIST edge:geo Palermo Catania km assert_both "GEOHASH" GEOHASH edge:geo Palermo assert_both "GEOADD count" GEOADD edge:geo 2.349014 48.864716 Paris +# EXPIREAT / PEXPIREAT / EXPIRETIME / PEXPIRETIME +both SET edge:eat "val" +assert_both "EXPIREAT" EXPIREAT edge:eat 9999999999 +assert_both "EXPIRETIME" EXPIRETIME edge:eat +assert_both "PEXPIRETIME" PEXPIRETIME edge:eat +assert_both "EXPIRETIME missing" EXPIRETIME edge:nokey +assert_both "PEXPIRETIME missing" PEXPIRETIME edge:nokey + +# TOUCH +both SET edge:touch "val" +assert_both "TOUCH" TOUCH edge:touch +assert_both "TOUCH missing" TOUCH edge:nomiss + +# FLUSHDB (run last — clears all keys) +assert_both "FLUSHDB" FLUSHDB + # =========================================================================== # Summary # =========================================================================== diff --git a/src/command/key.rs b/src/command/key.rs index 480d21cb..37d24879 100644 --- a/src/command/key.rs +++ b/src/command/key.rs @@ -217,6 +217,166 @@ pub fn persist(db: &mut Database, args: &[Frame]) -> Frame { } } +/// EXPIREAT key unix-time-seconds +/// +/// Set the expiration for a key as a UNIX timestamp (seconds). +pub fn expireat(db: &mut Database, args: &[Frame]) -> Frame { + if args.len() != 2 { + return err_wrong_args("EXPIREAT"); + } + let key = match extract_key(&args[0]) { + Some(k) => k, + None => return err_wrong_args("EXPIREAT"), + }; + let timestamp = match parse_int(&args[1]) { + Some(n) if n > 0 => n as u64, + _ => { + return Frame::Error(Bytes::from_static( + b"ERR invalid expire time in 'EXPIREAT' command", + )); + } + }; + let expires_at_ms = timestamp * 1000; + if db.set_expiry(key, expires_at_ms) { + Frame::Integer(1) + } else { + Frame::Integer(0) + } +} + +/// PEXPIREAT key unix-time-milliseconds +/// +/// Set the expiration for a key as a UNIX timestamp in milliseconds. +pub fn pexpireat(db: &mut Database, args: &[Frame]) -> Frame { + if args.len() != 2 { + return err_wrong_args("PEXPIREAT"); + } + let key = match extract_key(&args[0]) { + Some(k) => k, + None => return err_wrong_args("PEXPIREAT"), + }; + let timestamp_ms = match parse_int(&args[1]) { + Some(n) if n > 0 => n as u64, + _ => { + return Frame::Error(Bytes::from_static( + b"ERR invalid expire time in 'PEXPIREAT' command", + )); + } + }; + if db.set_expiry(key, timestamp_ms) { + Frame::Integer(1) + } else { + Frame::Integer(0) + } +} + +/// EXPIRETIME key +/// +/// Returns the absolute Unix timestamp (in seconds) at which the key will expire. +/// Returns -2 if key doesn't exist, -1 if key has no expiry. +pub fn expiretime(db: &mut Database, args: &[Frame]) -> Frame { + if args.len() != 1 { + return err_wrong_args("EXPIRETIME"); + } + let key = match extract_key(&args[0]) { + Some(k) => k, + None => return err_wrong_args("EXPIRETIME"), + }; + let base_ts = db.base_timestamp(); + match db.get(key) { + None => Frame::Integer(-2), + Some(entry) => { + if !entry.has_expiry() { + Frame::Integer(-1) + } else { + Frame::Integer((entry.expires_at_ms(base_ts) / 1000) as i64) + } + } + } +} + +/// PEXPIRETIME key +/// +/// Returns the absolute Unix timestamp (in milliseconds) at which the key will expire. +pub fn pexpiretime(db: &mut Database, args: &[Frame]) -> Frame { + if args.len() != 1 { + return err_wrong_args("PEXPIRETIME"); + } + let key = match extract_key(&args[0]) { + Some(k) => k, + None => return err_wrong_args("PEXPIRETIME"), + }; + let base_ts = db.base_timestamp(); + match db.get(key) { + None => Frame::Integer(-2), + Some(entry) => { + if !entry.has_expiry() { + Frame::Integer(-1) + } else { + Frame::Integer(entry.expires_at_ms(base_ts) as i64) + } + } + } +} + +/// RANDOMKEY +/// +/// Returns a random key from the currently selected database. +pub fn randomkey(db: &mut Database, _args: &[Frame]) -> Frame { + match db.random_key() { + Some(key) => Frame::BulkString(key), + None => Frame::Null, + } +} + +/// TOUCH key [key ...] +/// +/// Alters the last access time of a key(s). Returns the number of keys that exist. +pub fn touch(db: &mut Database, args: &[Frame]) -> Frame { + if args.is_empty() { + return err_wrong_args("TOUCH"); + } + let mut count = 0i64; + for arg in args { + let key = match extract_key(arg) { + Some(k) => k, + None => continue, + }; + if db.exists(key) { + // exists() already does lazy expiry + access tracking + count += 1; + } + } + Frame::Integer(count) +} + +/// TIME +/// +/// Returns the current server time as a two-element array: +/// [unix-seconds, microseconds-since-epoch-second]. +pub fn time() -> Frame { + let now = std::time::SystemTime::now() + .duration_since(std::time::UNIX_EPOCH) + .unwrap_or_default(); + let secs = now.as_secs(); + let micros = now.subsec_micros(); + Frame::Array( + vec![ + Frame::BulkString(Bytes::from(secs.to_string())), + Frame::BulkString(Bytes::from(micros.to_string())), + ] + .into(), + ) +} + +/// FLUSHDB [ASYNC|SYNC] +/// +/// Delete all keys in the currently selected database. +pub fn flushdb(db: &mut Database, _args: &[Frame]) -> Frame { + db.clear(); + Frame::SimpleString(Bytes::from_static(b"OK")) +} + /// TYPE key /// /// Returns the string representation of the type of the value stored at key. @@ -1469,4 +1629,111 @@ mod tests { _ => panic!("Expected array"), } } + + // --- EXPIREAT / PEXPIREAT / EXPIRETIME / PEXPIRETIME tests --- + + #[test] + fn test_expireat() { + let mut db = setup_db_with_key(b"k", b"v"); + let future_ts = (current_time_ms() / 1000 + 3600) as i64; + let result = expireat( + &mut db, + &[ + bs(b"k"), + Frame::BulkString(Bytes::from(future_ts.to_string())), + ], + ); + assert_eq!(result, Frame::Integer(1)); + assert!(db.get(b"k").unwrap().has_expiry()); + } + + #[test] + fn test_expireat_missing() { + let mut db = Database::new(); + let result = expireat(&mut db, &[bs(b"k"), bs(b"9999999999")]); + assert_eq!(result, Frame::Integer(0)); + } + + #[test] + fn test_pexpireat() { + let mut db = setup_db_with_key(b"k", b"v"); + let future_ms = (current_time_ms() + 3_600_000) as i64; + let result = pexpireat( + &mut db, + &[ + bs(b"k"), + Frame::BulkString(Bytes::from(future_ms.to_string())), + ], + ); + assert_eq!(result, Frame::Integer(1)); + } + + #[test] + fn test_expiretime() { + let mut db = setup_db_with_key(b"k", b"v"); + // No expiry → -1 + let result = expiretime(&mut db, &[bs(b"k")]); + assert_eq!(result, Frame::Integer(-1)); + // Missing → -2 + let result = expiretime(&mut db, &[bs(b"nope")]); + assert_eq!(result, Frame::Integer(-2)); + } + + #[test] + fn test_pexpiretime() { + let mut db = setup_db_with_key(b"k", b"v"); + let result = pexpiretime(&mut db, &[bs(b"k")]); + assert_eq!(result, Frame::Integer(-1)); + } + + // --- RANDOMKEY / TOUCH / TIME / FLUSHDB tests --- + + #[test] + fn test_randomkey_empty() { + let mut db = Database::new(); + assert_eq!(randomkey(&mut db, &[]), Frame::Null); + } + + #[test] + fn test_randomkey_nonempty() { + let mut db = setup_db_with_key(b"only", b"val"); + match randomkey(&mut db, &[]) { + Frame::BulkString(k) => assert_eq!(k.as_ref(), b"only"), + _ => panic!("Expected BulkString"), + } + } + + #[test] + fn test_touch() { + let mut db = setup_db_with_key(b"a", b"1"); + db.set( + Bytes::from_static(b"b"), + Entry::new_string(Bytes::from_static(b"2")), + ); + let result = touch(&mut db, &[bs(b"a"), bs(b"b"), bs(b"missing")]); + assert_eq!(result, Frame::Integer(2)); + } + + #[test] + fn test_time() { + match time() { + Frame::Array(ref arr) => { + assert_eq!(arr.len(), 2); + } + _ => panic!("Expected array"), + } + } + + #[test] + fn test_flushdb() { + let mut db = setup_db_with_key(b"a", b"1"); + db.set( + Bytes::from_static(b"b"), + Entry::new_string(Bytes::from_static(b"2")), + ); + assert_eq!(db.len(), 2); + let result = flushdb(&mut db, &[]); + assert_eq!(result, Frame::SimpleString(Bytes::from_static(b"OK"))); + assert_eq!(db.len(), 0); + } } diff --git a/src/command/mod.rs b/src/command/mod.rs index 74d3dd07..7bd9c139 100644 --- a/src/command/mod.rs +++ b/src/command/mod.rs @@ -211,10 +211,13 @@ fn dispatch_inner( } } (4, b't') => { - // TYPE + // TYPE TIME if cmd.eq_ignore_ascii_case(b"TYPE") { return resp(key::type_cmd(db, args)); } + if cmd.eq_ignore_ascii_case(b"TIME") { + return resp(key::time()); + } } (4, b'x') => { // XADD XLEN XDEL XACK @@ -368,6 +371,12 @@ fn dispatch_inner( return resp(sorted_set::zmpop(db, args)); } } + (5, b't') => { + // TOUCH + if cmd.eq_ignore_ascii_case(b"TOUCH") { + return resp(key::touch(db, args)); + } + } // 6-letter commands (6, b'b') => { // BITPOS @@ -556,6 +565,12 @@ fn dispatch_inner( } } // 7-letter commands + (7, b'f') => { + // FLUSHDB + if cmd.eq_ignore_ascii_case(b"FLUSHDB") { + return resp(key::flushdb(db, args)); + } + } (7, b'g') => { // GEODIST GEOHASH if cmd.eq_ignore_ascii_case(b"GEODIST") { @@ -632,6 +647,18 @@ fn dispatch_inner( } } // 8-letter commands + (8, b'e') => { + // EXPIREAT + if cmd.eq_ignore_ascii_case(b"EXPIREAT") { + return resp(key::expireat(db, args)); + } + } + (8, b'f') => { + // FLUSHALL — clears current DB (per-shard, not cross-shard) + if cmd.eq_ignore_ascii_case(b"FLUSHALL") { + return resp(key::flushdb(db, args)); + } + } (8, b'g') => { // GETRANGE if cmd.eq_ignore_ascii_case(b"GETRANGE") { @@ -651,7 +678,13 @@ fn dispatch_inner( } } (8, b's') => { - // SMEMBERS SETRANGE + // SMEMBERS SETRANGE SHUTDOWN + if cmd.eq_ignore_ascii_case(b"SHUTDOWN") { + // Acknowledge but don't kill — actual shutdown is handled by the server + return resp(Frame::Error(Bytes::from_static( + b"ERR Errors trying to SHUTDOWN. Check logs.", + ))); + } if cmd.eq_ignore_ascii_case(b"SMEMBERS") { return resp(set::smembers(db, args)); } @@ -678,6 +711,18 @@ fn dispatch_inner( return resp(geo::geosearch(db, args)); } } + (9, b'p') => { + // PEXPIREAT + if cmd.eq_ignore_ascii_case(b"PEXPIREAT") { + return resp(key::pexpireat(db, args)); + } + } + (9, b'r') => { + // RANDOMKEY + if cmd.eq_ignore_ascii_case(b"RANDOMKEY") { + return resp(key::randomkey(db, &[])); + } + } (9, b's') => { // SISMEMBER if cmd.eq_ignore_ascii_case(b"SISMEMBER") { @@ -706,6 +751,12 @@ fn dispatch_inner( return resp(hash::hrandfield(db, args)); } } + (10, b'e') => { + // EXPIRETIME + if cmd.eq_ignore_ascii_case(b"EXPIRETIME") { + return resp(key::expiretime(db, args)); + } + } (10, b's') => { // SMISMEMBER SDIFFSTORE SINTERCARD if cmd.eq_ignore_ascii_case(b"SMISMEMBER") { @@ -734,6 +785,12 @@ fn dispatch_inner( } } // 11-letter commands + (11, b'p') => { + // PEXPIRETIME + if cmd.eq_ignore_ascii_case(b"PEXPIRETIME") { + return resp(key::pexpiretime(db, args)); + } + } (11, b'i') => { // INCRBYFLOAT if cmd.eq_ignore_ascii_case(b"INCRBYFLOAT") { diff --git a/src/storage/db.rs b/src/storage/db.rs index 577f31d0..486ffaf3 100644 --- a/src/storage/db.rs +++ b/src/storage/db.rs @@ -290,6 +290,19 @@ impl Database { self.data.keys() } + /// Return a random key from the database, or None if empty. + pub fn random_key(&self) -> Option { + if self.data.is_empty() { + return None; + } + // Pick a random index via simple hash of current time + let idx = (current_time_ms() as usize) % self.data.len(); + self.data + .keys() + .nth(idx) + .map(|k| Bytes::copy_from_slice(k.as_ref())) + } + /// Set or remove expiration on an existing key. /// /// Performs lazy expiry check first. Returns `false` if the key does not From ce0113728fbc8297be660f23d4b5586da1c9cf6b Mon Sep 17 00:00:00 2001 From: Tin Dang Date: Sat, 11 Apr 2026 08:58:27 +0700 Subject: [PATCH 15/20] docs: add Tier 1 commands to CHANGELOG --- CHANGELOG.md | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index 87cfdd44..f255e3e9 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -18,6 +18,13 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - **Lazyfree threshold** — configurable via `CONFIG SET lazyfree-threshold N` (default 64). - **GETBIT/SETBIT metadata** — added to PHF command registry. - **GEOADD/GEOSEARCHSTORE** — added to AOF write commands test list. +- **EXPIREAT/PEXPIREAT** — absolute Unix timestamp expiry (seconds/milliseconds). +- **EXPIRETIME/PEXPIRETIME** — read back absolute expiry timestamp. +- **FLUSHDB/FLUSHALL** — clear all keys in current database. +- **TIME** — server clock as `[seconds, microseconds]`. +- **RANDOMKEY** — return a random key from the database. +- **TOUCH** — refresh LRU/LFU access time without reading value. +- **SHUTDOWN** — dispatch entry (graceful stop via signal handler). ### Fixed — Wave 0-4 Gap Closure (2026-04-09) From 0f5c2d4a383eb8751e53e2e843dd95e4c1d2b833 Mon Sep 17 00:00:00 2001 From: Tin Dang Date: Sat, 11 Apr 2026 09:54:13 +0700 Subject: [PATCH 16/20] feat: resolve all remaining Redis command gaps Implements 10 commands, closing every actionable gap: New commands: - BITFIELD: GET/SET/INCRBY with type specifiers (u8/i16/u32/...), OVERFLOW WRAP/SAT/FAIL, bit-level packed integer array - LCS: Longest Common Substring with LEN option (DP algorithm) - XSETID: set stream last-delivered ID without adding entries - GEORADIUS: deprecated wrapper translating to GEOSEARCH - GEORADIUSBYMEMBER: deprecated wrapper translating to GEOSEARCH - LOLWUT: Easter egg returning "Moon v0.1.0" - SWAPDB: dispatch entry (returns ERR in sharded mode) Enhanced commands: - OBJECT FREQ: return LFU access counter - OBJECT IDLETIME: return idle seconds (16-bit wraparound-safe) - OBJECT REFCOUNT: always returns 1 (no reference counting) Metadata added: LCS, XSETID --- src/command/geo/geo_cmd.rs | 39 +++++ src/command/key.rs | 55 ++++++ src/command/metadata.rs | 2 + src/command/mod.rs | 39 ++++- src/command/stream/stream_write.rs | 32 +++- src/command/string/string_bit.rs | 265 +++++++++++++++++++++++++++++ src/command/string/string_read.rs | 89 ++++++++++ 7 files changed, 517 insertions(+), 4 deletions(-) diff --git a/src/command/geo/geo_cmd.rs b/src/command/geo/geo_cmd.rs index cd5bf75b..c15a4af5 100644 --- a/src/command/geo/geo_cmd.rs +++ b/src/command/geo/geo_cmd.rs @@ -272,6 +272,45 @@ pub fn geosearch(db: &mut Database, args: &[Frame]) -> Frame { results } +/// GEORADIUS key longitude latitude radius M|KM|FT|MI [WITHCOORD] [WITHDIST] [WITHHASH] [COUNT n] [ASC|DESC] +/// +/// Deprecated since Redis 6.2 — translates to GEOSEARCH internally. +pub fn georadius(db: &mut Database, args: &[Frame]) -> Frame { + if args.len() < 5 { + return err_wrong_args("GEORADIUS"); + } + // Translate: GEORADIUS key lon lat radius unit [opts...] + // → GEOSEARCH key FROMLONLAT lon lat BYRADIUS radius unit [opts...] + let mut new_args = Vec::with_capacity(args.len() + 3); + new_args.push(args[0].clone()); // key + new_args.push(Frame::BulkString(Bytes::from_static(b"FROMLONLAT"))); + new_args.push(args[1].clone()); // lon + new_args.push(args[2].clone()); // lat + new_args.push(Frame::BulkString(Bytes::from_static(b"BYRADIUS"))); + new_args.push(args[3].clone()); // radius + new_args.push(args[4].clone()); // unit + new_args.extend_from_slice(&args[5..]); // remaining options + geosearch(db, &new_args) +} + +/// GEORADIUSBYMEMBER key member radius M|KM|FT|MI [opts...] +/// +/// Deprecated since Redis 6.2 — translates to GEOSEARCH internally. +pub fn georadiusbymember(db: &mut Database, args: &[Frame]) -> Frame { + if args.len() < 4 { + return err_wrong_args("GEORADIUSBYMEMBER"); + } + let mut new_args = Vec::with_capacity(args.len() + 3); + new_args.push(args[0].clone()); // key + new_args.push(Frame::BulkString(Bytes::from_static(b"FROMMEMBER"))); + new_args.push(args[1].clone()); // member + new_args.push(Frame::BulkString(Bytes::from_static(b"BYRADIUS"))); + new_args.push(args[2].clone()); // radius + new_args.push(args[3].clone()); // unit + new_args.extend_from_slice(&args[4..]); // remaining options + geosearch(db, &new_args) +} + /// GEOSEARCHSTORE destination source ... pub fn geosearchstore(db: &mut Database, args: &[Frame]) -> Frame { if args.len() < 2 { diff --git a/src/command/key.rs b/src/command/key.rs index 37d24879..e170c338 100644 --- a/src/command/key.rs +++ b/src/command/key.rs @@ -426,12 +426,67 @@ pub fn object(db: &mut Database, args: &[Frame]) -> Frame { } None => Frame::Null, } + } else if subcommand.eq_ignore_ascii_case(b"FREQ") { + if args.len() != 2 { + return err_wrong_args("OBJECT"); + } + let key = match extract_key(&args[1]) { + Some(k) => k, + None => return err_wrong_args("OBJECT"), + }; + match db.get(key) { + Some(entry) => Frame::Integer(entry.access_counter() as i64), + None => Frame::Error(Bytes::from_static(b"ERR no such key")), + } + } else if subcommand.eq_ignore_ascii_case(b"IDLETIME") { + if args.len() != 2 { + return err_wrong_args("OBJECT"); + } + let key = match extract_key(&args[1]) { + Some(k) => k, + None => return err_wrong_args("OBJECT"), + }; + let now = db.now(); + match db.get(key) { + Some(entry) => { + let last = entry.last_access(); + // Wraparound-safe delta in seconds (16-bit) + let idle = (now.wrapping_sub(last)) & 0xFFFF; + Frame::Integer(idle as i64) + } + None => Frame::Error(Bytes::from_static(b"ERR no such key")), + } + } else if subcommand.eq_ignore_ascii_case(b"REFCOUNT") { + if args.len() != 2 { + return err_wrong_args("OBJECT"); + } + let key = match extract_key(&args[1]) { + Some(k) => k, + None => return err_wrong_args("OBJECT"), + }; + match db.get(key) { + // Moon doesn't use reference counting — always return 1 + Some(_) => Frame::Integer(1), + None => Frame::Error(Bytes::from_static(b"ERR no such key")), + } } else if subcommand.eq_ignore_ascii_case(b"HELP") { Frame::Array(framevec![ Frame::BulkString(Bytes::from_static(b"OBJECT ENCODING ")), Frame::BulkString(Bytes::from_static( b" Return the encoding of the object stored at ." )), + Frame::BulkString(Bytes::from_static(b"OBJECT FREQ ")), + Frame::BulkString(Bytes::from_static( + b" Return the access frequency of the object at ." + )), + Frame::BulkString(Bytes::from_static(b"OBJECT IDLETIME ")), + Frame::BulkString(Bytes::from_static( + b" Return the idle time in seconds of the object at ." + )), + Frame::BulkString(Bytes::from_static(b"OBJECT REFCOUNT ")), + Frame::BulkString(Bytes::from_static( + b" Return the reference count of the object at ." + )), Frame::BulkString(Bytes::from_static(b"OBJECT HELP")), Frame::BulkString(Bytes::from_static(b" Return subcommand help.")), ]) diff --git a/src/command/metadata.rs b/src/command/metadata.rs index 0adb16bf..bdf566de 100644 --- a/src/command/metadata.rs +++ b/src/command/metadata.rs @@ -301,6 +301,8 @@ pub static COMMAND_META: phf::Map<&'static str, CommandMeta> = phf_map! { "BITOP" => CommandMeta { name: "BITOP", arity: -4, flags: W, first_key: 2, last_key: -1, step: 1, acl_categories: STR }, "BITFIELD" => CommandMeta { name: "BITFIELD", arity: -2, flags: W, first_key: 1, last_key: 1, step: 1, acl_categories: STR }, "BITPOS" => CommandMeta { name: "BITPOS", arity: -3, flags: R, first_key: 1, last_key: 1, step: 1, acl_categories: STR }, + "LCS" => CommandMeta { name: "LCS", arity: -3, flags: R, first_key: 1, last_key: 2, step: 1, acl_categories: STR }, + "XSETID" => CommandMeta { name: "XSETID", arity: -3, flags: W, first_key: 1, last_key: 1, step: 1, acl_categories: STR }, // ---- HyperLogLog commands ---- "PFADD" => CommandMeta { name: "PFADD", arity: -2, flags: WF, first_key: 1, last_key: 1, step: 1, acl_categories: STR }, diff --git a/src/command/mod.rs b/src/command/mod.rs index 7bd9c139..72114c51 100644 --- a/src/command/mod.rs +++ b/src/command/mod.rs @@ -71,6 +71,12 @@ fn dispatch_inner( match (len, b0) { // 3-letter commands + (3, b'l') => { + // LCS + if cmd.eq_ignore_ascii_case(b"LCS") { + return resp(string::lcs(db, args)); + } + } (3, b'd') => { // DEL if cmd.eq_ignore_ascii_case(b"DEL") { @@ -439,7 +445,7 @@ fn dispatch_inner( } } (6, b'l') => { - // LRANGE LINDEX + // LRANGE LINDEX LPUSHX LOLWUT if cmd.eq_ignore_ascii_case(b"LRANGE") { return resp(list::lrange(db, args)); } @@ -449,6 +455,9 @@ fn dispatch_inner( if cmd.eq_ignore_ascii_case(b"LPUSHX") { return resp(list::lpushx(db, args)); } + if cmd.eq_ignore_ascii_case(b"LOLWUT") { + return resp(Frame::BulkString(Bytes::from_static(b"Moon v0.1.0\n"))); + } } (6, b'm') => { // MEMORY @@ -525,6 +534,14 @@ fn dispatch_inner( return resp(string::getrange(db, args)); } } + b'w' => { + if cmd.eq_ignore_ascii_case(b"SWAPDB") { + // SWAPDB requires cross-database access not available in dispatch + return resp(Frame::Error(Bytes::from_static( + b"ERR SWAPDB is not supported in sharded mode", + ))); + } + } _ => {} } } @@ -535,7 +552,10 @@ fn dispatch_inner( } } (6, b'x') => { - // XRANGE XGROUP XCLAIM + // XRANGE XGROUP XCLAIM XSETID + if cmd.eq_ignore_ascii_case(b"XSETID") { + return resp(stream::xsetid(db, args)); + } if cmd.eq_ignore_ascii_case(b"XRANGE") { return resp(stream::xrange(db, args)); } @@ -670,6 +690,9 @@ fn dispatch_inner( if cmd.eq_ignore_ascii_case(b"BITCOUNT") { return resp(string::bitcount(db, args)); } + if cmd.eq_ignore_ascii_case(b"BITFIELD") { + return resp(string::bitfield(db, args)); + } } (8, b'r') => { // RENAMENX @@ -706,10 +729,13 @@ fn dispatch_inner( } // 9-letter commands (9, b'g') => { - // GEOSEARCH + // GEOSEARCH GEORADIUS if cmd.eq_ignore_ascii_case(b"GEOSEARCH") { return resp(geo::geosearch(db, args)); } + if cmd.eq_ignore_ascii_case(b"GEORADIUS") { + return resp(geo::georadius(db, args)); + } } (9, b'p') => { // PEXPIREAT @@ -852,6 +878,13 @@ fn dispatch_inner( return resp(sorted_set::zrevrangebyscore(db, args)); } } + // 17-letter commands + (17, b'g') => { + // GEORADIUSBYMEMBER + if cmd.eq_ignore_ascii_case(b"GEORADIUSBYMEMBER") { + return resp(geo::georadiusbymember(db, args)); + } + } _ => {} } diff --git a/src/command/stream/stream_write.rs b/src/command/stream/stream_write.rs index 970d599c..0cc9c6e4 100644 --- a/src/command/stream/stream_write.rs +++ b/src/command/stream/stream_write.rs @@ -1,4 +1,4 @@ -//! Stream write command handlers: XADD, XDEL, XTRIM, XACK, XCLAIM, XAUTOCLAIM, XGROUP, XREADGROUP. +//! Stream write command handlers: XADD, XDEL, XTRIM, XACK, XCLAIM, XAUTOCLAIM, XGROUP, XREADGROUP, XSETID. use bytes::Bytes; @@ -825,3 +825,33 @@ pub fn xautoclaim(db: &mut Database, args: &[Frame]) -> Frame { Err(e) => Frame::Error(Bytes::from(e)), } } + +/// XSETID key last-id [ENTRIESADDED entries-added] +/// +/// Sets the last delivered ID of a stream without adding entries. +pub fn xsetid(db: &mut Database, args: &[Frame]) -> Frame { + if args.len() < 2 { + return err_wrong_args("XSETID"); + } + let key = match extract_bytes(&args[0]) { + Some(k) => k, + None => return err_wrong_args("XSETID"), + }; + let id_str = match extract_bytes(&args[1]) { + Some(s) => s, + None => return err_wrong_args("XSETID"), + }; + let id = match StreamId::parse(id_str, 0) { + Ok(id) => id, + Err(e) => return Frame::Error(Bytes::from_static(e.as_bytes())), + }; + + match db.get_stream_mut(key) { + Ok(Some(stream)) => { + stream.last_id = id; + Frame::SimpleString(Bytes::from_static(b"OK")) + } + Ok(None) => Frame::Error(Bytes::from_static(b"ERR no such key")), + Err(e) => e, + } +} diff --git a/src/command/string/string_bit.rs b/src/command/string/string_bit.rs index 9dfc9993..ed604896 100644 --- a/src/command/string/string_bit.rs +++ b/src/command/string/string_bit.rs @@ -745,6 +745,271 @@ fn count_bits_in_range(data: &[u8], start_bit: usize, end_bit: usize) -> u32 { count } +/// BITFIELD key [GET encoding offset] [SET encoding offset value] +/// [INCRBY encoding offset increment] [OVERFLOW WRAP|SAT|FAIL] +/// +/// Treat a string as an array of packed integers of configurable width. +pub fn bitfield(db: &mut Database, args: &[Frame]) -> Frame { + if args.is_empty() { + return err_wrong_args("BITFIELD"); + } + let key = match extract_bytes(&args[0]) { + Some(k) => k.clone(), + None => return err_wrong_args("BITFIELD"), + }; + + let base_ts = db.base_timestamp(); + let (existing_data, existing_expiry_ms) = match db.get(&key) { + Some(entry) => { + let expiry = entry.expires_at_ms(base_ts); + match entry.value.as_bytes() { + Some(v) => (v.to_vec(), expiry), + None => { + return Frame::Error(Bytes::from_static( + b"WRONGTYPE Operation against a key holding the wrong kind of value", + )); + } + } + } + None => (Vec::new(), 0), + }; + + let mut buf = existing_data; + let mut results = Vec::new(); + let mut overflow = Overflow::Wrap; + let mut modified = false; + let mut i = 1; + + while i < args.len() { + let subcmd = match extract_bytes(&args[i]) { + Some(s) => s, + None => { + i += 1; + continue; + } + }; + + if subcmd.eq_ignore_ascii_case(b"OVERFLOW") { + i += 1; + let mode = match args.get(i).and_then(|f| extract_bytes(f)) { + Some(m) => m, + None => return Frame::Error(Bytes::from_static(b"ERR syntax error")), + }; + if mode.eq_ignore_ascii_case(b"WRAP") { + overflow = Overflow::Wrap; + } else if mode.eq_ignore_ascii_case(b"SAT") { + overflow = Overflow::Sat; + } else if mode.eq_ignore_ascii_case(b"FAIL") { + overflow = Overflow::Fail; + } else { + return Frame::Error(Bytes::from_static(b"ERR syntax error")); + } + i += 1; + continue; + } + + // Parse encoding and offset + let enc = match args.get(i + 1).and_then(|f| extract_bytes(f)) { + Some(e) => e, + None => return Frame::Error(Bytes::from_static(b"ERR syntax error")), + }; + let offset_arg = match args.get(i + 2).and_then(|f| extract_bytes(f)) { + Some(o) => o, + None => return Frame::Error(Bytes::from_static(b"ERR syntax error")), + }; + + let (signed, bits) = match parse_encoding(enc) { + Some(v) => v, + None => { + return Frame::Error(Bytes::from_static( + b"ERR Invalid bitfield type. Use something like i8 u8 i16 u16 ...", + )); + } + }; + let bit_offset = match parse_bit_offset(offset_arg, bits) { + Some(v) => v, + None => { + return Frame::Error(Bytes::from_static( + b"ERR bit offset is not an integer or out of range", + )); + } + }; + + if subcmd.eq_ignore_ascii_case(b"GET") { + let val = bf_get(&buf, bit_offset, bits, signed); + results.push(Frame::Integer(val)); + i += 3; + } else if subcmd.eq_ignore_ascii_case(b"SET") { + let value = match args.get(i + 3).and_then(|f| parse_i64(f)) { + Some(v) => v, + None => return Frame::Error(Bytes::from_static(b"ERR syntax error")), + }; + let old = bf_get(&buf, bit_offset, bits, signed); + bf_set(&mut buf, bit_offset, bits, value); + results.push(Frame::Integer(old)); + modified = true; + i += 4; + } else if subcmd.eq_ignore_ascii_case(b"INCRBY") { + let increment = match args.get(i + 3).and_then(|f| parse_i64(f)) { + Some(v) => v, + None => return Frame::Error(Bytes::from_static(b"ERR syntax error")), + }; + let old = bf_get(&buf, bit_offset, bits, signed); + let (new_val, overflowed) = bf_incr(old, increment, bits, signed); + if overflowed && matches!(overflow, Overflow::Fail) { + results.push(Frame::Null); + } else { + let clamped = if overflowed && matches!(overflow, Overflow::Sat) { + bf_saturate(old, increment, bits, signed) + } else { + new_val + }; + bf_set(&mut buf, bit_offset, bits, clamped); + results.push(Frame::Integer(clamped)); + modified = true; + } + i += 4; + } else { + return Frame::Error(Bytes::from_static(b"ERR syntax error")); + } + } + + if modified { + let new_val = Bytes::from(buf); + let mut entry = if existing_expiry_ms > 0 { + Entry::new_string_with_expiry(new_val, existing_expiry_ms, base_ts) + } else { + Entry::new_string(new_val) + }; + entry.set_last_access(db.now()); + entry.set_access_counter(5); + db.set(key, entry); + } + + Frame::Array(results.into()) +} + +#[derive(Clone, Copy)] +enum Overflow { + Wrap, + Sat, + Fail, +} + +/// Parse encoding like "u8", "i16", "u32" etc. Returns (signed, bits). +fn parse_encoding(enc: &[u8]) -> Option<(bool, u32)> { + if enc.is_empty() { + return None; + } + let signed = match enc[0] | 0x20 { + b'i' => true, + b'u' => false, + _ => return None, + }; + let bits: u32 = std::str::from_utf8(&enc[1..]).ok()?.parse().ok()?; + if bits == 0 || bits > 64 || (!signed && bits > 63) { + return None; + } + Some((signed, bits)) +} + +/// Parse bit offset: plain number or `#N` (type-multiplied). +fn parse_bit_offset(arg: &[u8], type_bits: u32) -> Option { + if arg.first() == Some(&b'#') { + let n: usize = std::str::from_utf8(&arg[1..]).ok()?.parse().ok()?; + Some(n * type_bits as usize) + } else { + let n: usize = std::str::from_utf8(arg).ok()?.parse().ok()?; + Some(n) + } +} + +/// Read `bits` bits starting at `bit_offset` from `data`, interpreting as signed/unsigned. +fn bf_get(data: &[u8], bit_offset: usize, bits: u32, signed: bool) -> i64 { + let mut val: u64 = 0; + for b in 0..bits as usize { + let pos = bit_offset + b; + let byte_idx = pos / 8; + let bit_idx = 7 - (pos % 8); + if byte_idx < data.len() && (data[byte_idx] >> bit_idx) & 1 == 1 { + val |= 1 << (bits as usize - 1 - b); + } + } + if signed && bits < 64 && val & (1 << (bits - 1)) != 0 { + // Sign extend + val |= !0u64 << bits; + } + val as i64 +} + +/// Write `bits` bits starting at `bit_offset` into `data`. +fn bf_set(data: &mut Vec, bit_offset: usize, bits: u32, value: i64) { + let val = value as u64; + let needed_bytes = (bit_offset + bits as usize + 7) / 8; + if data.len() < needed_bytes { + data.resize(needed_bytes, 0); + } + for b in 0..bits as usize { + let pos = bit_offset + b; + let byte_idx = pos / 8; + let bit_idx = 7 - (pos % 8); + if val & (1 << (bits as usize - 1 - b)) != 0 { + data[byte_idx] |= 1 << bit_idx; + } else { + data[byte_idx] &= !(1 << bit_idx); + } + } +} + +/// Increment with overflow detection. Returns (result, overflowed). +fn bf_incr(old: i64, incr: i64, bits: u32, signed: bool) -> (i64, bool) { + if signed { + let min = -(1i64 << (bits - 1)); + let max = (1i64 << (bits - 1)) - 1; + match old.checked_add(incr) { + Some(v) if v >= min && v <= max => (v, false), + Some(v) => { + // Wrap + let range = 1i64 << bits; + let wrapped = ((v - min) % range + range) % range + min; + (wrapped, true) + } + None => { + // Overflow in i64 itself — definitely overflowed + let wrapped = old.wrapping_add(incr); + (wrapped, true) + } + } + } else { + let max = if bits >= 64 { + u64::MAX + } else { + (1u64 << bits) - 1 + }; + let old_u = old as u64 & max; + let incr_u = incr as u64; + let sum = old_u.wrapping_add(incr_u); + let masked = sum & max; + (masked as i64, masked != sum || (incr < 0 && sum > old_u)) + } +} + +/// Saturate to min/max of the type. +fn bf_saturate(_old: i64, incr: i64, bits: u32, signed: bool) -> i64 { + if signed { + let min = -(1i64 << (bits - 1)); + let max = (1i64 << (bits - 1)) - 1; + if incr > 0 { max } else { min } + } else { + let max = if bits >= 64 { + i64::MAX + } else { + ((1u64 << bits) - 1) as i64 + }; + if incr > 0 { max } else { 0 } + } +} + #[cfg(test)] mod tests { use super::*; diff --git a/src/command/string/string_read.rs b/src/command/string/string_read.rs index 9d5729e3..93cd7477 100644 --- a/src/command/string/string_read.rs +++ b/src/command/string/string_read.rs @@ -374,3 +374,92 @@ pub fn strlen_readonly(db: &Database, args: &[Frame], now_ms: u64) -> Frame { None => Frame::Integer(0), } } + +/// LCS key1 key2 [LEN] [IDX] [MINMATCHLEN len] [WITHMATCHLEN] +/// +/// Returns the longest common substring of two string values. +pub fn lcs(db: &mut Database, args: &[Frame]) -> Frame { + if args.len() < 2 { + return err_wrong_args("LCS"); + } + let key1 = match extract_bytes(&args[0]) { + Some(k) => k, + None => return err_wrong_args("LCS"), + }; + let key2 = match extract_bytes(&args[1]) { + Some(k) => k, + None => return err_wrong_args("LCS"), + }; + + let mut want_len = false; + let mut i = 2; + while i < args.len() { + let arg = match extract_bytes(&args[i]) { + Some(a) => a, + None => return Frame::Error(Bytes::from_static(b"ERR syntax error")), + }; + if arg.eq_ignore_ascii_case(b"LEN") { + want_len = true; + } + i += 1; + } + + let s1 = match db.get(key1) { + Some(e) => match e.value.as_bytes() { + Some(v) => v.to_vec(), + None => { + return Frame::Error(Bytes::from_static( + b"WRONGTYPE Operation against a key holding the wrong kind of value", + )); + } + }, + None => Vec::new(), + }; + let s2 = match db.get(key2) { + Some(e) => match e.value.as_bytes() { + Some(v) => v.to_vec(), + None => { + return Frame::Error(Bytes::from_static( + b"WRONGTYPE Operation against a key holding the wrong kind of value", + )); + } + }, + None => Vec::new(), + }; + + let n = s1.len(); + let m = s2.len(); + let mut dp = vec![vec![0u32; m + 1]; n + 1]; + for ii in 1..=n { + for jj in 1..=m { + if s1[ii - 1] == s2[jj - 1] { + dp[ii][jj] = dp[ii - 1][jj - 1] + 1; + } else { + dp[ii][jj] = dp[ii - 1][jj].max(dp[ii][jj - 1]); + } + } + } + let lcs_len = dp[n][m] as usize; + + if want_len { + return Frame::Integer(lcs_len as i64); + } + + // Backtrack to find LCS string + let mut lcs_bytes = Vec::with_capacity(lcs_len); + let mut ci = n; + let mut cj = m; + while ci > 0 && cj > 0 { + if s1[ci - 1] == s2[cj - 1] { + lcs_bytes.push(s1[ci - 1]); + ci -= 1; + cj -= 1; + } else if dp[ci - 1][cj] > dp[ci][cj - 1] { + ci -= 1; + } else { + cj -= 1; + } + } + lcs_bytes.reverse(); + Frame::BulkString(Bytes::from(lcs_bytes)) +} From 5d122f425cb8ee01853a62f6f424d3e8809cfc8b Mon Sep 17 00:00:00 2001 From: Tin Dang Date: Sat, 11 Apr 2026 09:54:34 +0700 Subject: [PATCH 17/20] docs: add remaining gap commands to CHANGELOG --- CHANGELOG.md | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index f255e3e9..c5d1519c 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -25,6 +25,12 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - **RANDOMKEY** — return a random key from the database. - **TOUCH** — refresh LRU/LFU access time without reading value. - **SHUTDOWN** — dispatch entry (graceful stop via signal handler). +- **BITFIELD** — GET/SET/INCRBY with type specifiers (u8/i16/u32/...), OVERFLOW WRAP/SAT/FAIL. +- **LCS** — Longest Common Substring with LEN option. +- **XSETID** — set stream last-delivered ID without adding entries. +- **GEORADIUS/GEORADIUSBYMEMBER** — deprecated wrappers translating to GEOSEARCH. +- **OBJECT FREQ/IDLETIME/REFCOUNT** — LFU counter, idle seconds, reference count introspection. +- **LOLWUT** — Easter egg returning Moon version. ### Fixed — Wave 0-4 Gap Closure (2026-04-09) From 9649ec6971c2a313c223c634f66bb7eebd00de9d Mon Sep 17 00:00:00 2001 From: Tin Dang Date: Sat, 11 Apr 2026 12:14:49 +0700 Subject: [PATCH 18/20] fix: SORT STORE preserves nil GET results as empty strings MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Previously Frame::Null from GET pattern lookups were silently dropped, causing stored list length to diverge from returned count. Redis stores nil results as empty strings — now Moon does the same. --- src/command/key_extra.rs | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/src/command/key_extra.rs b/src/command/key_extra.rs index 8328fc72..873a7f75 100644 --- a/src/command/key_extra.rs +++ b/src/command/key_extra.rs @@ -333,8 +333,10 @@ pub fn sort(db: &mut Database, args: &[Frame]) -> Frame { let count = results.len() as i64; let mut list = std::collections::VecDeque::with_capacity(results.len()); for frame in results { - if let Frame::BulkString(b) = frame { - list.push_back(b); + match frame { + Frame::BulkString(b) => list.push_back(b), + // Redis stores nil GET results as empty strings in SORT...STORE + _ => list.push_back(Bytes::new()), } } let mut entry = crate::storage::entry::Entry::new_list(); From 813c06807cdb42615dde3412bff2dcc173ac5a20 Mon Sep 17 00:00:00 2001 From: Tin Dang Date: Sat, 11 Apr 2026 12:49:10 +0700 Subject: [PATCH 19/20] fix: CLIENT PAUSE/UNPAUSE, COPY same-key, RANDOMKEY expiry, MEMORY SAMPLES, geo test layout - Implement CLIENT PAUSE timeout [WRITE|ALL] and CLIENT UNPAUSE - Fix COPY same-key returning 0 instead of 1 with REPLACE flag - RANDOMKEY now filters expired keys instead of returning stale entries - MEMORY USAGE accepts optional SAMPLES count parameter - Fix XSETID acl category from STR to STM - Isolate config_rewrite test to unique temp dir - Move geo command tests from geo_cmd.rs to geo/mod.rs per split convention --- src/command/config.rs | 9 +- src/command/geo/geo_cmd.rs | 169 ------------------------------ src/command/geo/mod.rs | 166 +++++++++++++++++++++++++++++ src/command/key_extra.rs | 28 ++++- src/command/metadata.rs | 2 +- src/server/conn/handler_single.rs | 85 +++++++++++++++ src/storage/db.rs | 22 ++-- 7 files changed, 299 insertions(+), 182 deletions(-) diff --git a/src/command/config.rs b/src/command/config.rs index 228f021e..7c5cebd7 100644 --- a/src/command/config.rs +++ b/src/command/config.rs @@ -396,17 +396,20 @@ mod tests { #[test] fn test_config_rewrite() { + let tmp = std::env::temp_dir().join(format!("moon-test-{}", std::process::id())); + std::fs::create_dir_all(&tmp).unwrap(); + let mut rt = RuntimeConfig::default(); rt.maxmemory = 1_073_741_824; // 1GB rt.maxmemory_policy = "allkeys-lru".to_string(); - rt.dir = std::env::temp_dir().to_string_lossy().to_string(); + rt.dir = tmp.to_string_lossy().to_string(); let sc = default_server_config(); let result = config_rewrite(&rt, &sc); assert_eq!(result, Frame::SimpleString(Bytes::from_static(b"OK"))); // Verify file was created - let conf_path = std::path::Path::new(&rt.dir).join("moon.conf"); + let conf_path = tmp.join("moon.conf"); assert!(conf_path.exists()); let content = std::fs::read_to_string(&conf_path).unwrap(); assert!(content.contains("maxmemory 1073741824")); @@ -414,6 +417,6 @@ mod tests { assert!(content.contains("port 6379")); // Cleanup - let _ = std::fs::remove_file(conf_path); + let _ = std::fs::remove_dir_all(tmp); } } diff --git a/src/command/geo/geo_cmd.rs b/src/command/geo/geo_cmd.rs index c15a4af5..43a055e6 100644 --- a/src/command/geo/geo_cmd.rs +++ b/src/command/geo/geo_cmd.rs @@ -654,172 +654,3 @@ fn geosearch_inner(db: &mut Database, args: &[Frame], _store_mode: bool) -> (Vec (matches, Frame::Array(results.into())) } - -#[cfg(test)] -mod tests { - use super::*; - use crate::storage::Database; - - fn bs(s: &[u8]) -> Frame { - Frame::BulkString(Bytes::copy_from_slice(s)) - } - - #[test] - fn test_geoadd_and_geopos() { - let mut db = Database::new(); - let result = geoadd( - &mut db, - &[ - bs(b"mygeo"), - bs(b"13.361389"), - bs(b"38.115556"), - bs(b"Palermo"), - bs(b"15.087269"), - bs(b"37.502669"), - bs(b"Catania"), - ], - ); - assert_eq!(result, Frame::Integer(2)); - - let result = geopos(&mut db, &[bs(b"mygeo"), bs(b"Palermo"), bs(b"NonExistent")]); - match result { - Frame::Array(ref arr) => { - assert_eq!(arr.len(), 2); - assert!(matches!(&arr[0], Frame::Array(_))); - assert_eq!(arr[1], Frame::Null); - } - _ => panic!("Expected array"), - } - } - - #[test] - fn test_geodist() { - let mut db = Database::new(); - geoadd( - &mut db, - &[ - bs(b"mygeo"), - bs(b"13.361389"), - bs(b"38.115556"), - bs(b"Palermo"), - bs(b"15.087269"), - bs(b"37.502669"), - bs(b"Catania"), - ], - ); - let result = geodist( - &mut db, - &[bs(b"mygeo"), bs(b"Palermo"), bs(b"Catania"), bs(b"km")], - ); - match result { - Frame::BulkString(b) => { - let dist: f64 = std::str::from_utf8(&b).unwrap().parse().unwrap(); - assert!((dist - 166.2742).abs() < 1.0, "got {dist}"); - } - _ => panic!("Expected bulk string"), - } - } - - #[test] - fn test_geohash() { - let mut db = Database::new(); - geoadd( - &mut db, - &[ - bs(b"mygeo"), - bs(b"13.361389"), - bs(b"38.115556"), - bs(b"Palermo"), - ], - ); - let result = geohash(&mut db, &[bs(b"mygeo"), bs(b"Palermo")]); - match result { - Frame::Array(ref arr) => { - assert_eq!(arr.len(), 1); - match &arr[0] { - Frame::BulkString(b) => { - assert_eq!(b.len(), 11); - } - _ => panic!("Expected bulk string"), - } - } - _ => panic!("Expected array"), - } - } - - #[test] - fn test_geosearch_byradius() { - let mut db = Database::new(); - geoadd( - &mut db, - &[ - bs(b"mygeo"), - bs(b"13.361389"), - bs(b"38.115556"), - bs(b"Palermo"), - bs(b"15.087269"), - bs(b"37.502669"), - bs(b"Catania"), - bs(b"2.349014"), - bs(b"48.864716"), - bs(b"Paris"), - ], - ); - - let result = geosearch( - &mut db, - &[ - bs(b"mygeo"), - bs(b"FROMLONLAT"), - bs(b"15"), - bs(b"37"), - bs(b"BYRADIUS"), - bs(b"200"), - bs(b"km"), - bs(b"ASC"), - ], - ); - match result { - Frame::Array(ref arr) => { - // Should find Catania and Palermo (within 200km of 15,37), not Paris - assert_eq!(arr.len(), 2); - } - _ => panic!("Expected array, got {:?}", result), - } - } - - #[test] - fn test_geoadd_nx_xx() { - let mut db = Database::new(); - geoadd( - &mut db, - &[bs(b"g"), bs(b"10.0"), bs(b"20.0"), bs(b"member1")], - ); - - // NX should not update existing - let result = geoadd( - &mut db, - &[ - bs(b"g"), - bs(b"NX"), - bs(b"11.0"), - bs(b"21.0"), - bs(b"member1"), - ], - ); - assert_eq!(result, Frame::Integer(0)); - - // NX should add new - let result = geoadd( - &mut db, - &[ - bs(b"g"), - bs(b"NX"), - bs(b"12.0"), - bs(b"22.0"), - bs(b"member2"), - ], - ); - assert_eq!(result, Frame::Integer(1)); - } -} diff --git a/src/command/geo/mod.rs b/src/command/geo/mod.rs index 2bcc2791..60259bc9 100644 --- a/src/command/geo/mod.rs +++ b/src/command/geo/mod.rs @@ -131,6 +131,13 @@ pub(crate) fn parse_unit(unit: &[u8]) -> Option { #[cfg(test)] mod tests { use super::*; + use bytes::Bytes; + use crate::protocol::Frame; + use crate::storage::Database; + + fn bs(s: &[u8]) -> Frame { + Frame::BulkString(Bytes::copy_from_slice(s)) + } #[test] fn test_geohash_roundtrip() { @@ -163,4 +170,163 @@ mod tests { "invalid chars: {s}" ); } + + #[test] + fn test_geoadd_and_geopos() { + let mut db = Database::new(); + let result = geoadd( + &mut db, + &[ + bs(b"mygeo"), + bs(b"13.361389"), + bs(b"38.115556"), + bs(b"Palermo"), + bs(b"15.087269"), + bs(b"37.502669"), + bs(b"Catania"), + ], + ); + assert_eq!(result, Frame::Integer(2)); + + let result = geopos(&mut db, &[bs(b"mygeo"), bs(b"Palermo"), bs(b"NonExistent")]); + match result { + Frame::Array(ref arr) => { + assert_eq!(arr.len(), 2); + assert!(matches!(&arr[0], Frame::Array(_))); + assert_eq!(arr[1], Frame::Null); + } + _ => panic!("Expected array"), + } + } + + #[test] + fn test_geodist() { + let mut db = Database::new(); + geoadd( + &mut db, + &[ + bs(b"mygeo"), + bs(b"13.361389"), + bs(b"38.115556"), + bs(b"Palermo"), + bs(b"15.087269"), + bs(b"37.502669"), + bs(b"Catania"), + ], + ); + let result = geodist( + &mut db, + &[bs(b"mygeo"), bs(b"Palermo"), bs(b"Catania"), bs(b"km")], + ); + match result { + Frame::BulkString(b) => { + let dist: f64 = std::str::from_utf8(&b).unwrap().parse().unwrap(); + assert!((dist - 166.2742).abs() < 1.0, "got {dist}"); + } + _ => panic!("Expected bulk string"), + } + } + + #[test] + fn test_geohash() { + let mut db = Database::new(); + geoadd( + &mut db, + &[ + bs(b"mygeo"), + bs(b"13.361389"), + bs(b"38.115556"), + bs(b"Palermo"), + ], + ); + let result = geohash(&mut db, &[bs(b"mygeo"), bs(b"Palermo")]); + match result { + Frame::Array(ref arr) => { + assert_eq!(arr.len(), 1); + match &arr[0] { + Frame::BulkString(b) => { + assert_eq!(b.len(), 11); + } + _ => panic!("Expected bulk string"), + } + } + _ => panic!("Expected array"), + } + } + + #[test] + fn test_geosearch_byradius() { + let mut db = Database::new(); + geoadd( + &mut db, + &[ + bs(b"mygeo"), + bs(b"13.361389"), + bs(b"38.115556"), + bs(b"Palermo"), + bs(b"15.087269"), + bs(b"37.502669"), + bs(b"Catania"), + bs(b"2.349014"), + bs(b"48.864716"), + bs(b"Paris"), + ], + ); + + let result = geosearch( + &mut db, + &[ + bs(b"mygeo"), + bs(b"FROMLONLAT"), + bs(b"15"), + bs(b"37"), + bs(b"BYRADIUS"), + bs(b"200"), + bs(b"km"), + bs(b"ASC"), + ], + ); + match result { + Frame::Array(ref arr) => { + // Should find Catania and Palermo (within 200km of 15,37), not Paris + assert_eq!(arr.len(), 2); + } + _ => panic!("Expected array, got {:?}", result), + } + } + + #[test] + fn test_geoadd_nx_xx() { + let mut db = Database::new(); + geoadd( + &mut db, + &[bs(b"g"), bs(b"10.0"), bs(b"20.0"), bs(b"member1")], + ); + + // NX should not update existing + let result = geoadd( + &mut db, + &[ + bs(b"g"), + bs(b"NX"), + bs(b"11.0"), + bs(b"21.0"), + bs(b"member1"), + ], + ); + assert_eq!(result, Frame::Integer(0)); + + // NX should add new + let result = geoadd( + &mut db, + &[ + bs(b"g"), + bs(b"NX"), + bs(b"12.0"), + bs(b"22.0"), + bs(b"member2"), + ], + ); + assert_eq!(result, Frame::Integer(1)); + } } diff --git a/src/command/key_extra.rs b/src/command/key_extra.rs index 873a7f75..5295f9db 100644 --- a/src/command/key_extra.rs +++ b/src/command/key_extra.rs @@ -54,9 +54,9 @@ pub fn copy(db: &mut Database, args: &[Frame]) -> Frame { return Frame::Integer(0); } - // Same key: source == dest, nothing to do + // Same key: source == dest — Redis returns 1 only with REPLACE, else 0 if src == dst { - return Frame::Integer(0); + return Frame::Integer(if replace { 1 } else { 0 }); } // Check if destination exists @@ -87,6 +87,30 @@ pub fn memory_usage(db: &mut Database, args: &[Frame]) -> Frame { None => return err_wrong_args("MEMORY"), }; + // Parse optional SAMPLES count; reject unknown trailing args + let mut i = 1; + let mut _samples: usize = 5; // default sample count (like Redis) + while i < args.len() { + let arg = match extract_key(&args[i]) { + Some(a) => a, + None => return err_wrong_args("MEMORY"), + }; + if arg.eq_ignore_ascii_case(b"SAMPLES") { + i += 1; + let count_arg = match args.get(i).and_then(|f| extract_key(f)) { + Some(c) => c, + None => return err_wrong_args("MEMORY"), + }; + match std::str::from_utf8(count_arg).ok().and_then(|s| s.parse::().ok()) { + Some(c) if c > 0 => _samples = c, + _ => return err_wrong_args("MEMORY"), + } + } else { + return err_wrong_args("MEMORY"); + } + i += 1; + } + match db.get(key) { Some(entry) => { let mem = key.len() + entry.value.estimate_memory() + 128; // same as entry_overhead diff --git a/src/command/metadata.rs b/src/command/metadata.rs index bdf566de..919ce06b 100644 --- a/src/command/metadata.rs +++ b/src/command/metadata.rs @@ -302,7 +302,7 @@ pub static COMMAND_META: phf::Map<&'static str, CommandMeta> = phf_map! { "BITFIELD" => CommandMeta { name: "BITFIELD", arity: -2, flags: W, first_key: 1, last_key: 1, step: 1, acl_categories: STR }, "BITPOS" => CommandMeta { name: "BITPOS", arity: -3, flags: R, first_key: 1, last_key: 1, step: 1, acl_categories: STR }, "LCS" => CommandMeta { name: "LCS", arity: -3, flags: R, first_key: 1, last_key: 2, step: 1, acl_categories: STR }, - "XSETID" => CommandMeta { name: "XSETID", arity: -3, flags: W, first_key: 1, last_key: 1, step: 1, acl_categories: STR }, + "XSETID" => CommandMeta { name: "XSETID", arity: -3, flags: W, first_key: 1, last_key: 1, step: 1, acl_categories: STM }, // ---- HyperLogLog commands ---- "PFADD" => CommandMeta { name: "PFADD", arity: -2, flags: WF, first_key: 1, last_key: 1, step: 1, acl_categories: STR }, diff --git a/src/server/conn/handler_single.rs b/src/server/conn/handler_single.rs index 93b29207..7e36f431 100644 --- a/src/server/conn/handler_single.rs +++ b/src/server/conn/handler_single.rs @@ -498,6 +498,58 @@ pub async fn handle_connection( } } } + if sub_bytes.eq_ignore_ascii_case(b"PAUSE") { + // CLIENT PAUSE timeout [WRITE|ALL] + if cmd_args.len() < 2 { + responses.push(Frame::Error(Bytes::from_static( + b"ERR wrong number of arguments for 'CLIENT PAUSE' command", + ))); + } else { + let timeout_ms = match extract_bytes(&cmd_args[1]) { + Some(b) => std::str::from_utf8(&b).ok().and_then(|s| s.parse::().ok()), + None => None, + }; + let mode_valid = if cmd_args.len() >= 3 { + match extract_bytes(&cmd_args[2]) { + Some(b) => b.eq_ignore_ascii_case(b"WRITE") || b.eq_ignore_ascii_case(b"ALL"), + None => false, + } + } else { + true + }; + if cmd_args.len() > 3 || !mode_valid { + responses.push(Frame::Error(Bytes::from_static( + b"ERR syntax error", + ))); + } else { + match timeout_ms { + Some(ms) => { + let write_only = cmd_args.get(2) + .and_then(|f| extract_bytes(f)) + .is_some_and(|b| b.eq_ignore_ascii_case(b"WRITE")); + let deadline = crate::storage::entry::current_time_ms().saturating_add(ms); + let mut rt = runtime_config.write(); + rt.client_pause_deadline_ms = deadline; + rt.client_pause_write_only = write_only; + responses.push(Frame::SimpleString(Bytes::from_static(b"OK"))); + } + None => { + responses.push(Frame::Error(Bytes::from_static( + b"ERR timeout is not a valid integer or out of range", + ))); + } + } + } + } + continue; + } + if sub_bytes.eq_ignore_ascii_case(b"UNPAUSE") { + let mut rt = runtime_config.write(); + rt.client_pause_deadline_ms = 0; + rt.client_pause_write_only = false; + responses.push(Frame::SimpleString(Bytes::from_static(b"OK"))); + continue; + } // Unknown CLIENT subcommand responses.push(Frame::Error(Bytes::from(format!( "ERR unknown subcommand '{}'", @@ -942,6 +994,39 @@ pub async fn handle_connection( continue; } } + + // === CLIENT PAUSE check === + let pause_wait_ms = { + let rt = runtime_config.read(); + let deadline = rt.client_pause_deadline_ms; + if deadline > 0 { + let now = crate::storage::entry::current_time_ms(); + if now < deadline { + let should_pause = if rt.client_pause_write_only { + metadata::is_write(cmd) + } else { + true + }; + if should_pause { deadline.saturating_sub(now) } else { 0 } + } else { 0 } + } else { 0 } + }; + if pause_wait_ms > 0 { + let mut remaining = pause_wait_ms; + while remaining > 0 { + let chunk = remaining.min(50); + tokio::time::sleep(std::time::Duration::from_millis(chunk)).await; + remaining = remaining.saturating_sub(chunk); + let still_paused = { + let rt = runtime_config.read(); + rt.client_pause_deadline_ms > 0 + && crate::storage::entry::current_time_ms() < rt.client_pause_deadline_ms + }; + if !still_paused { + break; + } + } + } } // --- MULTI queue mode --- diff --git a/src/storage/db.rs b/src/storage/db.rs index 486ffaf3..ebf9f14f 100644 --- a/src/storage/db.rs +++ b/src/storage/db.rs @@ -290,17 +290,25 @@ impl Database { self.data.keys() } - /// Return a random key from the database, or None if empty. + /// Return a random non-expired key from the database, or None if empty. pub fn random_key(&self) -> Option { if self.data.is_empty() { return None; } - // Pick a random index via simple hash of current time - let idx = (current_time_ms() as usize) % self.data.len(); - self.data - .keys() - .nth(idx) - .map(|k| Bytes::copy_from_slice(k.as_ref())) + let now_ms = self.cached_now_ms; + let base_ts = self.base_timestamp; + // Collect non-expired keys (iterator is already O(n)) + let live: Vec<_> = self + .data + .iter() + .filter(|(_, e)| !e.is_expired_at(base_ts, now_ms)) + .map(|(k, _)| Bytes::copy_from_slice(k.as_ref())) + .collect(); + if live.is_empty() { + return None; + } + let idx = (current_time_ms() as usize) % live.len(); + Some(live.into_iter().nth(idx).unwrap_or_default()) } /// Set or remove expiration on an existing key. From 3572b1e00d4e020b81ef8d9a059c075d7adf6299 Mon Sep 17 00:00:00 2001 From: Tin Dang Date: Sun, 12 Apr 2026 01:28:09 +0700 Subject: [PATCH 20/20] fix: COPY same-key ERR, EXPIREAT accepts past timestamps, LCS OOM guard 1. COPY same-key: return ERR matching Redis 7.x behavior instead of returning 0/1 based on REPLACE flag. 2. EXPIREAT/PEXPIREAT: accept timestamp 0 and negative values as past-time expiry, deleting the key immediately. Previously rejected with ERR, diverging from Redis which uses these to conditionally delete keys. 3. LCS: add size guard (16M cells / 64MB) to prevent OOM from two large strings creating an unbounded O(n*m) DP table. --- src/command/key.rs | 28 ++++++++++++++++++++++------ src/command/key_extra.rs | 10 ++++++---- src/command/string/string_read.rs | 8 ++++++++ 3 files changed, 36 insertions(+), 10 deletions(-) diff --git a/src/command/key.rs b/src/command/key.rs index e170c338..7a373edb 100644 --- a/src/command/key.rs +++ b/src/command/key.rs @@ -229,14 +229,22 @@ pub fn expireat(db: &mut Database, args: &[Frame]) -> Frame { None => return err_wrong_args("EXPIREAT"), }; let timestamp = match parse_int(&args[1]) { - Some(n) if n > 0 => n as u64, - _ => { + Some(n) => n, + None => { return Frame::Error(Bytes::from_static( b"ERR invalid expire time in 'EXPIREAT' command", )); } }; - let expires_at_ms = timestamp * 1000; + // Redis accepts 0 and negative timestamps as past-time expiry (deletes key immediately) + if timestamp <= 0 { + return if db.remove(key).is_some() { + Frame::Integer(1) + } else { + Frame::Integer(0) + }; + } + let expires_at_ms = (timestamp as u64) * 1000; if db.set_expiry(key, expires_at_ms) { Frame::Integer(1) } else { @@ -256,14 +264,22 @@ pub fn pexpireat(db: &mut Database, args: &[Frame]) -> Frame { None => return err_wrong_args("PEXPIREAT"), }; let timestamp_ms = match parse_int(&args[1]) { - Some(n) if n > 0 => n as u64, - _ => { + Some(n) => n, + None => { return Frame::Error(Bytes::from_static( b"ERR invalid expire time in 'PEXPIREAT' command", )); } }; - if db.set_expiry(key, timestamp_ms) { + // Redis accepts 0 and negative timestamps as past-time expiry (deletes key immediately) + if timestamp_ms <= 0 { + return if db.remove(key).is_some() { + Frame::Integer(1) + } else { + Frame::Integer(0) + }; + } + if db.set_expiry(key, timestamp_ms as u64) { Frame::Integer(1) } else { Frame::Integer(0) diff --git a/src/command/key_extra.rs b/src/command/key_extra.rs index 5295f9db..6065757a 100644 --- a/src/command/key_extra.rs +++ b/src/command/key_extra.rs @@ -54,9 +54,11 @@ pub fn copy(db: &mut Database, args: &[Frame]) -> Frame { return Frame::Integer(0); } - // Same key: source == dest — Redis returns 1 only with REPLACE, else 0 + // Same key: Redis 7.x returns ERR for source == destination if src == dst { - return Frame::Integer(if replace { 1 } else { 0 }); + return Frame::Error(Bytes::from_static( + b"ERR source and destination objects are the same", + )); } // Check if destination exists @@ -450,8 +452,8 @@ mod tests { fn test_copy_same_key() { let mut db = setup_db_with_key(b"src", b"hello"); let result = copy(&mut db, &[bs(b"src"), bs(b"src")]); - // Redis returns 0 for same-key copy - assert_eq!(result, Frame::Integer(0)); + // Redis 7.x returns ERR for same-key copy + assert!(matches!(result, Frame::Error(_))); } #[test] diff --git a/src/command/string/string_read.rs b/src/command/string/string_read.rs index 93cd7477..422fdb31 100644 --- a/src/command/string/string_read.rs +++ b/src/command/string/string_read.rs @@ -429,6 +429,14 @@ pub fn lcs(db: &mut Database, args: &[Frame]) -> Frame { let n = s1.len(); let m = s2.len(); + // Guard against OOM: cap DP table size (n*m cells * 4 bytes each). + // 64MB cap ≈ 4096 * 4096 strings — generous for any real use case. + const MAX_LCS_CELLS: usize = 16 * 1024 * 1024; // 64MB at 4 bytes/cell + if n.saturating_mul(m) > MAX_LCS_CELLS { + return Frame::Error(Bytes::from_static( + b"ERR inputs too large for LCS computation", + )); + } let mut dp = vec![vec![0u32; m + 1]; n + 1]; for ii in 1..=n { for jj in 1..=m {