From afa1033cd6a1de575bf5af1036470e5e3617ffd2 Mon Sep 17 00:00:00 2001 From: Maxime Van Hees Date: Mon, 18 Aug 2025 16:13:46 +0200 Subject: [PATCH 1/8] implement BLPOP --- herodb/src/cmd.rs | 109 ++++++++++++++++++++++++++++++++++++++++++- herodb/src/server.rs | 90 +++++++++++++++++++++++++++++++++++ 2 files changed, 197 insertions(+), 2 deletions(-) diff --git a/herodb/src/cmd.rs b/herodb/src/cmd.rs index c036b4a..eef93c4 100644 --- a/herodb/src/cmd.rs +++ b/herodb/src/cmd.rs @@ -1,5 +1,7 @@ use crate::{error::DBError, protocol::Protocol, server::Server}; use serde::Serialize; +use tokio::time::{timeout, Duration}; +use futures::future::select_all; #[derive(Debug, Clone)] pub enum Cmd { @@ -43,6 +45,7 @@ pub enum Cmd { RPush(String, Vec), LPop(String, Option), RPop(String, Option), + BLPop(Vec, f64), LLen(String), LRem(String, i64, String), LTrim(String, i64, i64), @@ -376,6 +379,17 @@ impl Cmd { }; Cmd::RPop(cmd[1].clone(), count) } + "blpop" => { + if cmd.len() < 3 { + return Err(DBError(format!("wrong number of arguments for BLPOP command"))); + } + // keys are all but the last argument + let keys = cmd[1..cmd.len()-1].to_vec(); + let timeout_f = cmd[cmd.len()-1] + .parse::() + .map_err(|_| DBError("ERR timeout is not a number".to_string()))?; + Cmd::BLPop(keys, timeout_f) + } "llen" => { if cmd.len() != 2 { return Err(DBError(format!("wrong number of arguments for LLEN command"))); @@ -531,6 +545,7 @@ impl Cmd { Cmd::RPush(key, elements) => rpush_cmd(server, &key, &elements).await, Cmd::LPop(key, count) => lpop_cmd(server, &key, &count).await, Cmd::RPop(key, count) => rpop_cmd(server, &key, &count).await, + Cmd::BLPop(keys, timeout) => blpop_cmd(server, &keys, timeout).await, Cmd::LLen(key) => llen_cmd(server, &key).await, Cmd::LRem(key, count, element) => lrem_cmd(server, &key, count, &element).await, Cmd::LTrim(key, start, stop) => ltrim_cmd(server, &key, start, stop).await, @@ -661,16 +676,106 @@ async fn rpop_cmd(server: &Server, key: &str, count: &Option) -> Result Result { + // Immediate, non-blocking attempt in key order + for k in keys { + let elems = server.current_storage()?.lpop(k, 1)?; + if !elems.is_empty() { + return Ok(Protocol::Array(vec![ + Protocol::BulkString(k.clone()), + Protocol::BulkString(elems[0].clone()), + ])); + } + } + + // If timeout is zero, return immediately with Null + if timeout_secs <= 0.0 { + return Ok(Protocol::Null); + } + + // Register waiters for each key + let db_index = server.selected_db; + let mut ids: Vec = Vec::with_capacity(keys.len()); + let mut names: Vec = Vec::with_capacity(keys.len()); + let mut rxs: Vec> = Vec::with_capacity(keys.len()); + + for k in keys { + let (id, rx) = server.register_waiter(db_index, k).await; + ids.push(id); + names.push(k.clone()); + rxs.push(rx); + } + + // Wait for the first delivery or timeout + let wait_fut = async move { + let mut futures_vec = rxs; + loop { + if futures_vec.is_empty() { + return None; + } + let (res, idx, remaining) = select_all(futures_vec).await; + match res { + Ok((k, elem)) => { + return Some((k, elem, idx, remaining)); + } + Err(_canceled) => { + // That waiter was canceled; continue with the rest + futures_vec = remaining; + continue; + } + } + } + }; + + match timeout(Duration::from_secs_f64(timeout_secs), wait_fut).await { + Ok(Some((k, elem, idx, _remaining))) => { + // Unregister other waiters + for (i, key_name) in names.iter().enumerate() { + if i != idx { + server.unregister_waiter(db_index, key_name, ids[i]).await; + } + } + Ok(Protocol::Array(vec![ + Protocol::BulkString(k), + Protocol::BulkString(elem), + ])) + } + Ok(None) => { + // No futures left; unregister all waiters + for (i, key_name) in names.iter().enumerate() { + server.unregister_waiter(db_index, key_name, ids[i]).await; + } + Ok(Protocol::Null) + } + Err(_elapsed) => { + // Timeout: unregister all waiters + for (i, key_name) in names.iter().enumerate() { + server.unregister_waiter(db_index, key_name, ids[i]).await; + } + Ok(Protocol::Null) + } + } +} + async fn lpush_cmd(server: &Server, key: &str, elements: &[String]) -> Result { match server.current_storage()?.lpush(key, elements.to_vec()) { - Ok(len) => Ok(Protocol::SimpleString(len.to_string())), + Ok(len) => { + // Attempt to deliver to any blocked BLPOP waiters + let _ = server.drain_waiters_after_push(key).await; + Ok(Protocol::SimpleString(len.to_string())) + } Err(e) => Ok(Protocol::err(&e.0)), } } async fn rpush_cmd(server: &Server, key: &str, elements: &[String]) -> Result { match server.current_storage()?.rpush(key, elements.to_vec()) { - Ok(len) => Ok(Protocol::SimpleString(len.to_string())), + Ok(len) => { + // Attempt to deliver to any blocked BLPOP waiters + let _ = server.drain_waiters_after_push(key).await; + Ok(Protocol::SimpleString(len.to_string())) + } Err(e) => Ok(Protocol::err(&e.0)), } } diff --git a/herodb/src/server.rs b/herodb/src/server.rs index c286e21..68a0219 100644 --- a/herodb/src/server.rs +++ b/herodb/src/server.rs @@ -3,6 +3,9 @@ use std::collections::HashMap; use std::sync::Arc; use tokio::io::AsyncReadExt; use tokio::io::AsyncWriteExt; +use tokio::sync::{Mutex, oneshot}; + +use std::sync::atomic::{AtomicU64, Ordering}; use crate::cmd::Cmd; use crate::error::DBError; @@ -17,6 +20,15 @@ pub struct Server { pub client_name: Option, pub selected_db: u64, // Changed from usize to u64 pub queued_cmd: Option>, + + // BLPOP waiter registry: per (db_index, key) FIFO of waiters + pub list_waiters: Arc>>>>, + pub waiter_seq: Arc, +} + +pub struct Waiter { + pub id: u64, + pub tx: oneshot::Sender<(String, String)>, // (key, element) } impl Server { @@ -27,6 +39,9 @@ impl Server { client_name: None, selected_db: 0, queued_cmd: None, + + list_waiters: Arc::new(Mutex::new(HashMap::new())), + waiter_seq: Arc::new(AtomicU64::new(1)), } } @@ -66,6 +81,81 @@ impl Server { self.option.encrypt && db_index >= 10 } + // ----- BLPOP waiter helpers ----- + + pub async fn register_waiter(&self, db_index: u64, key: &str) -> (u64, oneshot::Receiver<(String, String)>) { + let id = self.waiter_seq.fetch_add(1, Ordering::Relaxed); + let (tx, rx) = oneshot::channel::<(String, String)>(); + + let mut guard = self.list_waiters.lock().await; + let per_db = guard.entry(db_index).or_insert_with(HashMap::new); + let q = per_db.entry(key.to_string()).or_insert_with(Vec::new); + q.push(Waiter { id, tx }); + (id, rx) + } + + pub async fn unregister_waiter(&self, db_index: u64, key: &str, id: u64) { + let mut guard = self.list_waiters.lock().await; + if let Some(per_db) = guard.get_mut(&db_index) { + if let Some(q) = per_db.get_mut(key) { + q.retain(|w| w.id != id); + if q.is_empty() { + per_db.remove(key); + } + } + if per_db.is_empty() { + guard.remove(&db_index); + } + } + } + + // Called after LPUSH/RPUSH to deliver to blocked BLPOP waiters. + pub async fn drain_waiters_after_push(&self, key: &str) -> Result<(), DBError> { + let db_index = self.selected_db; + + loop { + // Check if any waiter exists + let maybe_waiter = { + let mut guard = self.list_waiters.lock().await; + if let Some(per_db) = guard.get_mut(&db_index) { + if let Some(q) = per_db.get_mut(key) { + if !q.is_empty() { + // Pop FIFO + Some(q.remove(0)) + } else { + None + } + } else { + None + } + } else { + None + } + }; + + let waiter = if let Some(w) = maybe_waiter { w } else { break }; + + // Pop one element from the left + let elems = self.current_storage()?.lpop(key, 1)?; + if elems.is_empty() { + // Nothing to deliver; re-register waiter at the front to preserve order + let mut guard = self.list_waiters.lock().await; + let per_db = guard.entry(db_index).or_insert_with(HashMap::new); + let q = per_db.entry(key.to_string()).or_insert_with(Vec::new); + q.insert(0, waiter); + break; + } else { + let elem = elems[0].clone(); + // Send to waiter; if receiver dropped, just continue + let _ = waiter.tx.send((key.to_string(), elem)); + // Loop to try to satisfy more waiters if more elements remain + continue; + } + } + + Ok(()) + } + pub async fn handle( &mut self, mut stream: tokio::net::TcpStream, -- 2.40.1 From a306544a348d078e1e7f2f08c87169b3595e0879 Mon Sep 17 00:00:00 2001 From: Maxime Van Hees Date: Mon, 18 Aug 2025 16:21:49 +0200 Subject: [PATCH 2/8] implement COMMAND --- herodb/src/cmd.rs | 26 +++++++++++++++++--------- 1 file changed, 17 insertions(+), 9 deletions(-) diff --git a/herodb/src/cmd.rs b/herodb/src/cmd.rs index eef93c4..3403fc5 100644 --- a/herodb/src/cmd.rs +++ b/herodb/src/cmd.rs @@ -40,6 +40,7 @@ pub enum Cmd { Client(Vec), ClientSetName(String), ClientGetName, + Command(Vec), // List commands LPush(String, Vec), RPush(String, Vec), @@ -345,6 +346,10 @@ impl Cmd { Cmd::Client(vec![]) } } + "command" => { + let args = if cmd.len() > 1 { cmd[1..].to_vec() } else { vec![] }; + Cmd::Command(args) + } "lpush" => { if cmd.len() < 3 { return Err(DBError(format!("wrong number of arguments for LPUSH command"))); @@ -540,6 +545,7 @@ impl Cmd { Cmd::Client(_) => Ok(Protocol::SimpleString("OK".to_string())), Cmd::ClientSetName(name) => client_setname_cmd(server, &name).await, Cmd::ClientGetName => client_getname_cmd(server).await, + Cmd::Command(_) => Ok(Protocol::Array(vec![])), // List commands Cmd::LPush(key, elements) => lpush_cmd(server, &key, &elements).await, Cmd::RPush(key, elements) => rpush_cmd(server, &key, &elements).await, @@ -862,17 +868,19 @@ async fn info_cmd(server: &Server, section: &Option) -> Result match s.as_str() { - "replication" => Ok(Protocol::BulkString( - "role:master\nmaster_replid:8371b4fb1155b71f4a04d3e1bc3e18c4a990aeea\nmaster_repl_offset:0\n".to_string() - )), - _ => Err(DBError(format!("unsupported section {:?}", s))), - }, - None => { - Ok(Protocol::BulkString(info_string)) + Some(s) => { + let sl = s.to_lowercase(); + if sl == "replication" { + Ok(Protocol::BulkString( + "role:master\nmaster_replid:8371b4fb1155b71f4a04d3e1bc3e18c4a990aeea\nmaster_repl_offset:0\n".to_string() + )) + } else { + // Return general info for unknown sections (e.g., SERVER) + Ok(Protocol::BulkString(info_string)) + } } + None => Ok(Protocol::BulkString(info_string)), } } -- 2.40.1 From b644bf873fd5ea6fd36b76679d025a59f24975ed Mon Sep 17 00:00:00 2001 From: Maxime Van Hees Date: Tue, 19 Aug 2025 11:16:00 +0200 Subject: [PATCH 3/8] implement MGET/MSET and variadic DEL/EXISTS --- herodb/src/cmd.rs | 91 ++++++++++++++++++++++++++++-- herodb/src/storage/storage_hset.rs | 12 ---- 2 files changed, 86 insertions(+), 17 deletions(-) diff --git a/herodb/src/cmd.rs b/herodb/src/cmd.rs index 3403fc5..3ae8140 100644 --- a/herodb/src/cmd.rs +++ b/herodb/src/cmd.rs @@ -12,6 +12,8 @@ pub enum Cmd { Set(String, String), SetPx(String, String, u128), SetEx(String, String, u128), + MGet(Vec), + MSet(Vec<(String, String)>), Keys, ConfigGet(String), Info(Option), @@ -36,6 +38,8 @@ pub enum Cmd { Scan(u64, Option, Option), // cursor, pattern, count Ttl(String), Exists(String), + ExistsMulti(Vec), + DelMulti(Vec), Quit, Client(Vec), ClientSetName(String), @@ -110,6 +114,24 @@ impl Cmd { } Cmd::SetEx(cmd[1].clone(), cmd[3].clone(), cmd[2].parse().unwrap()) } + "mget" => { + if cmd.len() < 2 { + return Err(DBError("wrong number of arguments for MGET command".to_string())); + } + Cmd::MGet(cmd[1..].to_vec()) + } + "mset" => { + if cmd.len() < 3 || ((cmd.len() - 1) % 2 != 0) { + return Err(DBError("wrong number of arguments for MSET command".to_string())); + } + let mut pairs = Vec::new(); + let mut i = 1; + while i + 1 < cmd.len() { + pairs.push((cmd[i].clone(), cmd[i + 1].clone())); + i += 2; + } + Cmd::MSet(pairs) + } "config" => { if cmd.len() != 3 || cmd[1].to_lowercase() != "get" { return Err(DBError(format!("unsupported cmd {:?}", cmd))); @@ -133,10 +155,14 @@ impl Cmd { Cmd::Info(section) } "del" => { - if cmd.len() != 2 { - return Err(DBError(format!("unsupported cmd {:?}", cmd))); + if cmd.len() < 2 { + return Err(DBError(format!("wrong number of arguments for DEL command"))); + } + if cmd.len() == 2 { + Cmd::Del(cmd[1].clone()) + } else { + Cmd::DelMulti(cmd[1..].to_vec()) } - Cmd::Del(cmd[1].clone()) } "type" => { if cmd.len() != 2 { @@ -312,10 +338,14 @@ impl Cmd { Cmd::Ttl(cmd[1].clone()) } "exists" => { - if cmd.len() != 2 { + if cmd.len() < 2 { return Err(DBError(format!("wrong number of arguments for EXISTS command"))); } - Cmd::Exists(cmd[1].clone()) + if cmd.len() == 2 { + Cmd::Exists(cmd[1].clone()) + } else { + Cmd::ExistsMulti(cmd[1..].to_vec()) + } } "quit" => { if cmd.len() != 1 { @@ -507,7 +537,10 @@ impl Cmd { Cmd::Set(k, v) => set_cmd(server, &k, &v).await, Cmd::SetPx(k, v, x) => set_px_cmd(server, &k, &v, &x).await, Cmd::SetEx(k, v, x) => set_ex_cmd(server, &k, &v, &x).await, + Cmd::MGet(keys) => mget_cmd(server, &keys).await, + Cmd::MSet(pairs) => mset_cmd(server, &pairs).await, Cmd::Del(k) => del_cmd(server, &k).await, + Cmd::DelMulti(keys) => del_multi_cmd(server, &keys).await, Cmd::ConfigGet(name) => config_get_cmd(&name, server), Cmd::Keys => keys_cmd(server).await, Cmd::Info(section) => info_cmd(server, §ion).await, @@ -541,6 +574,7 @@ impl Cmd { Cmd::Scan(cursor, pattern, count) => scan_cmd(server, &cursor, pattern.as_deref(), &count).await, Cmd::Ttl(key) => ttl_cmd(server, &key).await, Cmd::Exists(key) => exists_cmd(server, &key).await, + Cmd::ExistsMulti(keys) => exists_multi_cmd(server, &keys).await, Cmd::Quit => Ok(Protocol::SimpleString("OK".to_string())), Cmd::Client(_) => Ok(Protocol::SimpleString("OK".to_string())), Cmd::ClientSetName(name) => client_setname_cmd(server, &name).await, @@ -921,6 +955,53 @@ async fn set_cmd(server: &Server, k: &str, v: &str) -> Result Ok(Protocol::SimpleString("OK".to_string())) } +// MGET: return array of bulk strings or Null for missing +async fn mget_cmd(server: &Server, keys: &[String]) -> Result { + let mut out: Vec = Vec::with_capacity(keys.len()); + let storage = server.current_storage()?; + for k in keys { + match storage.get(k)? { + Some(v) => out.push(Protocol::BulkString(v)), + None => out.push(Protocol::Null), + } + } + Ok(Protocol::Array(out)) +} + +// MSET: set multiple key/value pairs, return OK +async fn mset_cmd(server: &Server, pairs: &[(String, String)]) -> Result { + let storage = server.current_storage()?; + for (k, v) in pairs { + storage.set(k.clone(), v.clone())?; + } + Ok(Protocol::SimpleString("OK".to_string())) +} + +// DEL with multiple keys: return count of keys actually deleted +async fn del_multi_cmd(server: &Server, keys: &[String]) -> Result { + let storage = server.current_storage()?; + let mut deleted = 0i64; + for k in keys { + if storage.exists(k)? { + storage.del(k.clone())?; + deleted += 1; + } + } + Ok(Protocol::SimpleString(deleted.to_string())) +} + +// EXISTS with multiple keys: return count existing +async fn exists_multi_cmd(server: &Server, keys: &[String]) -> Result { + let storage = server.current_storage()?; + let mut count = 0i64; + for k in keys { + if storage.exists(k)? { + count += 1; + } + } + Ok(Protocol::SimpleString(count.to_string())) +} + async fn get_cmd(server: &Server, k: &str) -> Result { let v = server.current_storage()?.get(k)?; Ok(v.map_or(Protocol::Null, Protocol::BulkString)) diff --git a/herodb/src/storage/storage_hset.rs b/herodb/src/storage/storage_hset.rs index e2d3130..dfe9394 100644 --- a/herodb/src/storage/storage_hset.rs +++ b/herodb/src/storage/storage_hset.rs @@ -148,8 +148,6 @@ impl Storage { pub fn hexists(&self, key: &str, field: &str) -> Result { let read_txn = self.db.begin_read()?; - let types_table = read_txn.open_table(TYPES_TABLE)?; - let types_table = read_txn.open_table(TYPES_TABLE)?; let key_type = { let access_guard = types_table.get(key)?; @@ -168,8 +166,6 @@ impl Storage { pub fn hkeys(&self, key: &str) -> Result, DBError> { let read_txn = self.db.begin_read()?; - let types_table = read_txn.open_table(TYPES_TABLE)?; - let types_table = read_txn.open_table(TYPES_TABLE)?; let key_type = { let access_guard = types_table.get(key)?; @@ -200,8 +196,6 @@ impl Storage { // ✅ ENCRYPTION APPLIED: All values are decrypted after retrieval pub fn hvals(&self, key: &str) -> Result, DBError> { let read_txn = self.db.begin_read()?; - let types_table = read_txn.open_table(TYPES_TABLE)?; - let types_table = read_txn.open_table(TYPES_TABLE)?; let key_type = { let access_guard = types_table.get(key)?; @@ -233,8 +227,6 @@ impl Storage { pub fn hlen(&self, key: &str) -> Result { let read_txn = self.db.begin_read()?; - let types_table = read_txn.open_table(TYPES_TABLE)?; - let types_table = read_txn.open_table(TYPES_TABLE)?; let key_type = { let access_guard = types_table.get(key)?; @@ -265,8 +257,6 @@ impl Storage { // ✅ ENCRYPTION APPLIED: Values are decrypted after retrieval pub fn hmget(&self, key: &str, fields: Vec) -> Result>, DBError> { let read_txn = self.db.begin_read()?; - let types_table = read_txn.open_table(TYPES_TABLE)?; - let types_table = read_txn.open_table(TYPES_TABLE)?; let key_type = { let access_guard = types_table.get(key)?; @@ -334,8 +324,6 @@ impl Storage { // ✅ ENCRYPTION APPLIED: Values are decrypted after retrieval pub fn hscan(&self, key: &str, cursor: u64, pattern: Option<&str>, count: Option) -> Result<(u64, Vec<(String, String)>), DBError> { let read_txn = self.db.begin_read()?; - let types_table = read_txn.open_table(TYPES_TABLE)?; - let types_table = read_txn.open_table(TYPES_TABLE)?; let key_type = { let access_guard = types_table.get(key)?; -- 2.40.1 From 34808fc1c97668485e33dee92a7094051b00fa23 Mon Sep 17 00:00:00 2001 From: Maxime Van Hees Date: Tue, 19 Aug 2025 11:34:04 +0200 Subject: [PATCH 4/8] Implemented EXPIRE/PEXPIRE/PERSIST --- herodb/src/cmd.rs | 56 +++ herodb/src/storage/storage_extra.rs | 66 +++ herodb/tests/usage_suite.rs | 600 ++++++++++++++++++++++++++++ 3 files changed, 722 insertions(+) create mode 100644 herodb/tests/usage_suite.rs diff --git a/herodb/src/cmd.rs b/herodb/src/cmd.rs index 3ae8140..59021e7 100644 --- a/herodb/src/cmd.rs +++ b/herodb/src/cmd.rs @@ -37,6 +37,9 @@ pub enum Cmd { HScan(String, u64, Option, Option), // key, cursor, pattern, count Scan(u64, Option, Option), // cursor, pattern, count Ttl(String), + Expire(String, i64), + PExpire(String, i64), + Persist(String), Exists(String), ExistsMulti(Vec), DelMulti(Vec), @@ -337,6 +340,26 @@ impl Cmd { } Cmd::Ttl(cmd[1].clone()) } + "expire" => { + if cmd.len() != 3 { + return Err(DBError("wrong number of arguments for EXPIRE command".to_string())); + } + let secs = cmd[2].parse::().map_err(|_| DBError("ERR value is not an integer or out of range".to_string()))?; + Cmd::Expire(cmd[1].clone(), secs) + } + "pexpire" => { + if cmd.len() != 3 { + return Err(DBError("wrong number of arguments for PEXPIRE command".to_string())); + } + let ms = cmd[2].parse::().map_err(|_| DBError("ERR value is not an integer or out of range".to_string()))?; + Cmd::PExpire(cmd[1].clone(), ms) + } + "persist" => { + if cmd.len() != 2 { + return Err(DBError("wrong number of arguments for PERSIST command".to_string())); + } + Cmd::Persist(cmd[1].clone()) + } "exists" => { if cmd.len() < 2 { return Err(DBError(format!("wrong number of arguments for EXISTS command"))); @@ -573,6 +596,9 @@ impl Cmd { Cmd::HScan(key, cursor, pattern, count) => hscan_cmd(server, &key, &cursor, pattern.as_deref(), &count).await, Cmd::Scan(cursor, pattern, count) => scan_cmd(server, &cursor, pattern.as_deref(), &count).await, Cmd::Ttl(key) => ttl_cmd(server, &key).await, + Cmd::Expire(key, secs) => expire_cmd(server, &key, secs).await, + Cmd::PExpire(key, ms) => pexpire_cmd(server, &key, ms).await, + Cmd::Persist(key) => persist_cmd(server, &key).await, Cmd::Exists(key) => exists_cmd(server, &key).await, Cmd::ExistsMulti(keys) => exists_multi_cmd(server, &keys).await, Cmd::Quit => Ok(Protocol::SimpleString("OK".to_string())), @@ -1151,6 +1177,36 @@ async fn exists_cmd(server: &Server, key: &str) -> Result { } } +// EXPIRE key seconds -> 1 if timeout set, 0 otherwise +async fn expire_cmd(server: &Server, key: &str, secs: i64) -> Result { + if secs < 0 { + return Ok(Protocol::SimpleString("0".to_string())); + } + match server.current_storage()?.expire_seconds(key, secs as u64) { + Ok(applied) => Ok(Protocol::SimpleString(if applied { "1" } else { "0" }.to_string())), + Err(e) => Ok(Protocol::err(&e.0)), + } +} + +// PEXPIRE key milliseconds -> 1 if timeout set, 0 otherwise +async fn pexpire_cmd(server: &Server, key: &str, ms: i64) -> Result { + if ms < 0 { + return Ok(Protocol::SimpleString("0".to_string())); + } + match server.current_storage()?.pexpire_millis(key, ms as u128) { + Ok(applied) => Ok(Protocol::SimpleString(if applied { "1" } else { "0" }.to_string())), + Err(e) => Ok(Protocol::err(&e.0)), + } +} + +// PERSIST key -> 1 if timeout removed, 0 otherwise +async fn persist_cmd(server: &Server, key: &str) -> Result { + match server.current_storage()?.persist(key) { + Ok(removed) => Ok(Protocol::SimpleString(if removed { "1" } else { "0" }.to_string())), + Err(e) => Ok(Protocol::err(&e.0)), + } +} + async fn client_setname_cmd(server: &mut Server, name: &str) -> Result { server.client_name = Some(name.to_string()); Ok(Protocol::SimpleString("OK".to_string())) diff --git a/herodb/src/storage/storage_extra.rs b/herodb/src/storage/storage_extra.rs index cb8aa25..8a12674 100644 --- a/herodb/src/storage/storage_extra.rs +++ b/herodb/src/storage/storage_extra.rs @@ -98,6 +98,72 @@ impl Storage { None => Ok(false), // Key does not exist } } + + // -------- Expiration helpers (string keys only, consistent with TTL/EXISTS) -------- + + // Set expiry in seconds; returns true if applied (key exists and is string), false otherwise + pub fn expire_seconds(&self, key: &str, secs: u64) -> Result { + // Determine eligibility first to avoid holding borrows across commit + let mut applied = false; + let write_txn = self.db.begin_write()?; + { + let types_table = write_txn.open_table(TYPES_TABLE)?; + let is_string = types_table + .get(key)? + .map(|v| v.value() == "string") + .unwrap_or(false); + if is_string { + let mut expiration_table = write_txn.open_table(EXPIRATION_TABLE)?; + let expires_at = now_in_millis() + (secs as u128) * 1000; + expiration_table.insert(key, &(expires_at as u64))?; + applied = true; + } + } + write_txn.commit()?; + Ok(applied) + } + + // Set expiry in milliseconds; returns true if applied (key exists and is string), false otherwise + pub fn pexpire_millis(&self, key: &str, ms: u128) -> Result { + let mut applied = false; + let write_txn = self.db.begin_write()?; + { + let types_table = write_txn.open_table(TYPES_TABLE)?; + let is_string = types_table + .get(key)? + .map(|v| v.value() == "string") + .unwrap_or(false); + if is_string { + let mut expiration_table = write_txn.open_table(EXPIRATION_TABLE)?; + let expires_at = now_in_millis() + ms; + expiration_table.insert(key, &(expires_at as u64))?; + applied = true; + } + } + write_txn.commit()?; + Ok(applied) + } + + // Remove expiry if present; returns true if removed, false otherwise + pub fn persist(&self, key: &str) -> Result { + let mut removed = false; + let write_txn = self.db.begin_write()?; + { + let types_table = write_txn.open_table(TYPES_TABLE)?; + let is_string = types_table + .get(key)? + .map(|v| v.value() == "string") + .unwrap_or(false); + if is_string { + let mut expiration_table = write_txn.open_table(EXPIRATION_TABLE)?; + if expiration_table.remove(key)?.is_some() { + removed = true; + } + } + } + write_txn.commit()?; + Ok(removed) + } } // Utility function for glob pattern matching diff --git a/herodb/tests/usage_suite.rs b/herodb/tests/usage_suite.rs new file mode 100644 index 0000000..5ec554d --- /dev/null +++ b/herodb/tests/usage_suite.rs @@ -0,0 +1,600 @@ +use herodb::{options::DBOption, server::Server}; +use tokio::io::{AsyncReadExt, AsyncWriteExt}; +use tokio::net::TcpStream; +use tokio::time::{sleep, Duration}; + +// ========================= +// Helpers +// ========================= + +async fn start_test_server(test_name: &str) -> (Server, u16) { + use std::sync::atomic::{AtomicU16, Ordering}; + static PORT_COUNTER: AtomicU16 = AtomicU16::new(17100); + let port = PORT_COUNTER.fetch_add(1, Ordering::SeqCst); + + let test_dir = format!("/tmp/herodb_usage_suite_{}", test_name); + let _ = std::fs::remove_dir_all(&test_dir); + std::fs::create_dir_all(&test_dir).unwrap(); + + let option = DBOption { + dir: test_dir, + port, + debug: false, + encrypt: false, + encryption_key: None, + }; + + let server = Server::new(option).await; + (server, port) +} + +async fn spawn_listener(server: Server, port: u16) { + tokio::spawn(async move { + let listener = tokio::net::TcpListener::bind(format!("127.0.0.1:{}", port)) + .await + .expect("bind listener"); + loop { + match listener.accept().await { + Ok((stream, _)) => { + let mut s_clone = server.clone(); + tokio::spawn(async move { + let _ = s_clone.handle(stream).await; + }); + } + Err(_e) => break, + } + } + }); +} + +/// Build RESP array for args ["PING"] -> "*1\r\n$4\r\nPING\r\n" +fn build_resp(args: &[&str]) -> String { + let mut s = format!("*{}\r\n", args.len()); + for a in args { + s.push_str(&format!("${}\r\n{}\r\n", a.len(), a)); + } + s +} + +async fn connect(port: u16) -> TcpStream { + let mut attempts = 0; + loop { + match TcpStream::connect(format!("127.0.0.1:{}", port)).await { + Ok(s) => return s, + Err(_) if attempts < 30 => { + attempts += 1; + sleep(Duration::from_millis(100)).await; + } + Err(e) => panic!("Failed to connect: {}", e), + } + } +} + +async fn send_cmd(stream: &mut TcpStream, args: &[&str]) -> String { + let req = build_resp(args); + stream.write_all(req.as_bytes()).await.unwrap(); + + // Single read is enough for these small replies + let mut buf = vec![0u8; 8192]; + let n = stream.read(&mut buf).await.unwrap(); + String::from_utf8_lossy(&buf[..n]).to_string() +} + +// Assert helpers with clearer output +fn assert_contains(haystack: &str, needle: &str, ctx: &str) { + assert!( + haystack.contains(needle), + "ASSERT CONTAINS failed: '{}' not found in response.\nContext: {}\nResponse:\n{}", + needle, + ctx, + haystack + ); +} + +fn assert_eq_resp(actual: &str, expected: &str, ctx: &str) { + assert!( + actual == expected, + "ASSERT EQUAL failed.\nContext: {}\nExpected:\n{:?}\nActual:\n{:?}", + ctx, + expected, + actual + ); +} + +/// Extract the payload of a single RESP Bulk String reply. +/// Example input: +/// "$5\r\nhello\r\n" -> Some("hello".to_string()) +fn extract_bulk_payload(resp: &str) -> Option { + // find first CRLF after "$len" + let first = resp.find("\r\n")?; + let after = &resp[(first + 2)..]; + // find next CRLF ending payload + let second = after.find("\r\n")?; + Some(after[..second].to_string()) +} + +// ========================= +// Test suites +// ========================= + +#[tokio::test] +async fn test_01_connection_and_info() { + let (server, port) = start_test_server("conn_info").await; + spawn_listener(server, port).await; + sleep(Duration::from_millis(150)).await; + + let mut s = connect(port).await; + + // redis-cli may send COMMAND DOCS, our server replies empty array; harmless. + let pong = send_cmd(&mut s, &["PING"]).await; + assert_contains(&pong, "PONG", "PING should return PONG"); + + let echo = send_cmd(&mut s, &["ECHO", "hello"]).await; + assert_contains(&echo, "hello", "ECHO hello"); + + // INFO (general) + let info = send_cmd(&mut s, &["INFO"]).await; + assert_contains(&info, "redis_version", "INFO should include redis_version"); + + // INFO REPLICATION (static stub) + let repl = send_cmd(&mut s, &["INFO", "replication"]).await; + assert_contains(&repl, "role:master", "INFO replication role"); + + // CONFIG GET subset + let cfg = send_cmd(&mut s, &["CONFIG", "GET", "databases"]).await; + assert_contains(&cfg, "databases", "CONFIG GET databases"); + assert_contains(&cfg, "16", "CONFIG GET databases value"); + + // CLIENT name + let setname = send_cmd(&mut s, &["CLIENT", "SETNAME", "myapp"]).await; + assert_contains(&setname, "OK", "CLIENT SETNAME"); + + let getname = send_cmd(&mut s, &["CLIENT", "GETNAME"]).await; + assert_contains(&getname, "myapp", "CLIENT GETNAME"); + + // SELECT db + let sel = send_cmd(&mut s, &["SELECT", "0"]).await; + assert_contains(&sel, "OK", "SELECT 0"); + + // QUIT should close connection after sending OK + let quit = send_cmd(&mut s, &["QUIT"]).await; + assert_contains(&quit, "OK", "QUIT should return OK"); +} + +#[tokio::test] +async fn test_02_strings_and_expiry() { + let (server, port) = start_test_server("strings").await; + spawn_listener(server, port).await; + sleep(Duration::from_millis(150)).await; + + let mut s = connect(port).await; + + // SET / GET + let set = send_cmd(&mut s, &["SET", "user:1", "alice"]).await; + assert_contains(&set, "OK", "SET user:1 alice"); + + let get = send_cmd(&mut s, &["GET", "user:1"]).await; + assert_contains(&get, "alice", "GET user:1"); + + // EXISTS / DEL + let ex1 = send_cmd(&mut s, &["EXISTS", "user:1"]).await; + assert_contains(&ex1, "1", "EXISTS user:1"); + + let del = send_cmd(&mut s, &["DEL", "user:1"]).await; + assert_contains(&del, "1", "DEL user:1"); + + let ex0 = send_cmd(&mut s, &["EXISTS", "user:1"]).await; + assert_contains(&ex0, "0", "EXISTS after DEL"); + + // INCR behavior + let i1 = send_cmd(&mut s, &["INCR", "count"]).await; + assert_contains(&i1, "1", "INCR new key -> 1"); + let i2 = send_cmd(&mut s, &["INCR", "count"]).await; + assert_contains(&i2, "2", "INCR existing -> 2"); + let _ = send_cmd(&mut s, &["SET", "notnum", "abc"]).await; + let ierr = send_cmd(&mut s, &["INCR", "notnum"]).await; + assert_contains(&ierr, "ERR", "INCR on non-numeric should ERR"); + + // Expiration via SET EX + let setex = send_cmd(&mut s, &["SET", "tmp:1", "boom", "EX", "1"]).await; + assert_contains(&setex, "OK", "SET tmp:1 EX 1"); + + let g_immediate = send_cmd(&mut s, &["GET", "tmp:1"]).await; + assert_contains(&g_immediate, "boom", "GET tmp:1 immediately"); + + let ttl = send_cmd(&mut s, &["TTL", "tmp:1"]).await; + // Implementation returns a SimpleString, accept any numeric content + assert!( + ttl.contains("1") || ttl.contains("0"), + "TTL should be 1 or 0, got: {}", + ttl + ); + + sleep(Duration::from_millis(1100)).await; + let g_after = send_cmd(&mut s, &["GET", "tmp:1"]).await; + assert_contains(&g_after, "$-1", "GET tmp:1 after expiry -> Null"); + + // TYPE + let _ = send_cmd(&mut s, &["SET", "t", "v"]).await; + let ty = send_cmd(&mut s, &["TYPE", "t"]).await; + assert_contains(&ty, "string", "TYPE string key"); + let ty_none = send_cmd(&mut s, &["TYPE", "noexist"]).await; + assert_contains(&ty_none, "none", "TYPE nonexistent"); +} + +#[tokio::test] +async fn test_03_scan_and_keys() { + let (server, port) = start_test_server("scan").await; + spawn_listener(server, port).await; + sleep(Duration::from_millis(150)).await; + + let mut s = connect(port).await; + + for i in 0..5 { + let _ = send_cmd(&mut s, &["SET", &format!("key{}", i), &format!("value{}", i)]).await; + } + + let scan = send_cmd(&mut s, &["SCAN", "0", "MATCH", "key*", "COUNT", "10"]).await; + assert_contains(&scan, "key0", "SCAN should return keys with MATCH"); + assert_contains(&scan, "key4", "SCAN should return last key"); + + let keys = send_cmd(&mut s, &["KEYS", "*"]).await; + assert_contains(&keys, "key0", "KEYS * includes key0"); + assert_contains(&keys, "key4", "KEYS * includes key4"); +} + +#[tokio::test] +async fn test_04_hashes_suite() { + let (server, port) = start_test_server("hashes").await; + spawn_listener(server, port).await; + sleep(Duration::from_millis(150)).await; + + let mut s = connect(port).await; + + // HSET (single, returns number of new fields) + let h1 = send_cmd(&mut s, &["HSET", "profile:1", "name", "alice"]).await; + assert_contains(&h1, "1", "HSET new field -> 1"); + + // HGET + let hg = send_cmd(&mut s, &["HGET", "profile:1", "name"]).await; + assert_contains(&hg, "alice", "HGET existing field"); + + // HSET multiple + let h2 = send_cmd(&mut s, &["HSET", "profile:1", "age", "30", "city", "paris"]).await; + assert_contains(&h2, "2", "HSET added 2 new fields"); + + // HMGET + let hmg = send_cmd(&mut s, &["HMGET", "profile:1", "name", "age", "city", "nope"]).await; + assert_contains(&hmg, "alice", "HMGET name"); + assert_contains(&hmg, "30", "HMGET age"); + assert_contains(&hmg, "paris", "HMGET city"); + assert_contains(&hmg, "$-1", "HMGET non-existent -> Null"); + + // HGETALL + let hga = send_cmd(&mut s, &["HGETALL", "profile:1"]).await; + assert_contains(&hga, "name", "HGETALL contains name"); + assert_contains(&hga, "alice", "HGETALL contains alice"); + + // HLEN + let hlen = send_cmd(&mut s, &["HLEN", "profile:1"]).await; + assert_contains(&hlen, "3", "HLEN is 3"); + + // HEXISTS + let hex1 = send_cmd(&mut s, &["HEXISTS", "profile:1", "age"]).await; + assert_contains(&hex1, "1", "HEXISTS age true"); + let hex0 = send_cmd(&mut s, &["HEXISTS", "profile:1", "nope"]).await; + assert_contains(&hex0, "0", "HEXISTS nope false"); + + // HKEYS / HVALS + let hkeys = send_cmd(&mut s, &["HKEYS", "profile:1"]).await; + assert_contains(&hkeys, "name", "HKEYS includes name"); + let hvals = send_cmd(&mut s, &["HVALS", "profile:1"]).await; + assert_contains(&hvals, "alice", "HVALS includes alice"); + + // HSETNX + let hnx0 = send_cmd(&mut s, &["HSETNX", "profile:1", "name", "bob"]).await; + assert_contains(&hnx0, "0", "HSETNX existing field -> 0"); + let hnx1 = send_cmd(&mut s, &["HSETNX", "profile:1", "nickname", "ali"]).await; + assert_contains(&hnx1, "1", "HSETNX new field -> 1"); + + // HSCAN + let hscan = send_cmd(&mut s, &["HSCAN", "profile:1", "0", "MATCH", "n*", "COUNT", "10"]).await; + assert_contains(&hscan, "name", "HSCAN matches fields starting with n"); + assert_contains(&hscan, "nickname", "HSCAN nickname present"); + + // HDEL + let hdel = send_cmd(&mut s, &["HDEL", "profile:1", "city", "age"]).await; + assert_contains(&hdel, "2", "HDEL removed two fields"); +} + +#[tokio::test] +async fn test_05_lists_suite_including_blpop() { + let (server, port) = start_test_server("lists").await; + spawn_listener(server, port).await; + sleep(Duration::from_millis(150)).await; + + let mut a = connect(port).await; + + // LPUSH / RPUSH / LLEN + let lp = send_cmd(&mut a, &["LPUSH", "q:jobs", "a", "b"]).await; + assert_contains(&lp, "2", "LPUSH added 2, length 2"); + + let rp = send_cmd(&mut a, &["RPUSH", "q:jobs", "c"]).await; + assert_contains(&rp, "3", "RPUSH now length 3"); + + let llen = send_cmd(&mut a, &["LLEN", "q:jobs"]).await; + assert_contains(&llen, "3", "LLEN 3"); + + // LINDEX / LRANGE + let lidx = send_cmd(&mut a, &["LINDEX", "q:jobs", "0"]).await; + assert_eq_resp(&lidx, "$1\r\nb\r\n", "LINDEX q:jobs 0 should be b"); + + let lr = send_cmd(&mut a, &["LRANGE", "q:jobs", "0", "-1"]).await; + assert_eq_resp(&lr, "*3\r\n$1\r\nb\r\n$1\r\na\r\n$1\r\nc\r\n", "LRANGE q:jobs 0 -1 should be [b,a,c]"); + + // LTRIM + let ltrim = send_cmd(&mut a, &["LTRIM", "q:jobs", "0", "1"]).await; + assert_contains(<rim, "OK", "LTRIM OK"); + let lr_post = send_cmd(&mut a, &["LRANGE", "q:jobs", "0", "-1"]).await; + assert_eq_resp(&lr_post, "*2\r\n$1\r\nb\r\n$1\r\na\r\n", "After LTRIM, list [b,a]"); + + // LREM remove first occurrence of b + let lrem = send_cmd(&mut a, &["LREM", "q:jobs", "1", "b"]).await; + assert_contains(&lrem, "1", "LREM removed 1"); + + // LPOP and RPOP + let lpop1 = send_cmd(&mut a, &["LPOP", "q:jobs"]).await; + assert_contains(&lpop1, "$1\r\na\r\n", "LPOP returns a"); + let rpop_empty = send_cmd(&mut a, &["RPOP", "q:jobs"]).await; // empty now + assert_contains(&rpop_empty, "$-1", "RPOP on empty -> Null"); + + // LPOP with count on empty -> [] + let lpop0 = send_cmd(&mut a, &["LPOP", "q:jobs", "2"]).await; + assert_eq_resp(&lpop0, "*0\r\n", "LPOP with count on empty returns empty array"); + + // BLPOP: block on one client, push from another + let c1 = connect(port).await; + let mut c2 = connect(port).await; + + // Start BLPOP on c1 + let blpop_task = tokio::spawn(async move { + let mut c1_local = c1; + send_cmd(&mut c1_local, &["BLPOP", "q:block", "5"]).await + }); + + // Give it time to register waiter + sleep(Duration::from_millis(150)).await; + + // Push from c2 to wake BLPOP + let _ = send_cmd(&mut c2, &["LPUSH", "q:block", "x"]).await; + + // Await BLPOP result + let blpop_res = blpop_task.await.expect("BLPOP task join"); + assert_contains(&blpop_res, "q:block", "BLPOP returned key"); + assert_contains(&blpop_res, "x", "BLPOP returned element"); +} + +#[tokio::test] +async fn test_06_flushdb_suite() { + let (server, port) = start_test_server("flushdb").await; + spawn_listener(server, port).await; + sleep(Duration::from_millis(150)).await; + + let mut s = connect(port).await; + + let _ = send_cmd(&mut s, &["SET", "k1", "v1"]).await; + let _ = send_cmd(&mut s, &["HSET", "h1", "f", "v"]).await; + let _ = send_cmd(&mut s, &["LPUSH", "l1", "a"]).await; + + let keys_before = send_cmd(&mut s, &["KEYS", "*"]).await; + assert_contains(&keys_before, "k1", "have string key before FLUSHDB"); + assert_contains(&keys_before, "h1", "have hash key before FLUSHDB"); + assert_contains(&keys_before, "l1", "have list key before FLUSHDB"); + + let fl = send_cmd(&mut s, &["FLUSHDB"]).await; + assert_contains(&fl, "OK", "FLUSHDB OK"); + + let keys_after = send_cmd(&mut s, &["KEYS", "*"]).await; + assert_eq_resp(&keys_after, "*0\r\n", "DB should be empty after FLUSHDB"); +} + +#[tokio::test] +async fn test_07_age_stateless_suite() { + let (server, port) = start_test_server("age_stateless").await; + spawn_listener(server, port).await; + sleep(Duration::from_millis(150)).await; + + let mut s = connect(port).await; + + // GENENC -> [recipient, identity] + let gen = send_cmd(&mut s, &["AGE", "GENENC"]).await; + assert!( + gen.starts_with("*2\r\n$"), + "AGE GENENC should return array [recipient, identity], got:\n{}", + gen + ); + + // Parse simple RESP array of two bulk strings to extract keys + fn parse_two_bulk_array(resp: &str) -> (String, String) { + // naive parse for tests + let mut lines = resp.lines(); + let _ = lines.next(); // *2 + // $len + let _ = lines.next(); + let recip = lines.next().unwrap_or("").to_string(); + let _ = lines.next(); + let ident = lines.next().unwrap_or("").to_string(); + (recip, ident) + } + let (recipient, identity) = parse_two_bulk_array(&gen); + assert!( + recipient.starts_with("age1") && identity.starts_with("AGE-SECRET-KEY-1"), + "Unexpected AGE key formats.\nrecipient: {}\nidentity: {}", + recipient, + identity + ); + + // ENCRYPT / DECRYPT + let ct = send_cmd(&mut s, &["AGE", "ENCRYPT", &recipient, "hello world"]).await; + let ct_b64 = extract_bulk_payload(&ct).expect("Failed to parse bulk payload from ENCRYPT"); + let pt = send_cmd(&mut s, &["AGE", "DECRYPT", &identity, &ct_b64]).await; + assert_contains(&pt, "hello world", "AGE DECRYPT round-trip"); + + // GENSIGN -> [verify_pub_b64, sign_secret_b64] + let gensign = send_cmd(&mut s, &["AGE", "GENSIGN"]).await; + let (verify_pub, sign_secret) = parse_two_bulk_array(&gensign); + assert!( + !verify_pub.is_empty() && !sign_secret.is_empty(), + "GENSIGN returned empty keys" + ); + + // SIGN / VERIFY + let sig = send_cmd(&mut s, &["AGE", "SIGN", &sign_secret, "msg"]).await; + let sig_b64 = extract_bulk_payload(&sig).expect("Failed to parse bulk payload from SIGN"); + let v_ok = send_cmd(&mut s, &["AGE", "VERIFY", &verify_pub, "msg", &sig_b64]).await; + assert_contains(&v_ok, "1", "VERIFY should be 1 for valid signature"); + + let v_bad = send_cmd(&mut s, &["AGE", "VERIFY", &verify_pub, "tampered", &sig_b64]).await; + assert_contains(&v_bad, "0", "VERIFY should be 0 for invalid message/signature"); +} + +#[tokio::test] +async fn test_08_age_persistent_named_suite() { + let (server, port) = start_test_server("age_persistent").await; + spawn_listener(server, port).await; + sleep(Duration::from_millis(150)).await; + + let mut s = connect(port).await; + + // KEYGEN + ENCRYPTNAME/DECRYPTNAME + let kg = send_cmd(&mut s, &["AGE", "KEYGEN", "app1"]).await; + assert!( + kg.starts_with("*2\r\n"), + "AGE KEYGEN should return [recipient, identity], got:\n{}", + kg + ); + + let ct = send_cmd(&mut s, &["AGE", "ENCRYPTNAME", "app1", "hello"]).await; + let ct_b64 = extract_bulk_payload(&ct).expect("Failed to parse bulk payload from ENCRYPTNAME"); + let pt = send_cmd(&mut s, &["AGE", "DECRYPTNAME", "app1", &ct_b64]).await; + assert_contains(&pt, "hello", "DECRYPTNAME round-trip"); + + // SIGNKEYGEN + SIGNNAME/VERIFYNAME + let skg = send_cmd(&mut s, &["AGE", "SIGNKEYGEN", "app1"]).await; + assert!( + skg.starts_with("*2\r\n"), + "AGE SIGNKEYGEN should return [verify_pub, sign_secret], got:\n{}", + skg + ); + + let sig = send_cmd(&mut s, &["AGE", "SIGNNAME", "app1", "m"] ).await; + let sig_b64 = extract_bulk_payload(&sig).expect("Failed to parse bulk payload from SIGNNAME"); + let v1 = send_cmd(&mut s, &["AGE", "VERIFYNAME", "app1", "m", &sig_b64]).await; + assert_contains(&v1, "1", "VERIFYNAME valid => 1"); + + let v0 = send_cmd(&mut s, &["AGE", "VERIFYNAME", "app1", "bad", &sig_b64]).await; + assert_contains(&v0, "0", "VERIFYNAME invalid => 0"); + + // AGE LIST + let lst = send_cmd(&mut s, &["AGE", "LIST"]).await; + assert_contains(&lst, "encpub", "AGE LIST label encpub"); + assert_contains(&lst, "app1", "AGE LIST includes app1"); +} + +#[tokio::test] +async fn test_10_expire_pexpire_persist() { + let (server, port) = start_test_server("expire_suite").await; + spawn_listener(server, port).await; + sleep(Duration::from_millis(150)).await; + + let mut s = connect(port).await; + + // EXPIRE: seconds + let _ = send_cmd(&mut s, &["SET", "exp:s", "v"]).await; + let ex = send_cmd(&mut s, &["EXPIRE", "exp:s", "1"]).await; + assert_contains(&ex, "1", "EXPIRE exp:s 1 -> 1 (applied)"); + let ttl1 = send_cmd(&mut s, &["TTL", "exp:s"]).await; + assert!( + ttl1.contains("1") || ttl1.contains("0"), + "TTL exp:s should be 1 or 0, got: {}", + ttl1 + ); + sleep(Duration::from_millis(1100)).await; + let get_after = send_cmd(&mut s, &["GET", "exp:s"]).await; + assert_contains(&get_after, "$-1", "GET after expiry should be Null"); + let ttl_after = send_cmd(&mut s, &["TTL", "exp:s"]).await; + assert_contains(&ttl_after, "-2", "TTL after expiry -> -2"); + let exists_after = send_cmd(&mut s, &["EXISTS", "exp:s"]).await; + assert_contains(&exists_after, "0", "EXISTS after expiry -> 0"); + + // PEXPIRE: milliseconds + let _ = send_cmd(&mut s, &["SET", "exp:ms", "v"]).await; + let pex = send_cmd(&mut s, &["PEXPIRE", "exp:ms", "1500"]).await; + assert_contains(&pex, "1", "PEXPIRE exp:ms 1500 -> 1 (applied)"); + let ttl_ms1 = send_cmd(&mut s, &["TTL", "exp:ms"]).await; + assert!( + ttl_ms1.contains("1") || ttl_ms1.contains("0"), + "TTL exp:ms should be 1 or 0 soon after PEXPIRE, got: {}", + ttl_ms1 + ); + sleep(Duration::from_millis(1600)).await; + let exists_ms_after = send_cmd(&mut s, &["EXISTS", "exp:ms"]).await; + assert_contains(&exists_ms_after, "0", "EXISTS exp:ms after ms expiry -> 0"); + + // PERSIST: remove expiration + let _ = send_cmd(&mut s, &["SET", "exp:persist", "v"]).await; + let _ = send_cmd(&mut s, &["EXPIRE", "exp:persist", "5"]).await; + let ttl_pre = send_cmd(&mut s, &["TTL", "exp:persist"]).await; + assert!( + ttl_pre.contains("5") || ttl_pre.contains("4") || ttl_pre.contains("3") || ttl_pre.contains("2") || ttl_pre.contains("1") || ttl_pre.contains("0"), + "TTL exp:persist should be >=0 before persist, got: {}", + ttl_pre + ); + let persist1 = send_cmd(&mut s, &["PERSIST", "exp:persist"]).await; + assert_contains(&persist1, "1", "PERSIST should remove expiration"); + let ttl_post = send_cmd(&mut s, &["TTL", "exp:persist"]).await; + assert_contains(&ttl_post, "-1", "TTL after PERSIST -> -1 (no expiration)"); + // Second persist should return 0 (nothing to remove) + let persist2 = send_cmd(&mut s, &["PERSIST", "exp:persist"]).await; + assert_contains(&persist2, "0", "PERSIST again -> 0 (no expiration to remove)"); +} + +#[tokio::test] +async fn test_09_mget_mset_and_variadic_exists_del() { + let (server, port) = start_test_server("mget_mset_variadic").await; + spawn_listener(server, port).await; + sleep(Duration::from_millis(150)).await; + + let mut s = connect(port).await; + + // MSET multiple keys + let mset = send_cmd(&mut s, &["MSET", "k1", "v1", "k2", "v2", "k3", "v3"]).await; + assert_contains(&mset, "OK", "MSET k1 v1 k2 v2 k3 v3 -> OK"); + + // MGET should return values and Null for missing + let mget = send_cmd(&mut s, &["MGET", "k1", "k2", "nope", "k3"]).await; + // Expect an array with 4 entries; verify payloads + assert_contains(&mget, "v1", "MGET k1"); + assert_contains(&mget, "v2", "MGET k2"); + assert_contains(&mget, "v3", "MGET k3"); + assert_contains(&mget, "$-1", "MGET missing returns Null"); + + // EXISTS variadic: count how many exist + let exists_multi = send_cmd(&mut s, &["EXISTS", "k1", "nope", "k3"]).await; + // Server returns SimpleString numeric, e.g. +2 + assert_contains(&exists_multi, "2", "EXISTS k1 nope k3 -> 2"); + + // DEL variadic: delete multiple keys, return count deleted + let del_multi = send_cmd(&mut s, &["DEL", "k1", "k3", "nope"]).await; + assert_contains(&del_multi, "2", "DEL k1 k3 nope -> 2"); + + // Verify deletion + let exists_after = send_cmd(&mut s, &["EXISTS", "k1", "k3"]).await; + assert_contains(&exists_after, "0", "EXISTS k1 k3 after DEL -> 0"); + + // MGET after deletion should include Nulls for deleted keys + let mget_after = send_cmd(&mut s, &["MGET", "k1", "k2", "k3"]).await; + assert_contains(&mget_after, "$-1", "MGET k1 after DEL -> Null"); + assert_contains(&mget_after, "v2", "MGET k2 remains"); + assert_contains(&mget_after, "$-1", "MGET k3 after DEL -> Null"); +} \ No newline at end of file -- 2.40.1 From a92c90e9cbfbb117108435f3ebf105f01a2998c2 Mon Sep 17 00:00:00 2001 From: Maxime Van Hees Date: Tue, 19 Aug 2025 15:36:07 +0200 Subject: [PATCH 5/8] implemented HINCRBY/HINCRBYFLOAT + fixed partial-frame handling bug causing sporadic protocol parsing errors when sending or receiving large bulk strings (AGE ciphertext/signature) (happend because TCP segmentation can split a single RESP frame; both client & server assumed a single read would contain the whole frame) --- herodb/src/cmd.rs | 154 +++++++++++++++++++++++++++-- herodb/src/protocol.rs | 50 ++++++---- herodb/src/server.rs | 38 ++++--- herodb/tests/usage_suite.rs | 192 +++++++++++++++++++++++++++++++++++- 4 files changed, 391 insertions(+), 43 deletions(-) diff --git a/herodb/src/cmd.rs b/herodb/src/cmd.rs index 59021e7..8a9058c 100644 --- a/herodb/src/cmd.rs +++ b/herodb/src/cmd.rs @@ -12,6 +12,8 @@ pub enum Cmd { Set(String, String), SetPx(String, String, u128), SetEx(String, String, u128), + // Advanced SET with options: (key, value, ex_ms, nx, xx, get) + SetOpts(String, String, Option, bool, bool, bool), MGet(Vec), MSet(Vec<(String, String)>), Keys, @@ -34,6 +36,8 @@ pub enum Cmd { HLen(String), HMGet(String, Vec), HSetNx(String, String, String), + HIncrBy(String, String, i64), + HIncrByFloat(String, String, f64), HScan(String, u64, Option, Option), // key, cursor, pattern, count Scan(u64, Option, Option), // cursor, pattern, count Ttl(String), @@ -101,14 +105,51 @@ impl Cmd { "ping" => Cmd::Ping, "get" => Cmd::Get(cmd[1].clone()), "set" => { - if cmd.len() == 5 && cmd[3].to_lowercase() == "px" { - Cmd::SetPx(cmd[1].clone(), cmd[2].clone(), cmd[4].parse().unwrap()) - } else if cmd.len() == 5 && cmd[3].to_lowercase() == "ex" { - Cmd::SetEx(cmd[1].clone(), cmd[2].clone(), cmd[4].parse().unwrap()) - } else if cmd.len() == 3 { - Cmd::Set(cmd[1].clone(), cmd[2].clone()) + if cmd.len() < 3 { + return Err(DBError("wrong number of arguments for SET".to_string())); + } + let key = cmd[1].clone(); + let val = cmd[2].clone(); + + // Parse optional flags: EX sec | PX ms | NX | XX | GET + let mut ex_ms: Option = None; + let mut nx = false; + let mut xx = false; + let mut getflag = false; + + let mut i = 3; + while i < cmd.len() { + match cmd[i].to_lowercase().as_str() { + "ex" => { + if i + 1 >= cmd.len() { + return Err(DBError("ERR syntax error".to_string())); + } + let secs: u128 = cmd[i + 1].parse().map_err(|_| DBError("ERR value is not an integer or out of range".to_string()))?; + ex_ms = Some(secs * 1000); + i += 2; + } + "px" => { + if i + 1 >= cmd.len() { + return Err(DBError("ERR syntax error".to_string())); + } + let ms: u128 = cmd[i + 1].parse().map_err(|_| DBError("ERR value is not an integer or out of range".to_string()))?; + ex_ms = Some(ms); + i += 2; + } + "nx" => { nx = true; i += 1; } + "xx" => { xx = true; i += 1; } + "get" => { getflag = true; i += 1; } + _ => { + return Err(DBError(format!("unsupported cmd {:?}", cmd))); + } + } + } + + // If no options, keep legacy behavior + if ex_ms.is_none() && !nx && !xx && !getflag { + Cmd::Set(key, val) } else { - return Err(DBError(format!("unsupported cmd {:?}", cmd))); + Cmd::SetOpts(key, val, ex_ms, nx, xx, getflag) } } "setex" => { @@ -259,6 +300,20 @@ impl Cmd { } Cmd::HSetNx(cmd[1].clone(), cmd[2].clone(), cmd[3].clone()) } + "hincrby" => { + if cmd.len() != 4 { + return Err(DBError(format!("wrong number of arguments for HINCRBY command"))); + } + let delta = cmd[3].parse::().map_err(|_| DBError("ERR value is not an integer or out of range".to_string()))?; + Cmd::HIncrBy(cmd[1].clone(), cmd[2].clone(), delta) + } + "hincrbyfloat" => { + if cmd.len() != 4 { + return Err(DBError(format!("wrong number of arguments for HINCRBYFLOAT command"))); + } + let delta = cmd[3].parse::().map_err(|_| DBError("ERR value is not a valid float".to_string()))?; + Cmd::HIncrByFloat(cmd[1].clone(), cmd[2].clone(), delta) + } "hscan" => { if cmd.len() < 3 { return Err(DBError(format!("wrong number of arguments for HSCAN command"))); @@ -560,6 +615,7 @@ impl Cmd { Cmd::Set(k, v) => set_cmd(server, &k, &v).await, Cmd::SetPx(k, v, x) => set_px_cmd(server, &k, &v, &x).await, Cmd::SetEx(k, v, x) => set_ex_cmd(server, &k, &v, &x).await, + Cmd::SetOpts(k, v, ex_ms, nx, xx, getflag) => set_with_opts_cmd(server, &k, &v, ex_ms, nx, xx, getflag).await, Cmd::MGet(keys) => mget_cmd(server, &keys).await, Cmd::MSet(pairs) => mset_cmd(server, &pairs).await, Cmd::Del(k) => del_cmd(server, &k).await, @@ -593,6 +649,8 @@ impl Cmd { Cmd::HLen(key) => hlen_cmd(server, &key).await, Cmd::HMGet(key, fields) => hmget_cmd(server, &key, &fields).await, Cmd::HSetNx(key, field, value) => hsetnx_cmd(server, &key, &field, &value).await, + Cmd::HIncrBy(key, field, delta) => hincrby_cmd(server, &key, &field, delta).await, + Cmd::HIncrByFloat(key, field, delta) => hincrbyfloat_cmd(server, &key, &field, delta).await, Cmd::HScan(key, cursor, pattern, count) => hscan_cmd(server, &key, &cursor, pattern.as_deref(), &count).await, Cmd::Scan(cursor, pattern, count) => scan_cmd(server, &cursor, pattern.as_deref(), &count).await, Cmd::Ttl(key) => ttl_cmd(server, &key).await, @@ -981,6 +1039,62 @@ async fn set_cmd(server: &Server, k: &str, v: &str) -> Result Ok(Protocol::SimpleString("OK".to_string())) } +// Advanced SET with options: EX/PX/NX/XX/GET +async fn set_with_opts_cmd( + server: &Server, + key: &str, + value: &str, + ex_ms: Option, + nx: bool, + xx: bool, + get_old: bool, +) -> Result { + let storage = server.current_storage()?; + + // Determine existence (for NX/XX) + let exists = storage.exists(key)?; + + // If both NX and XX, condition can never be satisfied -> no-op + let mut should_set = true; + if nx && exists { + should_set = false; + } + if xx && !exists { + should_set = false; + } + + // Fetch old value if needed for GET + let old_val = if get_old { + storage.get(key)? + } else { + None + }; + + if should_set { + if let Some(ms) = ex_ms { + storage.setx(key.to_string(), value.to_string(), ms)?; + } else { + storage.set(key.to_string(), value.to_string())?; + } + } + + if get_old { + // Return previous value (or Null), regardless of NX/XX outcome only if set executed? + // We follow Redis semantics: return old value if set executed, else Null + if should_set { + Ok(old_val.map_or(Protocol::Null, Protocol::BulkString)) + } else { + Ok(Protocol::Null) + } + } else { + if should_set { + Ok(Protocol::SimpleString("OK".to_string())) + } else { + Ok(Protocol::Null) + } + } +} + // MGET: return array of bulk strings or Null for missing async fn mget_cmd(server: &Server, keys: &[String]) -> Result { let mut out: Vec = Vec::with_capacity(keys.len()); @@ -1120,6 +1234,32 @@ async fn hsetnx_cmd(server: &Server, key: &str, field: &str, value: &str) -> Res } } +async fn hincrby_cmd(server: &Server, key: &str, field: &str, delta: i64) -> Result { + let storage = server.current_storage()?; + let current = storage.hget(key, field)?; + let base: i64 = match current { + Some(v) => v.parse::().map_err(|_| DBError("ERR value is not an integer or out of range".to_string()))?, + None => 0, + }; + let new_val = base.checked_add(delta).ok_or_else(|| DBError("ERR increment or decrement would overflow".to_string()))?; + // Update the field + storage.hset(key, vec![(field.to_string(), new_val.to_string())])?; + Ok(Protocol::SimpleString(new_val.to_string())) +} + +async fn hincrbyfloat_cmd(server: &Server, key: &str, field: &str, delta: f64) -> Result { + let storage = server.current_storage()?; + let current = storage.hget(key, field)?; + let base: f64 = match current { + Some(v) => v.parse::().map_err(|_| DBError("ERR value is not a valid float".to_string()))?, + None => 0.0, + }; + let new_val = base + delta; + // Update the field + storage.hset(key, vec![(field.to_string(), new_val.to_string())])?; + Ok(Protocol::SimpleString(new_val.to_string())) +} + async fn scan_cmd( server: &Server, cursor: &u64, diff --git a/herodb/src/protocol.rs b/herodb/src/protocol.rs index c9a2255..6025074 100644 --- a/herodb/src/protocol.rs +++ b/herodb/src/protocol.rs @@ -19,6 +19,10 @@ impl fmt::Display for Protocol { impl Protocol { pub fn from(protocol: &str) -> Result<(Self, &str), DBError> { + if protocol.is_empty() { + // Incomplete frame; caller should read more bytes + return Err(DBError("[incomplete] empty".to_string())); + } let ret = match protocol.chars().nth(0) { Some('+') => Self::parse_simple_string_sfx(&protocol[1..]), Some('$') => Self::parse_bulk_string_sfx(&protocol[1..]), @@ -101,21 +105,20 @@ impl Protocol { let size = Self::parse_usize(&protocol[..len_end])?; let data_start = len_end + 2; let data_end = data_start + size; - let s = Self::parse_string(&protocol[data_start..data_end])?; - if protocol.len() < data_end + 2 || &protocol[data_end..data_end+2] != "\r\n" { - Err(DBError(format!( - "[new bulk string] unmatched string length in prototocl {:?}", - protocol, - ))) - } else { - Ok((Protocol::BulkString(s), &protocol[data_end + 2..])) + // If we don't yet have the full bulk payload + trailing CRLF, signal INCOMPLETE + if protocol.len() < data_end + 2 { + return Err(DBError("[incomplete] bulk body".to_string())); } + if &protocol[data_end..data_end + 2] != "\r\n" { + return Err(DBError("[incomplete] bulk terminator".to_string())); + } + + let s = Self::parse_string(&protocol[data_start..data_end])?; + Ok((Protocol::BulkString(s), &protocol[data_end + 2..])) } else { - Err(DBError(format!( - "[new bulk string] unsupported protocol: {:?}", - protocol - ))) + // No CRLF after bulk length header yet + Err(DBError("[incomplete] bulk header".to_string())) } } @@ -125,16 +128,25 @@ impl Protocol { let mut remaining = &s[len_end + 2..]; let mut vec = vec![]; for _ in 0..array_len { - let (p, rem) = Protocol::from(remaining)?; - vec.push(p); - remaining = rem; + match Protocol::from(remaining) { + Ok((p, rem)) => { + vec.push(p); + remaining = rem; + } + Err(e) => { + // Propagate incomplete so caller can read more bytes + if e.0.starts_with("[incomplete]") { + return Err(e); + } else { + return Err(e); + } + } + } } Ok((Protocol::Array(vec), remaining)) } else { - Err(DBError(format!( - "[new array] unsupported protocol: {:?}", - s - ))) + // No CRLF after array header yet + Err(DBError("[incomplete] array header".to_string())) } } diff --git a/herodb/src/server.rs b/herodb/src/server.rs index 68a0219..23a93af 100644 --- a/herodb/src/server.rs +++ b/herodb/src/server.rs @@ -160,31 +160,40 @@ impl Server { &mut self, mut stream: tokio::net::TcpStream, ) -> Result<(), DBError> { - let mut buf = [0; 512]; - + // Accumulate incoming bytes to handle partial RESP frames + let mut acc = String::new(); + let mut buf = vec![0u8; 8192]; + loop { - let len = match stream.read(&mut buf).await { + let n = match stream.read(&mut buf).await { Ok(0) => { println!("[handle] connection closed"); return Ok(()); } - Ok(len) => len, + Ok(n) => n, Err(e) => { println!("[handle] read error: {:?}", e); return Err(e.into()); } }; - let mut s = str::from_utf8(&buf[..len])?; - while !s.is_empty() { - let (cmd, protocol, remaining) = match Cmd::from(s) { + // Append to accumulator. RESP for our usage is ASCII-safe. + acc.push_str(str::from_utf8(&buf[..n])?); + + // Try to parse as many complete commands as are available in 'acc'. + loop { + let parsed = Cmd::from(&acc); + let (cmd, protocol, remaining) = match parsed { Ok((cmd, protocol, remaining)) => (cmd, protocol, remaining), - Err(e) => { - println!("\x1b[31;1mprotocol error: {:?}\x1b[0m", e); - (Cmd::Unknow("protocol_error".to_string()), Protocol::err(&format!("protocol error: {}", e.0)), "") + Err(_e) => { + // Incomplete or invalid frame; assume incomplete and wait for more data. + // This avoids emitting spurious protocol_error for split frames. + break; } }; - s = remaining; + + // Advance the accumulator to the unparsed remainder + acc = remaining.to_string(); if self.option.debug { println!("\x1b[34;1mgot command: {:?}, protocol: {:?}\x1b[0m", cmd, protocol); @@ -204,7 +213,7 @@ impl Server { Protocol::err(&format!("ERR {}", e.0)) } }; - + if self.option.debug { println!("\x1b[34;1mqueued cmd {:?}\x1b[0m", self.queued_cmd); println!("\x1b[32;1mgoing to send response {}\x1b[0m", res.encode()); @@ -220,6 +229,11 @@ impl Server { println!("[handle] QUIT command received, closing connection"); return Ok(()); } + + // Continue parsing any further complete commands already in 'acc' + if acc.is_empty() { + break; + } } } } diff --git a/herodb/tests/usage_suite.rs b/herodb/tests/usage_suite.rs index 5ec554d..c61fecf 100644 --- a/herodb/tests/usage_suite.rs +++ b/herodb/tests/usage_suite.rs @@ -70,14 +70,107 @@ async fn connect(port: u16) -> TcpStream { } } +fn find_crlf(buf: &[u8], start: usize) -> Option { + let mut i = start; + while i + 1 < buf.len() { + if buf[i] == b'\r' && buf[i + 1] == b'\n' { + return Some(i); + } + i += 1; + } + None +} + +fn parse_number_i64(buf: &[u8], start: usize, end: usize) -> Option { + let s = std::str::from_utf8(&buf[start..end]).ok()?; + s.parse::().ok() +} + +// Return number of bytes that make up a complete RESP element starting at 'i', or None if incomplete. +fn parse_elem(buf: &[u8], i: usize) -> Option { + if i >= buf.len() { + return None; + } + match buf[i] { + b'+' | b'-' | b':' => { + let end = find_crlf(buf, i + 1)?; + Some(end + 2 - i) + } + b'$' => { + let hdr_end = find_crlf(buf, i + 1)?; + let n = parse_number_i64(buf, i + 1, hdr_end)?; + if n < 0 { + // Null bulk string: only header + Some(hdr_end + 2 - i) + } else { + let need = hdr_end + 2 + (n as usize) + 2; + if need <= buf.len() { + Some(need - i) + } else { + None + } + } + } + b'*' => { + let hdr_end = find_crlf(buf, i + 1)?; + let n = parse_number_i64(buf, i + 1, hdr_end)?; + if n < 0 { + // Null array: only header + Some(hdr_end + 2 - i) + } else { + let mut j = hdr_end + 2; + for _ in 0..(n as usize) { + let consumed = parse_elem(buf, j)?; + j += consumed; + } + Some(j - i) + } + } + _ => None, + } +} + +fn resp_frame_len(buf: &[u8]) -> Option { + parse_elem(buf, 0) +} + +async fn read_full_resp(stream: &mut TcpStream) -> String { + let mut buf: Vec = Vec::with_capacity(8192); + let mut tmp = vec![0u8; 4096]; + + loop { + if let Some(total) = resp_frame_len(&buf) { + if buf.len() >= total { + return String::from_utf8_lossy(&buf[..total]).to_string(); + } + } + + match tokio::time::timeout(Duration::from_secs(2), stream.read(&mut tmp)).await { + Ok(Ok(n)) => { + if n == 0 { + if let Some(total) = resp_frame_len(&buf) { + if buf.len() >= total { + return String::from_utf8_lossy(&buf[..total]).to_string(); + } + } + return String::from_utf8_lossy(&buf).to_string(); + } + buf.extend_from_slice(&tmp[..n]); + } + Ok(Err(e)) => panic!("read error: {}", e), + Err(_) => panic!("timeout waiting for reply"), + } + + if buf.len() > 8 * 1024 * 1024 { + panic!("reply too large"); + } + } +} + async fn send_cmd(stream: &mut TcpStream, args: &[&str]) -> String { let req = build_resp(args); stream.write_all(req.as_bytes()).await.unwrap(); - - // Single read is enough for these small replies - let mut buf = vec![0u8; 8192]; - let n = stream.read(&mut buf).await.unwrap(); - String::from_utf8_lossy(&buf[..n]).to_string() + read_full_resp(stream).await } // Assert helpers with clearer output @@ -559,6 +652,58 @@ async fn test_10_expire_pexpire_persist() { assert_contains(&persist2, "0", "PERSIST again -> 0 (no expiration to remove)"); } +#[tokio::test] +async fn test_11_set_with_options() { + let (server, port) = start_test_server("set_opts").await; + spawn_listener(server, port).await; + sleep(Duration::from_millis(150)).await; + + let mut s = connect(port).await; + + // SET with GET on non-existing key -> returns Null, sets value + let set_get1 = send_cmd(&mut s, &["SET", "s1", "v1", "GET"]).await; + assert_contains(&set_get1, "$-1", "SET s1 v1 GET returns Null when key didn't exist"); + let g1 = send_cmd(&mut s, &["GET", "s1"]).await; + assert_contains(&g1, "v1", "GET s1 after first SET"); + + // SET with GET should return old value, then set to new + let set_get2 = send_cmd(&mut s, &["SET", "s1", "v2", "GET"]).await; + assert_contains(&set_get2, "v1", "SET s1 v2 GET returns previous value v1"); + let g2 = send_cmd(&mut s, &["GET", "s1"]).await; + assert_contains(&g2, "v2", "GET s1 now v2"); + + // NX prevents update when key exists; with GET should return Null and not change + let set_nx = send_cmd(&mut s, &["SET", "s1", "v3", "NX", "GET"]).await; + assert_contains(&set_nx, "$-1", "SET s1 v3 NX GET returns Null when not set"); + let g3 = send_cmd(&mut s, &["GET", "s1"]).await; + assert_contains(&g3, "v2", "GET s1 remains v2 after NX prevented write"); + + // NX allows set when key does not exist + let set_nx2 = send_cmd(&mut s, &["SET", "s2", "v10", "NX"]).await; + assert_contains(&set_nx2, "OK", "SET s2 v10 NX -> OK for new key"); + let g4 = send_cmd(&mut s, &["GET", "s2"]).await; + assert_contains(&g4, "v10", "GET s2 is v10"); + + // XX requires existing key; with GET returns old value and sets new + let set_xx = send_cmd(&mut s, &["SET", "s2", "v11", "XX", "GET"]).await; + assert_contains(&set_xx, "v10", "SET s2 v11 XX GET returns previous v10"); + let g5 = send_cmd(&mut s, &["GET", "s2"]).await; + assert_contains(&g5, "v11", "GET s2 is now v11"); + + // PX expiration path via SET options + let set_px = send_cmd(&mut s, &["SET", "s3", "vpx", "PX", "500"]).await; + assert_contains(&set_px, "OK", "SET s3 vpx PX 500 -> OK"); + let ttl_px1 = send_cmd(&mut s, &["TTL", "s3"]).await; + assert!( + ttl_px1.contains("0") || ttl_px1.contains("1"), + "TTL s3 immediately after PX should be 1 or 0, got: {}", + ttl_px1 + ); + sleep(Duration::from_millis(650)).await; + let g6 = send_cmd(&mut s, &["GET", "s3"]).await; + assert_contains(&g6, "$-1", "GET s3 after PX expiry -> Null"); +} + #[tokio::test] async fn test_09_mget_mset_and_variadic_exists_del() { let (server, port) = start_test_server("mget_mset_variadic").await; @@ -597,4 +742,41 @@ async fn test_09_mget_mset_and_variadic_exists_del() { assert_contains(&mget_after, "$-1", "MGET k1 after DEL -> Null"); assert_contains(&mget_after, "v2", "MGET k2 remains"); assert_contains(&mget_after, "$-1", "MGET k3 after DEL -> Null"); +} +#[tokio::test] +async fn test_12_hash_incr() { + let (server, port) = start_test_server("hash_incr").await; + spawn_listener(server, port).await; + sleep(Duration::from_millis(150)).await; + + let mut s = connect(port).await; + + // Integer increments + let _ = send_cmd(&mut s, &["HSET", "hinc", "a", "1"]).await; + let r1 = send_cmd(&mut s, &["HINCRBY", "hinc", "a", "2"]).await; + assert_contains(&r1, "3", "HINCRBY hinc a 2 -> 3"); + + let r2 = send_cmd(&mut s, &["HINCRBY", "hinc", "a", "-1"]).await; + assert_contains(&r2, "2", "HINCRBY hinc a -1 -> 2"); + + let r3 = send_cmd(&mut s, &["HINCRBY", "hinc", "b", "5"]).await; + assert_contains(&r3, "5", "HINCRBY hinc b 5 -> 5"); + + // HINCRBY error on non-integer field + let _ = send_cmd(&mut s, &["HSET", "hinc", "s", "x"]).await; + let r_err = send_cmd(&mut s, &["HINCRBY", "hinc", "s", "1"]).await; + assert_contains(&r_err, "ERR", "HINCRBY on non-integer field should ERR"); + + // Float increments + let r4 = send_cmd(&mut s, &["HINCRBYFLOAT", "hinc", "f", "1.5"]).await; + assert_contains(&r4, "1.5", "HINCRBYFLOAT hinc f 1.5 -> 1.5"); + + let r5 = send_cmd(&mut s, &["HINCRBYFLOAT", "hinc", "f", "2.5"]).await; + // Could be "4", "4.0", or "4.000000", accept "4" substring + assert_contains(&r5, "4", "HINCRBYFLOAT hinc f 2.5 -> 4"); + + // HINCRBYFLOAT error on non-float field + let _ = send_cmd(&mut s, &["HSET", "hinc", "notf", "abc"]).await; + let r6 = send_cmd(&mut s, &["HINCRBYFLOAT", "hinc", "notf", "1"]).await; + assert_contains(&r6, "ERR", "HINCRBYFLOAT on non-float field should ERR"); } \ No newline at end of file -- 2.40.1 From 463000c8f7ba3e9e29ac517c4b509b142be75c12 Mon Sep 17 00:00:00 2001 From: Maxime Van Hees Date: Tue, 19 Aug 2025 15:52:36 +0200 Subject: [PATCH 6/8] Implemented BRPOP and minimal COMMAND DOCS stub, and wired side-aware waiter delivery --- herodb/src/cmd.rs | 115 +++++++++++++++++++++++- herodb/src/server.rs | 18 +++- herodb/tests/redis_integration_tests.rs | 8 +- herodb/tests/usage_suite.rs | 36 ++++++++ 4 files changed, 167 insertions(+), 10 deletions(-) diff --git a/herodb/src/cmd.rs b/herodb/src/cmd.rs index 8a9058c..d4287ef 100644 --- a/herodb/src/cmd.rs +++ b/herodb/src/cmd.rs @@ -58,6 +58,7 @@ pub enum Cmd { LPop(String, Option), RPop(String, Option), BLPop(Vec, f64), + BRPop(Vec, f64), LLen(String), LRem(String, i64, String), LTrim(String, i64, i64), @@ -503,6 +504,17 @@ impl Cmd { .map_err(|_| DBError("ERR timeout is not a number".to_string()))?; Cmd::BLPop(keys, timeout_f) } + "brpop" => { + if cmd.len() < 3 { + return Err(DBError(format!("wrong number of arguments for BRPOP command"))); + } + // keys are all but the last argument + let keys = cmd[1..cmd.len()-1].to_vec(); + let timeout_f = cmd[cmd.len()-1] + .parse::() + .map_err(|_| DBError("ERR timeout is not a number".to_string()))?; + Cmd::BRPop(keys, timeout_f) + } "llen" => { if cmd.len() != 2 { return Err(DBError(format!("wrong number of arguments for LLEN command"))); @@ -663,13 +675,14 @@ impl Cmd { Cmd::Client(_) => Ok(Protocol::SimpleString("OK".to_string())), Cmd::ClientSetName(name) => client_setname_cmd(server, &name).await, Cmd::ClientGetName => client_getname_cmd(server).await, - Cmd::Command(_) => Ok(Protocol::Array(vec![])), + Cmd::Command(args) => command_cmd(&args), // List commands Cmd::LPush(key, elements) => lpush_cmd(server, &key, &elements).await, Cmd::RPush(key, elements) => rpush_cmd(server, &key, &elements).await, Cmd::LPop(key, count) => lpop_cmd(server, &key, &count).await, Cmd::RPop(key, count) => rpop_cmd(server, &key, &count).await, Cmd::BLPop(keys, timeout) => blpop_cmd(server, &keys, timeout).await, + Cmd::BRPop(keys, timeout) => brpop_cmd(server, &keys, timeout).await, Cmd::LLen(key) => llen_cmd(server, &key).await, Cmd::LRem(key, count, element) => lrem_cmd(server, &key, count, &element).await, Cmd::LTrim(key, start, stop) => ltrim_cmd(server, &key, start, stop).await, @@ -825,7 +838,89 @@ async fn blpop_cmd(server: &Server, keys: &[String], timeout_secs: f64) -> Resul let mut rxs: Vec> = Vec::with_capacity(keys.len()); for k in keys { - let (id, rx) = server.register_waiter(db_index, k).await; + let (id, rx) = server.register_waiter(db_index, k, crate::server::PopSide::Left).await; + ids.push(id); + names.push(k.clone()); + rxs.push(rx); + } + + // Wait for the first delivery or timeout + let wait_fut = async move { + let mut futures_vec = rxs; + loop { + if futures_vec.is_empty() { + return None; + } + let (res, idx, remaining) = select_all(futures_vec).await; + match res { + Ok((k, elem)) => { + return Some((k, elem, idx, remaining)); + } + Err(_canceled) => { + // That waiter was canceled; continue with the rest + futures_vec = remaining; + continue; + } + } + } + }; + + match timeout(Duration::from_secs_f64(timeout_secs), wait_fut).await { + Ok(Some((k, elem, idx, _remaining))) => { + // Unregister other waiters + for (i, key_name) in names.iter().enumerate() { + if i != idx { + server.unregister_waiter(db_index, key_name, ids[i]).await; + } + } + Ok(Protocol::Array(vec![ + Protocol::BulkString(k), + Protocol::BulkString(elem), + ])) + } + Ok(None) => { + // No futures left; unregister all waiters + for (i, key_name) in names.iter().enumerate() { + server.unregister_waiter(db_index, key_name, ids[i]).await; + } + Ok(Protocol::Null) + } + Err(_elapsed) => { + // Timeout: unregister all waiters + for (i, key_name) in names.iter().enumerate() { + server.unregister_waiter(db_index, key_name, ids[i]).await; + } + Ok(Protocol::Null) + } + } +} + +// BRPOP implementation (mirror of BLPOP, popping from the right) +async fn brpop_cmd(server: &Server, keys: &[String], timeout_secs: f64) -> Result { + // Immediate, non-blocking attempt in key order using RPOP + for k in keys { + let elems = server.current_storage()?.rpop(k, 1)?; + if !elems.is_empty() { + return Ok(Protocol::Array(vec![ + Protocol::BulkString(k.clone()), + Protocol::BulkString(elems[0].clone()), + ])); + } + } + + // If timeout is zero, return immediately with Null + if timeout_secs <= 0.0 { + return Ok(Protocol::Null); + } + + // Register waiters for each key (Right side) + let db_index = server.selected_db; + let mut ids: Vec = Vec::with_capacity(keys.len()); + let mut names: Vec = Vec::with_capacity(keys.len()); + let mut rxs: Vec> = Vec::with_capacity(keys.len()); + + for k in keys { + let (id, rx) = server.register_waiter(db_index, k, crate::server::PopSide::Right).await; ids.push(id); names.push(k.clone()); rxs.push(rx); @@ -1358,3 +1453,19 @@ async fn client_getname_cmd(server: &Server) -> Result { None => Ok(Protocol::Null), } } + +// Minimal COMMAND subcommands stub to satisfy redis-cli probes. +// - COMMAND DOCS ... => return empty array +// - COMMAND INFO ... => return empty array +// - Any other => empty array +fn command_cmd(args: &[String]) -> Result { + if args.is_empty() { + return Ok(Protocol::Array(vec![])); + } + let sub = args[0].to_lowercase(); + match sub.as_str() { + "docs" => Ok(Protocol::Array(vec![])), + "info" => Ok(Protocol::Array(vec![])), + _ => Ok(Protocol::Array(vec![])), + } +} diff --git a/herodb/src/server.rs b/herodb/src/server.rs index 23a93af..0c128c6 100644 --- a/herodb/src/server.rs +++ b/herodb/src/server.rs @@ -28,9 +28,16 @@ pub struct Server { pub struct Waiter { pub id: u64, + pub side: PopSide, pub tx: oneshot::Sender<(String, String)>, // (key, element) } +#[derive(Clone, Copy, Debug, PartialEq, Eq)] +pub enum PopSide { + Left, + Right, +} + impl Server { pub async fn new(option: options::DBOption) -> Self { Server { @@ -83,14 +90,14 @@ impl Server { // ----- BLPOP waiter helpers ----- - pub async fn register_waiter(&self, db_index: u64, key: &str) -> (u64, oneshot::Receiver<(String, String)>) { + pub async fn register_waiter(&self, db_index: u64, key: &str, side: PopSide) -> (u64, oneshot::Receiver<(String, String)>) { let id = self.waiter_seq.fetch_add(1, Ordering::Relaxed); let (tx, rx) = oneshot::channel::<(String, String)>(); let mut guard = self.list_waiters.lock().await; let per_db = guard.entry(db_index).or_insert_with(HashMap::new); let q = per_db.entry(key.to_string()).or_insert_with(Vec::new); - q.push(Waiter { id, tx }); + q.push(Waiter { id, side, tx }); (id, rx) } @@ -135,8 +142,11 @@ impl Server { let waiter = if let Some(w) = maybe_waiter { w } else { break }; - // Pop one element from the left - let elems = self.current_storage()?.lpop(key, 1)?; + // Pop one element depending on waiter side + let elems = match waiter.side { + PopSide::Left => self.current_storage()?.lpop(key, 1)?, + PopSide::Right => self.current_storage()?.rpop(key, 1)?, + }; if elems.is_empty() { // Nothing to deliver; re-register waiter at the front to preserve order let mut guard = self.list_waiters.lock().await; diff --git a/herodb/tests/redis_integration_tests.rs b/herodb/tests/redis_integration_tests.rs index 16e1f64..47033e1 100644 --- a/herodb/tests/redis_integration_tests.rs +++ b/herodb/tests/redis_integration_tests.rs @@ -16,9 +16,9 @@ fn get_redis_connection(port: u16) -> Connection { } } Err(e) => { - if attempts >= 20 { + if attempts >= 120 { panic!( - "Failed to connect to Redis server after 20 attempts: {}", + "Failed to connect to Redis server after 120 attempts: {}", e ); } @@ -88,8 +88,8 @@ fn setup_server() -> (ServerProcessGuard, u16) { test_dir, }; - // Give the server a moment to start - std::thread::sleep(Duration::from_millis(500)); + // Give the server time to build and start (cargo run may compile first) + std::thread::sleep(Duration::from_millis(2500)); (guard, port) } diff --git a/herodb/tests/usage_suite.rs b/herodb/tests/usage_suite.rs index c61fecf..9591193 100644 --- a/herodb/tests/usage_suite.rs +++ b/herodb/tests/usage_suite.rs @@ -779,4 +779,40 @@ async fn test_12_hash_incr() { let _ = send_cmd(&mut s, &["HSET", "hinc", "notf", "abc"]).await; let r6 = send_cmd(&mut s, &["HINCRBYFLOAT", "hinc", "notf", "1"]).await; assert_contains(&r6, "ERR", "HINCRBYFLOAT on non-float field should ERR"); +} +#[tokio::test] +async fn test_05b_brpop_suite() { + let (server, port) = start_test_server("lists_brpop").await; + spawn_listener(server, port).await; + sleep(Duration::from_millis(150)).await; + + let mut a = connect(port).await; + + // RPUSH some initial data, BRPOP should take from the right + let _ = send_cmd(&mut a, &["RPUSH", "q:rjobs", "1", "2"]).await; + let br_nonblock = send_cmd(&mut a, &["BRPOP", "q:rjobs", "0"]).await; + // Should pop the rightmost element "2" + assert_contains(&br_nonblock, "q:rjobs", "BRPOP returns key"); + assert_contains(&br_nonblock, "2", "BRPOP returns rightmost element"); + + // Now test blocking BRPOP: start blocked client, then RPUSH from another client + let c1 = connect(port).await; + let mut c2 = connect(port).await; + + // Start BRPOP on c1 + let brpop_task = tokio::spawn(async move { + let mut c1_local = c1; + send_cmd(&mut c1_local, &["BRPOP", "q:blockr", "5"]).await + }); + + // Give it time to register waiter + sleep(Duration::from_millis(150)).await; + + // Push from right to wake BRPOP + let _ = send_cmd(&mut c2, &["RPUSH", "q:blockr", "X"]).await; + + // Await BRPOP result + let brpop_res = brpop_task.await.expect("BRPOP task join"); + assert_contains(&brpop_res, "q:blockr", "BRPOP returned key"); + assert_contains(&brpop_res, "X", "BRPOP returned element"); } \ No newline at end of file -- 2.40.1 From b9a9f3e6d610501597305d2319d69b85f26548df Mon Sep 17 00:00:00 2001 From: Maxime Van Hees Date: Tue, 19 Aug 2025 16:05:25 +0200 Subject: [PATCH 7/8] Implemented DBSIZE --- herodb/src/cmd.rs | 15 +++++++++++++ herodb/src/storage/storage_basic.rs | 27 +++++++++++++++++++++++ herodb/tests/usage_suite.rs | 34 +++++++++++++++++++++++++++++ 3 files changed, 76 insertions(+) diff --git a/herodb/src/cmd.rs b/herodb/src/cmd.rs index d4287ef..518aa8a 100644 --- a/herodb/src/cmd.rs +++ b/herodb/src/cmd.rs @@ -17,6 +17,7 @@ pub enum Cmd { MGet(Vec), MSet(Vec<(String, String)>), Keys, + DbSize, ConfigGet(String), Info(Option), Del(String), @@ -191,6 +192,12 @@ impl Cmd { Cmd::Keys } } + "dbsize" => { + if cmd.len() != 1 { + return Err(DBError(format!("wrong number of arguments for DBSIZE command"))); + } + Cmd::DbSize + } "info" => { let section = if cmd.len() == 2 { Some(cmd[1].clone()) @@ -634,6 +641,7 @@ impl Cmd { Cmd::DelMulti(keys) => del_multi_cmd(server, &keys).await, Cmd::ConfigGet(name) => config_get_cmd(&name, server), Cmd::Keys => keys_cmd(server).await, + Cmd::DbSize => dbsize_cmd(server).await, Cmd::Info(section) => info_cmd(server, §ion).await, Cmd::Type(k) => type_cmd(server, &k).await, Cmd::Incr(key) => incr_cmd(server, &key).await, @@ -1060,6 +1068,13 @@ async fn keys_cmd(server: &Server) -> Result { )) } +async fn dbsize_cmd(server: &Server) -> Result { + match server.current_storage()?.dbsize() { + Ok(n) => Ok(Protocol::SimpleString(n.to_string())), + Err(e) => Ok(Protocol::err(&e.0)), + } +} + #[derive(Serialize)] struct ServerInfo { redis_version: String, diff --git a/herodb/src/storage/storage_basic.rs b/herodb/src/storage/storage_basic.rs index a394cb7..1594b87 100644 --- a/herodb/src/storage/storage_basic.rs +++ b/herodb/src/storage/storage_basic.rs @@ -215,4 +215,31 @@ impl Storage { Ok(keys) } +} + +impl Storage { + pub fn dbsize(&self) -> Result { + let read_txn = self.db.begin_read()?; + let types_table = read_txn.open_table(TYPES_TABLE)?; + let expiration_table = read_txn.open_table(EXPIRATION_TABLE)?; + + let mut count: i64 = 0; + let mut iter = types_table.iter()?; + while let Some(entry) = iter.next() { + let entry = entry?; + let key = entry.0.value(); + let ty = entry.1.value(); + + if ty == "string" { + if let Some(expires_at) = expiration_table.get(key)? { + if now_in_millis() > expires_at.value() as u128 { + // Skip logically expired string keys + continue; + } + } + } + count += 1; + } + Ok(count) + } } \ No newline at end of file diff --git a/herodb/tests/usage_suite.rs b/herodb/tests/usage_suite.rs index 9591193..a77e9ae 100644 --- a/herodb/tests/usage_suite.rs +++ b/herodb/tests/usage_suite.rs @@ -815,4 +815,38 @@ async fn test_05b_brpop_suite() { let brpop_res = brpop_task.await.expect("BRPOP task join"); assert_contains(&brpop_res, "q:blockr", "BRPOP returned key"); assert_contains(&brpop_res, "X", "BRPOP returned element"); +} +#[tokio::test] +async fn test_13_dbsize() { + let (server, port) = start_test_server("dbsize").await; + spawn_listener(server, port).await; + sleep(Duration::from_millis(150)).await; + + let mut s = connect(port).await; + + // Initially empty + let n0 = send_cmd(&mut s, &["DBSIZE"]).await; + assert_contains(&n0, "0", "DBSIZE initial should be 0"); + + // Add a string, a hash, and a list -> dbsize = 3 + let _ = send_cmd(&mut s, &["SET", "s", "v"]).await; + let _ = send_cmd(&mut s, &["HSET", "h", "f", "v"]).await; + let _ = send_cmd(&mut s, &["LPUSH", "l", "a", "b"]).await; + + let n3 = send_cmd(&mut s, &["DBSIZE"]).await; + assert_contains(&n3, "3", "DBSIZE after adding s,h,l should be 3"); + + // Expire the string and wait, dbsize should drop to 2 + let _ = send_cmd(&mut s, &["PEXPIRE", "s", "400"]).await; + sleep(Duration::from_millis(500)).await; + + let n2 = send_cmd(&mut s, &["DBSIZE"]).await; + assert_contains(&n2, "2", "DBSIZE after string expiry should be 2"); + + // Delete remaining keys and confirm 0 + let _ = send_cmd(&mut s, &["DEL", "h"]).await; + let _ = send_cmd(&mut s, &["DEL", "l"]).await; + + let n_final = send_cmd(&mut s, &["DBSIZE"]).await; + assert_contains(&n_final, "0", "DBSIZE after deleting all keys should be 0"); } \ No newline at end of file -- 2.40.1 From 892e6e2b90ca498ab2e65a5f5ab37805c64b6c7f Mon Sep 17 00:00:00 2001 From: Maxime Van Hees Date: Tue, 19 Aug 2025 16:21:43 +0200 Subject: [PATCH 8/8] Implemented EXPIREAT and PEXPIREAT --- herodb/src/cmd.rs | 33 ++++++++++++++++++++++ herodb/src/storage/storage_extra.rs | 44 +++++++++++++++++++++++++++++ herodb/tests/usage_suite.rs | 40 ++++++++++++++++++++++++++ 3 files changed, 117 insertions(+) diff --git a/herodb/src/cmd.rs b/herodb/src/cmd.rs index 518aa8a..9817f89 100644 --- a/herodb/src/cmd.rs +++ b/herodb/src/cmd.rs @@ -44,6 +44,8 @@ pub enum Cmd { Ttl(String), Expire(String, i64), PExpire(String, i64), + ExpireAt(String, i64), + PExpireAt(String, i64), Persist(String), Exists(String), ExistsMulti(Vec), @@ -417,6 +419,20 @@ impl Cmd { let ms = cmd[2].parse::().map_err(|_| DBError("ERR value is not an integer or out of range".to_string()))?; Cmd::PExpire(cmd[1].clone(), ms) } + "expireat" => { + if cmd.len() != 3 { + return Err(DBError("wrong number of arguments for EXPIREAT command".to_string())); + } + let ts = cmd[2].parse::().map_err(|_| DBError("ERR value is not an integer or out of range".to_string()))?; + Cmd::ExpireAt(cmd[1].clone(), ts) + } + "pexpireat" => { + if cmd.len() != 3 { + return Err(DBError("wrong number of arguments for PEXPIREAT command".to_string())); + } + let ts_ms = cmd[2].parse::().map_err(|_| DBError("ERR value is not an integer or out of range".to_string()))?; + Cmd::PExpireAt(cmd[1].clone(), ts_ms) + } "persist" => { if cmd.len() != 2 { return Err(DBError("wrong number of arguments for PERSIST command".to_string())); @@ -676,6 +692,8 @@ impl Cmd { Cmd::Ttl(key) => ttl_cmd(server, &key).await, Cmd::Expire(key, secs) => expire_cmd(server, &key, secs).await, Cmd::PExpire(key, ms) => pexpire_cmd(server, &key, ms).await, + Cmd::ExpireAt(key, ts_secs) => expireat_cmd(server, &key, ts_secs).await, + Cmd::PExpireAt(key, ts_ms) => pexpireat_cmd(server, &key, ts_ms).await, Cmd::Persist(key) => persist_cmd(server, &key).await, Cmd::Exists(key) => exists_cmd(server, &key).await, Cmd::ExistsMulti(keys) => exists_multi_cmd(server, &keys).await, @@ -1456,6 +1474,21 @@ async fn persist_cmd(server: &Server, key: &str) -> Result { Err(e) => Ok(Protocol::err(&e.0)), } } +// EXPIREAT key timestamp-seconds -> 1 if timeout set, 0 otherwise +async fn expireat_cmd(server: &Server, key: &str, ts_secs: i64) -> Result { + match server.current_storage()?.expire_at_seconds(key, ts_secs) { + Ok(applied) => Ok(Protocol::SimpleString(if applied { "1" } else { "0" }.to_string())), + Err(e) => Ok(Protocol::err(&e.0)), + } +} + +// PEXPIREAT key timestamp-milliseconds -> 1 if timeout set, 0 otherwise +async fn pexpireat_cmd(server: &Server, key: &str, ts_ms: i64) -> Result { + match server.current_storage()?.pexpire_at_millis(key, ts_ms) { + Ok(applied) => Ok(Protocol::SimpleString(if applied { "1" } else { "0" }.to_string())), + Err(e) => Ok(Protocol::err(&e.0)), + } +} async fn client_setname_cmd(server: &mut Server, name: &str) -> Result { server.client_name = Some(name.to_string()); diff --git a/herodb/src/storage/storage_extra.rs b/herodb/src/storage/storage_extra.rs index 8a12674..4f2e8f7 100644 --- a/herodb/src/storage/storage_extra.rs +++ b/herodb/src/storage/storage_extra.rs @@ -164,6 +164,50 @@ impl Storage { write_txn.commit()?; Ok(removed) } + + // Absolute EXPIREAT in seconds since epoch + // Returns true if applied (key exists and is string), false otherwise + pub fn expire_at_seconds(&self, key: &str, ts_secs: i64) -> Result { + let mut applied = false; + let write_txn = self.db.begin_write()?; + { + let types_table = write_txn.open_table(TYPES_TABLE)?; + let is_string = types_table + .get(key)? + .map(|v| v.value() == "string") + .unwrap_or(false); + if is_string { + let mut expiration_table = write_txn.open_table(EXPIRATION_TABLE)?; + let expires_at_ms: u128 = if ts_secs <= 0 { 0 } else { (ts_secs as u128) * 1000 }; + expiration_table.insert(key, &((expires_at_ms as u64)))?; + applied = true; + } + } + write_txn.commit()?; + Ok(applied) + } + + // Absolute PEXPIREAT in milliseconds since epoch + // Returns true if applied (key exists and is string), false otherwise + pub fn pexpire_at_millis(&self, key: &str, ts_ms: i64) -> Result { + let mut applied = false; + let write_txn = self.db.begin_write()?; + { + let types_table = write_txn.open_table(TYPES_TABLE)?; + let is_string = types_table + .get(key)? + .map(|v| v.value() == "string") + .unwrap_or(false); + if is_string { + let mut expiration_table = write_txn.open_table(EXPIRATION_TABLE)?; + let expires_at_ms: u128 = if ts_ms <= 0 { 0 } else { ts_ms as u128 }; + expiration_table.insert(key, &((expires_at_ms as u64)))?; + applied = true; + } + } + write_txn.commit()?; + Ok(applied) + } } // Utility function for glob pattern matching diff --git a/herodb/tests/usage_suite.rs b/herodb/tests/usage_suite.rs index a77e9ae..a330a0b 100644 --- a/herodb/tests/usage_suite.rs +++ b/herodb/tests/usage_suite.rs @@ -849,4 +849,44 @@ async fn test_13_dbsize() { let n_final = send_cmd(&mut s, &["DBSIZE"]).await; assert_contains(&n_final, "0", "DBSIZE after deleting all keys should be 0"); +} +#[tokio::test] +async fn test_14_expireat_pexpireat() { + use std::time::{SystemTime, UNIX_EPOCH}; + + let (server, port) = start_test_server("expireat_suite").await; + spawn_listener(server, port).await; + sleep(Duration::from_millis(150)).await; + + let mut s = connect(port).await; + + // EXPIREAT: seconds since epoch + let now_secs = SystemTime::now().duration_since(UNIX_EPOCH).unwrap().as_secs() as i64; + let _ = send_cmd(&mut s, &["SET", "exp:at:s", "v"]).await; + let exat = send_cmd(&mut s, &["EXPIREAT", "exp:at:s", &format!("{}", now_secs + 1)]).await; + assert_contains(&exat, "1", "EXPIREAT exp:at:s now+1s -> 1 (applied)"); + let ttl1 = send_cmd(&mut s, &["TTL", "exp:at:s"]).await; + assert!( + ttl1.contains("1") || ttl1.contains("0"), + "TTL exp:at:s should be 1 or 0 shortly after EXPIREAT, got: {}", + ttl1 + ); + sleep(Duration::from_millis(1200)).await; + let exists_after_exat = send_cmd(&mut s, &["EXISTS", "exp:at:s"]).await; + assert_contains(&exists_after_exat, "0", "EXISTS exp:at:s after EXPIREAT expiry -> 0"); + + // PEXPIREAT: milliseconds since epoch + let now_ms = SystemTime::now().duration_since(UNIX_EPOCH).unwrap().as_millis() as i64; + let _ = send_cmd(&mut s, &["SET", "exp:at:ms", "v"]).await; + let pexat = send_cmd(&mut s, &["PEXPIREAT", "exp:at:ms", &format!("{}", now_ms + 450)]).await; + assert_contains(&pexat, "1", "PEXPIREAT exp:at:ms now+450ms -> 1 (applied)"); + let ttl2 = send_cmd(&mut s, &["TTL", "exp:at:ms"]).await; + assert!( + ttl2.contains("0") || ttl2.contains("1"), + "TTL exp:at:ms should be 0..1 soon after PEXPIREAT, got: {}", + ttl2 + ); + sleep(Duration::from_millis(600)).await; + let exists_after_pexat = send_cmd(&mut s, &["EXISTS", "exp:at:ms"]).await; + assert_contains(&exists_after_pexat, "0", "EXISTS exp:at:ms after PEXPIREAT expiry -> 0"); } \ No newline at end of file -- 2.40.1