Merge pull request 'BLPOP + COMMAND + MGET/MSET + DEL/EXISTS + EXPIRE/PEXPIRE/PERSIST + HINCRBY/HINCRBYFLOAT + BRPOP + DBSIZE + EXPIREAT/PEXIREAT implementations' (#1) from blpop into main
Reviewed-on: #1
This commit is contained in:
		| @@ -1,5 +1,7 @@ | ||||
| use crate::{error::DBError, protocol::Protocol, server::Server}; | ||||
| use serde::Serialize; | ||||
| use tokio::time::{timeout, Duration}; | ||||
| use futures::future::select_all; | ||||
|  | ||||
| #[derive(Debug, Clone)] | ||||
| pub enum Cmd { | ||||
| @@ -10,7 +12,12 @@ pub enum Cmd { | ||||
|     Set(String, String), | ||||
|     SetPx(String, String, u128), | ||||
|     SetEx(String, String, u128), | ||||
|     // Advanced SET with options: (key, value, ex_ms, nx, xx, get) | ||||
|     SetOpts(String, String, Option<u128>, bool, bool, bool), | ||||
|     MGet(Vec<String>), | ||||
|     MSet(Vec<(String, String)>), | ||||
|     Keys, | ||||
|     DbSize, | ||||
|     ConfigGet(String), | ||||
|     Info(Option<String>), | ||||
|     Del(String), | ||||
| @@ -30,19 +37,31 @@ pub enum Cmd { | ||||
|     HLen(String), | ||||
|     HMGet(String, Vec<String>), | ||||
|     HSetNx(String, String, String), | ||||
|     HIncrBy(String, String, i64), | ||||
|     HIncrByFloat(String, String, f64), | ||||
|     HScan(String, u64, Option<String>, Option<u64>), // key, cursor, pattern, count | ||||
|     Scan(u64, Option<String>, Option<u64>), // cursor, pattern, count | ||||
|     Ttl(String), | ||||
|     Expire(String, i64), | ||||
|     PExpire(String, i64), | ||||
|     ExpireAt(String, i64), | ||||
|     PExpireAt(String, i64), | ||||
|     Persist(String), | ||||
|     Exists(String), | ||||
|     ExistsMulti(Vec<String>), | ||||
|     DelMulti(Vec<String>), | ||||
|     Quit, | ||||
|     Client(Vec<String>), | ||||
|     ClientSetName(String), | ||||
|     ClientGetName, | ||||
|     Command(Vec<String>), | ||||
|     // List commands | ||||
|     LPush(String, Vec<String>), | ||||
|     RPush(String, Vec<String>), | ||||
|     LPop(String, Option<u64>), | ||||
|     RPop(String, Option<u64>), | ||||
|     BLPop(Vec<String>, f64), | ||||
|     BRPop(Vec<String>, f64), | ||||
|     LLen(String), | ||||
|     LRem(String, i64, String), | ||||
|     LTrim(String, i64, i64), | ||||
| @@ -90,14 +109,51 @@ impl Cmd { | ||||
|                         "ping" => Cmd::Ping, | ||||
|                         "get" => Cmd::Get(cmd[1].clone()), | ||||
|                         "set" => { | ||||
|                             if cmd.len() == 5 && cmd[3].to_lowercase() == "px" { | ||||
|                                 Cmd::SetPx(cmd[1].clone(), cmd[2].clone(), cmd[4].parse().unwrap()) | ||||
|                             } else if cmd.len() == 5 && cmd[3].to_lowercase() == "ex" { | ||||
|                                 Cmd::SetEx(cmd[1].clone(), cmd[2].clone(), cmd[4].parse().unwrap()) | ||||
|                             } else if cmd.len() == 3 { | ||||
|                                 Cmd::Set(cmd[1].clone(), cmd[2].clone()) | ||||
|                             if cmd.len() < 3 { | ||||
|                                 return Err(DBError("wrong number of arguments for SET".to_string())); | ||||
|                             } | ||||
|                             let key = cmd[1].clone(); | ||||
|                             let val = cmd[2].clone(); | ||||
|  | ||||
|                             // Parse optional flags: EX sec | PX ms | NX | XX | GET | ||||
|                             let mut ex_ms: Option<u128> = None; | ||||
|                             let mut nx = false; | ||||
|                             let mut xx = false; | ||||
|                             let mut getflag = false; | ||||
|  | ||||
|                             let mut i = 3; | ||||
|                             while i < cmd.len() { | ||||
|                                 match cmd[i].to_lowercase().as_str() { | ||||
|                                     "ex" => { | ||||
|                                         if i + 1 >= cmd.len() { | ||||
|                                             return Err(DBError("ERR syntax error".to_string())); | ||||
|                                         } | ||||
|                                         let secs: u128 = cmd[i + 1].parse().map_err(|_| DBError("ERR value is not an integer or out of range".to_string()))?; | ||||
|                                         ex_ms = Some(secs * 1000); | ||||
|                                         i += 2; | ||||
|                                     } | ||||
|                                     "px" => { | ||||
|                                         if i + 1 >= cmd.len() { | ||||
|                                             return Err(DBError("ERR syntax error".to_string())); | ||||
|                                         } | ||||
|                                         let ms: u128 = cmd[i + 1].parse().map_err(|_| DBError("ERR value is not an integer or out of range".to_string()))?; | ||||
|                                         ex_ms = Some(ms); | ||||
|                                         i += 2; | ||||
|                                     } | ||||
|                                     "nx" => { nx = true; i += 1; } | ||||
|                                     "xx" => { xx = true; i += 1; } | ||||
|                                     "get" => { getflag = true; i += 1; } | ||||
|                                     _ => { | ||||
|                                         return Err(DBError(format!("unsupported cmd {:?}", cmd))); | ||||
|                                     } | ||||
|                                 } | ||||
|                             } | ||||
|  | ||||
|                             // If no options, keep legacy behavior | ||||
|                             if ex_ms.is_none() && !nx && !xx && !getflag { | ||||
|                                 Cmd::Set(key, val) | ||||
|                             } else { | ||||
|                                 return Err(DBError(format!("unsupported cmd {:?}", cmd))); | ||||
|                                 Cmd::SetOpts(key, val, ex_ms, nx, xx, getflag) | ||||
|                             } | ||||
|                         } | ||||
|                         "setex" => { | ||||
| @@ -106,6 +162,24 @@ impl Cmd { | ||||
|                             } | ||||
|                             Cmd::SetEx(cmd[1].clone(), cmd[3].clone(), cmd[2].parse().unwrap()) | ||||
|                         } | ||||
|                         "mget" => { | ||||
|                             if cmd.len() < 2 { | ||||
|                                 return Err(DBError("wrong number of arguments for MGET command".to_string())); | ||||
|                             } | ||||
|                             Cmd::MGet(cmd[1..].to_vec()) | ||||
|                         } | ||||
|                         "mset" => { | ||||
|                             if cmd.len() < 3 || ((cmd.len() - 1) % 2 != 0) { | ||||
|                                 return Err(DBError("wrong number of arguments for MSET command".to_string())); | ||||
|                             } | ||||
|                             let mut pairs = Vec::new(); | ||||
|                             let mut i = 1; | ||||
|                             while i + 1 < cmd.len() { | ||||
|                                 pairs.push((cmd[i].clone(), cmd[i + 1].clone())); | ||||
|                                 i += 2; | ||||
|                             } | ||||
|                             Cmd::MSet(pairs) | ||||
|                         } | ||||
|                         "config" => { | ||||
|                             if cmd.len() != 3 || cmd[1].to_lowercase() != "get" { | ||||
|                                 return Err(DBError(format!("unsupported cmd {:?}", cmd))); | ||||
| @@ -120,6 +194,12 @@ impl Cmd { | ||||
|                                 Cmd::Keys | ||||
|                             } | ||||
|                         } | ||||
|                         "dbsize" => { | ||||
|                             if cmd.len() != 1 { | ||||
|                                 return Err(DBError(format!("wrong number of arguments for DBSIZE command"))); | ||||
|                             } | ||||
|                             Cmd::DbSize | ||||
|                         } | ||||
|                         "info" => { | ||||
|                             let section = if cmd.len() == 2 { | ||||
|                                 Some(cmd[1].clone()) | ||||
| @@ -129,10 +209,14 @@ impl Cmd { | ||||
|                             Cmd::Info(section) | ||||
|                         } | ||||
|                         "del" => { | ||||
|                             if cmd.len() != 2 { | ||||
|                                 return Err(DBError(format!("unsupported cmd {:?}", cmd))); | ||||
|                             if cmd.len() < 2 { | ||||
|                                 return Err(DBError(format!("wrong number of arguments for DEL command"))); | ||||
|                             } | ||||
|                             if cmd.len() == 2 { | ||||
|                                 Cmd::Del(cmd[1].clone()) | ||||
|                             } else { | ||||
|                                 Cmd::DelMulti(cmd[1..].to_vec()) | ||||
|                             } | ||||
|                             Cmd::Del(cmd[1].clone()) | ||||
|                         } | ||||
|                         "type" => { | ||||
|                             if cmd.len() != 2 { | ||||
| @@ -226,6 +310,20 @@ impl Cmd { | ||||
|                             } | ||||
|                             Cmd::HSetNx(cmd[1].clone(), cmd[2].clone(), cmd[3].clone()) | ||||
|                         } | ||||
|                         "hincrby" => { | ||||
|                             if cmd.len() != 4 { | ||||
|                                 return Err(DBError(format!("wrong number of arguments for HINCRBY command"))); | ||||
|                             } | ||||
|                             let delta = cmd[3].parse::<i64>().map_err(|_| DBError("ERR value is not an integer or out of range".to_string()))?; | ||||
|                             Cmd::HIncrBy(cmd[1].clone(), cmd[2].clone(), delta) | ||||
|                         } | ||||
|                         "hincrbyfloat" => { | ||||
|                             if cmd.len() != 4 { | ||||
|                                 return Err(DBError(format!("wrong number of arguments for HINCRBYFLOAT command"))); | ||||
|                             } | ||||
|                             let delta = cmd[3].parse::<f64>().map_err(|_| DBError("ERR value is not a valid float".to_string()))?; | ||||
|                             Cmd::HIncrByFloat(cmd[1].clone(), cmd[2].clone(), delta) | ||||
|                         } | ||||
|                         "hscan" => { | ||||
|                             if cmd.len() < 3 { | ||||
|                                 return Err(DBError(format!("wrong number of arguments for HSCAN command"))); | ||||
| @@ -307,11 +405,49 @@ impl Cmd { | ||||
|                             } | ||||
|                             Cmd::Ttl(cmd[1].clone()) | ||||
|                         } | ||||
|                         "exists" => { | ||||
|                         "expire" => { | ||||
|                             if cmd.len() != 3 { | ||||
|                                 return Err(DBError("wrong number of arguments for EXPIRE command".to_string())); | ||||
|                             } | ||||
|                             let secs = cmd[2].parse::<i64>().map_err(|_| DBError("ERR value is not an integer or out of range".to_string()))?; | ||||
|                             Cmd::Expire(cmd[1].clone(), secs) | ||||
|                         } | ||||
|                         "pexpire" => { | ||||
|                             if cmd.len() != 3 { | ||||
|                                 return Err(DBError("wrong number of arguments for PEXPIRE command".to_string())); | ||||
|                             } | ||||
|                             let ms = cmd[2].parse::<i64>().map_err(|_| DBError("ERR value is not an integer or out of range".to_string()))?; | ||||
|                             Cmd::PExpire(cmd[1].clone(), ms) | ||||
|                         } | ||||
|                         "expireat" => { | ||||
|                             if cmd.len() != 3 { | ||||
|                                 return Err(DBError("wrong number of arguments for EXPIREAT command".to_string())); | ||||
|                             } | ||||
|                             let ts = cmd[2].parse::<i64>().map_err(|_| DBError("ERR value is not an integer or out of range".to_string()))?; | ||||
|                             Cmd::ExpireAt(cmd[1].clone(), ts) | ||||
|                         } | ||||
|                         "pexpireat" => { | ||||
|                             if cmd.len() != 3 { | ||||
|                                 return Err(DBError("wrong number of arguments for PEXPIREAT command".to_string())); | ||||
|                             } | ||||
|                             let ts_ms = cmd[2].parse::<i64>().map_err(|_| DBError("ERR value is not an integer or out of range".to_string()))?; | ||||
|                             Cmd::PExpireAt(cmd[1].clone(), ts_ms) | ||||
|                         } | ||||
|                         "persist" => { | ||||
|                             if cmd.len() != 2 { | ||||
|                                 return Err(DBError("wrong number of arguments for PERSIST command".to_string())); | ||||
|                             } | ||||
|                             Cmd::Persist(cmd[1].clone()) | ||||
|                         } | ||||
|                         "exists" => { | ||||
|                             if cmd.len() < 2 { | ||||
|                                 return Err(DBError(format!("wrong number of arguments for EXISTS command"))); | ||||
|                             } | ||||
|                             Cmd::Exists(cmd[1].clone()) | ||||
|                             if cmd.len() == 2 { | ||||
|                                 Cmd::Exists(cmd[1].clone()) | ||||
|                             } else { | ||||
|                                 Cmd::ExistsMulti(cmd[1..].to_vec()) | ||||
|                             } | ||||
|                         } | ||||
|                         "quit" => { | ||||
|                             if cmd.len() != 1 { | ||||
| @@ -342,6 +478,10 @@ impl Cmd { | ||||
|                                 Cmd::Client(vec![]) | ||||
|                             } | ||||
|                         } | ||||
|                         "command" => { | ||||
|                             let args = if cmd.len() > 1 { cmd[1..].to_vec() } else { vec![] }; | ||||
|                             Cmd::Command(args) | ||||
|                         } | ||||
|                         "lpush" => { | ||||
|                             if cmd.len() < 3 { | ||||
|                                 return Err(DBError(format!("wrong number of arguments for LPUSH command"))); | ||||
| @@ -376,6 +516,28 @@ impl Cmd { | ||||
|                             }; | ||||
|                             Cmd::RPop(cmd[1].clone(), count) | ||||
|                         } | ||||
|                         "blpop" => { | ||||
|                             if cmd.len() < 3 { | ||||
|                                 return Err(DBError(format!("wrong number of arguments for BLPOP command"))); | ||||
|                             } | ||||
|                             // keys are all but the last argument | ||||
|                             let keys = cmd[1..cmd.len()-1].to_vec(); | ||||
|                             let timeout_f = cmd[cmd.len()-1] | ||||
|                                 .parse::<f64>() | ||||
|                                 .map_err(|_| DBError("ERR timeout is not a number".to_string()))?; | ||||
|                             Cmd::BLPop(keys, timeout_f) | ||||
|                         } | ||||
|                         "brpop" => { | ||||
|                             if cmd.len() < 3 { | ||||
|                                 return Err(DBError(format!("wrong number of arguments for BRPOP command"))); | ||||
|                             } | ||||
|                             // keys are all but the last argument | ||||
|                             let keys = cmd[1..cmd.len()-1].to_vec(); | ||||
|                             let timeout_f = cmd[cmd.len()-1] | ||||
|                                 .parse::<f64>() | ||||
|                                 .map_err(|_| DBError("ERR timeout is not a number".to_string()))?; | ||||
|                             Cmd::BRPop(keys, timeout_f) | ||||
|                         } | ||||
|                         "llen" => { | ||||
|                             if cmd.len() != 2 { | ||||
|                                 return Err(DBError(format!("wrong number of arguments for LLEN command"))); | ||||
| @@ -488,9 +650,14 @@ impl Cmd { | ||||
|             Cmd::Set(k, v) => set_cmd(server, &k, &v).await, | ||||
|             Cmd::SetPx(k, v, x) => set_px_cmd(server, &k, &v, &x).await, | ||||
|             Cmd::SetEx(k, v, x) => set_ex_cmd(server, &k, &v, &x).await, | ||||
|             Cmd::SetOpts(k, v, ex_ms, nx, xx, getflag) => set_with_opts_cmd(server, &k, &v, ex_ms, nx, xx, getflag).await, | ||||
|             Cmd::MGet(keys) => mget_cmd(server, &keys).await, | ||||
|             Cmd::MSet(pairs) => mset_cmd(server, &pairs).await, | ||||
|             Cmd::Del(k) => del_cmd(server, &k).await, | ||||
|             Cmd::DelMulti(keys) => del_multi_cmd(server, &keys).await, | ||||
|             Cmd::ConfigGet(name) => config_get_cmd(&name, server), | ||||
|             Cmd::Keys => keys_cmd(server).await, | ||||
|             Cmd::DbSize => dbsize_cmd(server).await, | ||||
|             Cmd::Info(section) => info_cmd(server, §ion).await, | ||||
|             Cmd::Type(k) => type_cmd(server, &k).await, | ||||
|             Cmd::Incr(key) => incr_cmd(server, &key).await, | ||||
| @@ -518,19 +685,30 @@ impl Cmd { | ||||
|             Cmd::HLen(key) => hlen_cmd(server, &key).await, | ||||
|             Cmd::HMGet(key, fields) => hmget_cmd(server, &key, &fields).await, | ||||
|             Cmd::HSetNx(key, field, value) => hsetnx_cmd(server, &key, &field, &value).await, | ||||
|             Cmd::HIncrBy(key, field, delta) => hincrby_cmd(server, &key, &field, delta).await, | ||||
|             Cmd::HIncrByFloat(key, field, delta) => hincrbyfloat_cmd(server, &key, &field, delta).await, | ||||
|             Cmd::HScan(key, cursor, pattern, count) => hscan_cmd(server, &key, &cursor, pattern.as_deref(), &count).await, | ||||
|             Cmd::Scan(cursor, pattern, count) => scan_cmd(server, &cursor, pattern.as_deref(), &count).await, | ||||
|             Cmd::Ttl(key) => ttl_cmd(server, &key).await, | ||||
|             Cmd::Expire(key, secs) => expire_cmd(server, &key, secs).await, | ||||
|             Cmd::PExpire(key, ms) => pexpire_cmd(server, &key, ms).await, | ||||
|             Cmd::ExpireAt(key, ts_secs) => expireat_cmd(server, &key, ts_secs).await, | ||||
|             Cmd::PExpireAt(key, ts_ms) => pexpireat_cmd(server, &key, ts_ms).await, | ||||
|             Cmd::Persist(key) => persist_cmd(server, &key).await, | ||||
|             Cmd::Exists(key) => exists_cmd(server, &key).await, | ||||
|             Cmd::ExistsMulti(keys) => exists_multi_cmd(server, &keys).await, | ||||
|             Cmd::Quit => Ok(Protocol::SimpleString("OK".to_string())), | ||||
|             Cmd::Client(_) => Ok(Protocol::SimpleString("OK".to_string())), | ||||
|             Cmd::ClientSetName(name) => client_setname_cmd(server, &name).await, | ||||
|             Cmd::ClientGetName => client_getname_cmd(server).await, | ||||
|             Cmd::Command(args) => command_cmd(&args), | ||||
|             // List commands | ||||
|             Cmd::LPush(key, elements) => lpush_cmd(server, &key, &elements).await, | ||||
|             Cmd::RPush(key, elements) => rpush_cmd(server, &key, &elements).await, | ||||
|             Cmd::LPop(key, count) => lpop_cmd(server, &key, &count).await, | ||||
|             Cmd::RPop(key, count) => rpop_cmd(server, &key, &count).await, | ||||
|             Cmd::BLPop(keys, timeout) => blpop_cmd(server, &keys, timeout).await, | ||||
|             Cmd::BRPop(keys, timeout) => brpop_cmd(server, &keys, timeout).await, | ||||
|             Cmd::LLen(key) => llen_cmd(server, &key).await, | ||||
|             Cmd::LRem(key, count, element) => lrem_cmd(server, &key, count, &element).await, | ||||
|             Cmd::LTrim(key, start, stop) => ltrim_cmd(server, &key, start, stop).await, | ||||
| @@ -661,16 +839,188 @@ async fn rpop_cmd(server: &Server, key: &str, count: &Option<u64>) -> Result<Pro | ||||
|     } | ||||
| } | ||||
|  | ||||
| // BLPOP implementation | ||||
| async fn blpop_cmd(server: &Server, keys: &[String], timeout_secs: f64) -> Result<Protocol, DBError> { | ||||
|     // Immediate, non-blocking attempt in key order | ||||
|     for k in keys { | ||||
|         let elems = server.current_storage()?.lpop(k, 1)?; | ||||
|         if !elems.is_empty() { | ||||
|             return Ok(Protocol::Array(vec![ | ||||
|                 Protocol::BulkString(k.clone()), | ||||
|                 Protocol::BulkString(elems[0].clone()), | ||||
|             ])); | ||||
|         } | ||||
|     } | ||||
|  | ||||
|     // If timeout is zero, return immediately with Null | ||||
|     if timeout_secs <= 0.0 { | ||||
|         return Ok(Protocol::Null); | ||||
|     } | ||||
|  | ||||
|     // Register waiters for each key | ||||
|     let db_index = server.selected_db; | ||||
|     let mut ids: Vec<u64> = Vec::with_capacity(keys.len()); | ||||
|     let mut names: Vec<String> = Vec::with_capacity(keys.len()); | ||||
|     let mut rxs: Vec<tokio::sync::oneshot::Receiver<(String, String)>> = Vec::with_capacity(keys.len()); | ||||
|  | ||||
|     for k in keys { | ||||
|         let (id, rx) = server.register_waiter(db_index, k, crate::server::PopSide::Left).await; | ||||
|         ids.push(id); | ||||
|         names.push(k.clone()); | ||||
|         rxs.push(rx); | ||||
|     } | ||||
|  | ||||
|     // Wait for the first delivery or timeout | ||||
|     let wait_fut = async move { | ||||
|         let mut futures_vec = rxs; | ||||
|         loop { | ||||
|             if futures_vec.is_empty() { | ||||
|                 return None; | ||||
|             } | ||||
|             let (res, idx, remaining) = select_all(futures_vec).await; | ||||
|             match res { | ||||
|                 Ok((k, elem)) => { | ||||
|                     return Some((k, elem, idx, remaining)); | ||||
|                 } | ||||
|                 Err(_canceled) => { | ||||
|                     // That waiter was canceled; continue with the rest | ||||
|                     futures_vec = remaining; | ||||
|                     continue; | ||||
|                 } | ||||
|             } | ||||
|         } | ||||
|     }; | ||||
|  | ||||
|     match timeout(Duration::from_secs_f64(timeout_secs), wait_fut).await { | ||||
|         Ok(Some((k, elem, idx, _remaining))) => { | ||||
|             // Unregister other waiters | ||||
|             for (i, key_name) in names.iter().enumerate() { | ||||
|                 if i != idx { | ||||
|                     server.unregister_waiter(db_index, key_name, ids[i]).await; | ||||
|                 } | ||||
|             } | ||||
|             Ok(Protocol::Array(vec![ | ||||
|                 Protocol::BulkString(k), | ||||
|                 Protocol::BulkString(elem), | ||||
|             ])) | ||||
|         } | ||||
|         Ok(None) => { | ||||
|             // No futures left; unregister all waiters | ||||
|             for (i, key_name) in names.iter().enumerate() { | ||||
|                 server.unregister_waiter(db_index, key_name, ids[i]).await; | ||||
|             } | ||||
|             Ok(Protocol::Null) | ||||
|         } | ||||
|         Err(_elapsed) => { | ||||
|             // Timeout: unregister all waiters | ||||
|             for (i, key_name) in names.iter().enumerate() { | ||||
|                 server.unregister_waiter(db_index, key_name, ids[i]).await; | ||||
|             } | ||||
|             Ok(Protocol::Null) | ||||
|         } | ||||
|     } | ||||
| } | ||||
|  | ||||
| // BRPOP implementation (mirror of BLPOP, popping from the right) | ||||
| async fn brpop_cmd(server: &Server, keys: &[String], timeout_secs: f64) -> Result<Protocol, DBError> { | ||||
|     // Immediate, non-blocking attempt in key order using RPOP | ||||
|     for k in keys { | ||||
|         let elems = server.current_storage()?.rpop(k, 1)?; | ||||
|         if !elems.is_empty() { | ||||
|             return Ok(Protocol::Array(vec![ | ||||
|                 Protocol::BulkString(k.clone()), | ||||
|                 Protocol::BulkString(elems[0].clone()), | ||||
|             ])); | ||||
|         } | ||||
|     } | ||||
|  | ||||
|     // If timeout is zero, return immediately with Null | ||||
|     if timeout_secs <= 0.0 { | ||||
|         return Ok(Protocol::Null); | ||||
|     } | ||||
|  | ||||
|     // Register waiters for each key (Right side) | ||||
|     let db_index = server.selected_db; | ||||
|     let mut ids: Vec<u64> = Vec::with_capacity(keys.len()); | ||||
|     let mut names: Vec<String> = Vec::with_capacity(keys.len()); | ||||
|     let mut rxs: Vec<tokio::sync::oneshot::Receiver<(String, String)>> = Vec::with_capacity(keys.len()); | ||||
|  | ||||
|     for k in keys { | ||||
|         let (id, rx) = server.register_waiter(db_index, k, crate::server::PopSide::Right).await; | ||||
|         ids.push(id); | ||||
|         names.push(k.clone()); | ||||
|         rxs.push(rx); | ||||
|     } | ||||
|  | ||||
|     // Wait for the first delivery or timeout | ||||
|     let wait_fut = async move { | ||||
|         let mut futures_vec = rxs; | ||||
|         loop { | ||||
|             if futures_vec.is_empty() { | ||||
|                 return None; | ||||
|             } | ||||
|             let (res, idx, remaining) = select_all(futures_vec).await; | ||||
|             match res { | ||||
|                 Ok((k, elem)) => { | ||||
|                     return Some((k, elem, idx, remaining)); | ||||
|                 } | ||||
|                 Err(_canceled) => { | ||||
|                     // That waiter was canceled; continue with the rest | ||||
|                     futures_vec = remaining; | ||||
|                     continue; | ||||
|                 } | ||||
|             } | ||||
|         } | ||||
|     }; | ||||
|  | ||||
|     match timeout(Duration::from_secs_f64(timeout_secs), wait_fut).await { | ||||
|         Ok(Some((k, elem, idx, _remaining))) => { | ||||
|             // Unregister other waiters | ||||
|             for (i, key_name) in names.iter().enumerate() { | ||||
|                 if i != idx { | ||||
|                     server.unregister_waiter(db_index, key_name, ids[i]).await; | ||||
|                 } | ||||
|             } | ||||
|             Ok(Protocol::Array(vec![ | ||||
|                 Protocol::BulkString(k), | ||||
|                 Protocol::BulkString(elem), | ||||
|             ])) | ||||
|         } | ||||
|         Ok(None) => { | ||||
|             // No futures left; unregister all waiters | ||||
|             for (i, key_name) in names.iter().enumerate() { | ||||
|                 server.unregister_waiter(db_index, key_name, ids[i]).await; | ||||
|             } | ||||
|             Ok(Protocol::Null) | ||||
|         } | ||||
|         Err(_elapsed) => { | ||||
|             // Timeout: unregister all waiters | ||||
|             for (i, key_name) in names.iter().enumerate() { | ||||
|                 server.unregister_waiter(db_index, key_name, ids[i]).await; | ||||
|             } | ||||
|             Ok(Protocol::Null) | ||||
|         } | ||||
|     } | ||||
| } | ||||
|  | ||||
| async fn lpush_cmd(server: &Server, key: &str, elements: &[String]) -> Result<Protocol, DBError> { | ||||
|     match server.current_storage()?.lpush(key, elements.to_vec()) { | ||||
|         Ok(len) => Ok(Protocol::SimpleString(len.to_string())), | ||||
|         Ok(len) => { | ||||
|             // Attempt to deliver to any blocked BLPOP waiters | ||||
|             let _ = server.drain_waiters_after_push(key).await; | ||||
|             Ok(Protocol::SimpleString(len.to_string())) | ||||
|         } | ||||
|         Err(e) => Ok(Protocol::err(&e.0)), | ||||
|     } | ||||
| } | ||||
|  | ||||
| async fn rpush_cmd(server: &Server, key: &str, elements: &[String]) -> Result<Protocol, DBError> { | ||||
|     match server.current_storage()?.rpush(key, elements.to_vec()) { | ||||
|         Ok(len) => Ok(Protocol::SimpleString(len.to_string())), | ||||
|         Ok(len) => { | ||||
|             // Attempt to deliver to any blocked BLPOP waiters | ||||
|             let _ = server.drain_waiters_after_push(key).await; | ||||
|             Ok(Protocol::SimpleString(len.to_string())) | ||||
|         } | ||||
|         Err(e) => Ok(Protocol::err(&e.0)), | ||||
|     } | ||||
| } | ||||
| @@ -736,6 +1086,13 @@ async fn keys_cmd(server: &Server) -> Result<Protocol, DBError> { | ||||
|     )) | ||||
| } | ||||
|  | ||||
| async fn dbsize_cmd(server: &Server) -> Result<Protocol, DBError> { | ||||
|     match server.current_storage()?.dbsize() { | ||||
|         Ok(n) => Ok(Protocol::SimpleString(n.to_string())), | ||||
|         Err(e) => Ok(Protocol::err(&e.0)), | ||||
|     } | ||||
| } | ||||
|  | ||||
| #[derive(Serialize)] | ||||
| struct ServerInfo { | ||||
|     redis_version: String, | ||||
| @@ -757,17 +1114,19 @@ async fn info_cmd(server: &Server, section: &Option<String>) -> Result<Protocol, | ||||
|     info_string.push_str(&format!("# Keyspace\n")); | ||||
|     info_string.push_str(&format!("db{}:keys=0,expires=0,avg_ttl=0\n", info.selected_db)); | ||||
|  | ||||
|  | ||||
|     match section { | ||||
|         Some(s) => match s.as_str() { | ||||
|             "replication" => Ok(Protocol::BulkString( | ||||
|                 "role:master\nmaster_replid:8371b4fb1155b71f4a04d3e1bc3e18c4a990aeea\nmaster_repl_offset:0\n".to_string() | ||||
|             )), | ||||
|             _ => Err(DBError(format!("unsupported section {:?}", s))), | ||||
|         }, | ||||
|         None => { | ||||
|             Ok(Protocol::BulkString(info_string)) | ||||
|         Some(s) => { | ||||
|             let sl = s.to_lowercase(); | ||||
|             if sl == "replication" { | ||||
|                 Ok(Protocol::BulkString( | ||||
|                     "role:master\nmaster_replid:8371b4fb1155b71f4a04d3e1bc3e18c4a990aeea\nmaster_repl_offset:0\n".to_string() | ||||
|                 )) | ||||
|             } else { | ||||
|                 // Return general info for unknown sections (e.g., SERVER) | ||||
|                 Ok(Protocol::BulkString(info_string)) | ||||
|             } | ||||
|         } | ||||
|         None => Ok(Protocol::BulkString(info_string)), | ||||
|     } | ||||
| } | ||||
|  | ||||
| @@ -808,6 +1167,109 @@ async fn set_cmd(server: &Server, k: &str, v: &str) -> Result<Protocol, DBError> | ||||
|     Ok(Protocol::SimpleString("OK".to_string())) | ||||
| } | ||||
|  | ||||
| // Advanced SET with options: EX/PX/NX/XX/GET | ||||
| async fn set_with_opts_cmd( | ||||
|     server: &Server, | ||||
|     key: &str, | ||||
|     value: &str, | ||||
|     ex_ms: Option<u128>, | ||||
|     nx: bool, | ||||
|     xx: bool, | ||||
|     get_old: bool, | ||||
| ) -> Result<Protocol, DBError> { | ||||
|     let storage = server.current_storage()?; | ||||
|  | ||||
|     // Determine existence (for NX/XX) | ||||
|     let exists = storage.exists(key)?; | ||||
|  | ||||
|     // If both NX and XX, condition can never be satisfied -> no-op | ||||
|     let mut should_set = true; | ||||
|     if nx && exists { | ||||
|         should_set = false; | ||||
|     } | ||||
|     if xx && !exists { | ||||
|         should_set = false; | ||||
|     } | ||||
|  | ||||
|     // Fetch old value if needed for GET | ||||
|     let old_val = if get_old { | ||||
|         storage.get(key)? | ||||
|     } else { | ||||
|         None | ||||
|     }; | ||||
|  | ||||
|     if should_set { | ||||
|         if let Some(ms) = ex_ms { | ||||
|             storage.setx(key.to_string(), value.to_string(), ms)?; | ||||
|         } else { | ||||
|             storage.set(key.to_string(), value.to_string())?; | ||||
|         } | ||||
|     } | ||||
|  | ||||
|     if get_old { | ||||
|         // Return previous value (or Null), regardless of NX/XX outcome only if set executed? | ||||
|         // We follow Redis semantics: return old value if set executed, else Null | ||||
|         if should_set { | ||||
|             Ok(old_val.map_or(Protocol::Null, Protocol::BulkString)) | ||||
|         } else { | ||||
|             Ok(Protocol::Null) | ||||
|         } | ||||
|     } else { | ||||
|         if should_set { | ||||
|             Ok(Protocol::SimpleString("OK".to_string())) | ||||
|         } else { | ||||
|             Ok(Protocol::Null) | ||||
|         } | ||||
|     } | ||||
| } | ||||
|  | ||||
| // MGET: return array of bulk strings or Null for missing | ||||
| async fn mget_cmd(server: &Server, keys: &[String]) -> Result<Protocol, DBError> { | ||||
|     let mut out: Vec<Protocol> = Vec::with_capacity(keys.len()); | ||||
|     let storage = server.current_storage()?; | ||||
|     for k in keys { | ||||
|         match storage.get(k)? { | ||||
|             Some(v) => out.push(Protocol::BulkString(v)), | ||||
|             None => out.push(Protocol::Null), | ||||
|         } | ||||
|     } | ||||
|     Ok(Protocol::Array(out)) | ||||
| } | ||||
|  | ||||
| // MSET: set multiple key/value pairs, return OK | ||||
| async fn mset_cmd(server: &Server, pairs: &[(String, String)]) -> Result<Protocol, DBError> { | ||||
|     let storage = server.current_storage()?; | ||||
|     for (k, v) in pairs { | ||||
|         storage.set(k.clone(), v.clone())?; | ||||
|     } | ||||
|     Ok(Protocol::SimpleString("OK".to_string())) | ||||
| } | ||||
|  | ||||
| // DEL with multiple keys: return count of keys actually deleted | ||||
| async fn del_multi_cmd(server: &Server, keys: &[String]) -> Result<Protocol, DBError> { | ||||
|     let storage = server.current_storage()?; | ||||
|     let mut deleted = 0i64; | ||||
|     for k in keys { | ||||
|         if storage.exists(k)? { | ||||
|             storage.del(k.clone())?; | ||||
|             deleted += 1; | ||||
|         } | ||||
|     } | ||||
|     Ok(Protocol::SimpleString(deleted.to_string())) | ||||
| } | ||||
|  | ||||
| // EXISTS with multiple keys: return count existing | ||||
| async fn exists_multi_cmd(server: &Server, keys: &[String]) -> Result<Protocol, DBError> { | ||||
|     let storage = server.current_storage()?; | ||||
|     let mut count = 0i64; | ||||
|     for k in keys { | ||||
|         if storage.exists(k)? { | ||||
|             count += 1; | ||||
|         } | ||||
|     } | ||||
|     Ok(Protocol::SimpleString(count.to_string())) | ||||
| } | ||||
|  | ||||
| async fn get_cmd(server: &Server, k: &str) -> Result<Protocol, DBError> { | ||||
|     let v = server.current_storage()?.get(k)?; | ||||
|     Ok(v.map_or(Protocol::Null, Protocol::BulkString)) | ||||
| @@ -900,6 +1362,32 @@ async fn hsetnx_cmd(server: &Server, key: &str, field: &str, value: &str) -> Res | ||||
|     } | ||||
| } | ||||
|  | ||||
| async fn hincrby_cmd(server: &Server, key: &str, field: &str, delta: i64) -> Result<Protocol, DBError> { | ||||
|     let storage = server.current_storage()?; | ||||
|     let current = storage.hget(key, field)?; | ||||
|     let base: i64 = match current { | ||||
|         Some(v) => v.parse::<i64>().map_err(|_| DBError("ERR value is not an integer or out of range".to_string()))?, | ||||
|         None => 0, | ||||
|     }; | ||||
|     let new_val = base.checked_add(delta).ok_or_else(|| DBError("ERR increment or decrement would overflow".to_string()))?; | ||||
|     // Update the field | ||||
|     storage.hset(key, vec![(field.to_string(), new_val.to_string())])?; | ||||
|     Ok(Protocol::SimpleString(new_val.to_string())) | ||||
| } | ||||
|  | ||||
| async fn hincrbyfloat_cmd(server: &Server, key: &str, field: &str, delta: f64) -> Result<Protocol, DBError> { | ||||
|     let storage = server.current_storage()?; | ||||
|     let current = storage.hget(key, field)?; | ||||
|     let base: f64 = match current { | ||||
|         Some(v) => v.parse::<f64>().map_err(|_| DBError("ERR value is not a valid float".to_string()))?, | ||||
|         None => 0.0, | ||||
|     }; | ||||
|     let new_val = base + delta; | ||||
|     // Update the field | ||||
|     storage.hset(key, vec![(field.to_string(), new_val.to_string())])?; | ||||
|     Ok(Protocol::SimpleString(new_val.to_string())) | ||||
| } | ||||
|  | ||||
| async fn scan_cmd( | ||||
|     server: &Server, | ||||
|     cursor: &u64, | ||||
| @@ -957,6 +1445,51 @@ async fn exists_cmd(server: &Server, key: &str) -> Result<Protocol, DBError> { | ||||
|     } | ||||
| } | ||||
|  | ||||
| // EXPIRE key seconds -> 1 if timeout set, 0 otherwise | ||||
| async fn expire_cmd(server: &Server, key: &str, secs: i64) -> Result<Protocol, DBError> { | ||||
|     if secs < 0 { | ||||
|         return Ok(Protocol::SimpleString("0".to_string())); | ||||
|     } | ||||
|     match server.current_storage()?.expire_seconds(key, secs as u64) { | ||||
|         Ok(applied) => Ok(Protocol::SimpleString(if applied { "1" } else { "0" }.to_string())), | ||||
|         Err(e) => Ok(Protocol::err(&e.0)), | ||||
|     } | ||||
| } | ||||
|  | ||||
| // PEXPIRE key milliseconds -> 1 if timeout set, 0 otherwise | ||||
| async fn pexpire_cmd(server: &Server, key: &str, ms: i64) -> Result<Protocol, DBError> { | ||||
|     if ms < 0 { | ||||
|         return Ok(Protocol::SimpleString("0".to_string())); | ||||
|     } | ||||
|     match server.current_storage()?.pexpire_millis(key, ms as u128) { | ||||
|         Ok(applied) => Ok(Protocol::SimpleString(if applied { "1" } else { "0" }.to_string())), | ||||
|         Err(e) => Ok(Protocol::err(&e.0)), | ||||
|     } | ||||
| } | ||||
|  | ||||
| // PERSIST key -> 1 if timeout removed, 0 otherwise | ||||
| async fn persist_cmd(server: &Server, key: &str) -> Result<Protocol, DBError> { | ||||
|     match server.current_storage()?.persist(key) { | ||||
|         Ok(removed) => Ok(Protocol::SimpleString(if removed { "1" } else { "0" }.to_string())), | ||||
|         Err(e) => Ok(Protocol::err(&e.0)), | ||||
|     } | ||||
| } | ||||
| // EXPIREAT key timestamp-seconds -> 1 if timeout set, 0 otherwise | ||||
| async fn expireat_cmd(server: &Server, key: &str, ts_secs: i64) -> Result<Protocol, DBError> { | ||||
|     match server.current_storage()?.expire_at_seconds(key, ts_secs) { | ||||
|         Ok(applied) => Ok(Protocol::SimpleString(if applied { "1" } else { "0" }.to_string())), | ||||
|         Err(e) => Ok(Protocol::err(&e.0)), | ||||
|     } | ||||
| } | ||||
|  | ||||
| // PEXPIREAT key timestamp-milliseconds -> 1 if timeout set, 0 otherwise | ||||
| async fn pexpireat_cmd(server: &Server, key: &str, ts_ms: i64) -> Result<Protocol, DBError> { | ||||
|     match server.current_storage()?.pexpire_at_millis(key, ts_ms) { | ||||
|         Ok(applied) => Ok(Protocol::SimpleString(if applied { "1" } else { "0" }.to_string())), | ||||
|         Err(e) => Ok(Protocol::err(&e.0)), | ||||
|     } | ||||
| } | ||||
|  | ||||
| async fn client_setname_cmd(server: &mut Server, name: &str) -> Result<Protocol, DBError> { | ||||
|     server.client_name = Some(name.to_string()); | ||||
|     Ok(Protocol::SimpleString("OK".to_string())) | ||||
| @@ -968,3 +1501,19 @@ async fn client_getname_cmd(server: &Server) -> Result<Protocol, DBError> { | ||||
|         None => Ok(Protocol::Null), | ||||
|     } | ||||
| } | ||||
|  | ||||
| // Minimal COMMAND subcommands stub to satisfy redis-cli probes. | ||||
| // - COMMAND DOCS ... => return empty array | ||||
| // - COMMAND INFO ... => return empty array | ||||
| // - Any other => empty array | ||||
| fn command_cmd(args: &[String]) -> Result<Protocol, DBError> { | ||||
|     if args.is_empty() { | ||||
|         return Ok(Protocol::Array(vec![])); | ||||
|     } | ||||
|     let sub = args[0].to_lowercase(); | ||||
|     match sub.as_str() { | ||||
|         "docs" => Ok(Protocol::Array(vec![])), | ||||
|         "info" => Ok(Protocol::Array(vec![])), | ||||
|         _ => Ok(Protocol::Array(vec![])), | ||||
|     } | ||||
| } | ||||
|   | ||||
| @@ -19,6 +19,10 @@ impl fmt::Display for Protocol { | ||||
|  | ||||
| impl Protocol { | ||||
|     pub fn from(protocol: &str) -> Result<(Self, &str), DBError> { | ||||
|         if protocol.is_empty() { | ||||
|             // Incomplete frame; caller should read more bytes | ||||
|             return Err(DBError("[incomplete] empty".to_string())); | ||||
|         } | ||||
|         let ret = match protocol.chars().nth(0) { | ||||
|             Some('+') => Self::parse_simple_string_sfx(&protocol[1..]), | ||||
|             Some('$') => Self::parse_bulk_string_sfx(&protocol[1..]), | ||||
| @@ -101,21 +105,20 @@ impl Protocol { | ||||
|             let size = Self::parse_usize(&protocol[..len_end])?; | ||||
|             let data_start = len_end + 2; | ||||
|             let data_end = data_start + size; | ||||
|             let s = Self::parse_string(&protocol[data_start..data_end])?; | ||||
|  | ||||
|             if protocol.len() < data_end + 2 || &protocol[data_end..data_end+2] != "\r\n" { | ||||
|                  Err(DBError(format!( | ||||
|                     "[new bulk string] unmatched string length in prototocl {:?}", | ||||
|                     protocol, | ||||
|                 ))) | ||||
|             } else { | ||||
|                 Ok((Protocol::BulkString(s), &protocol[data_end + 2..])) | ||||
|             // If we don't yet have the full bulk payload + trailing CRLF, signal INCOMPLETE | ||||
|             if protocol.len() < data_end + 2 { | ||||
|                 return Err(DBError("[incomplete] bulk body".to_string())); | ||||
|             } | ||||
|             if &protocol[data_end..data_end + 2] != "\r\n" { | ||||
|                 return Err(DBError("[incomplete] bulk terminator".to_string())); | ||||
|             } | ||||
|  | ||||
|             let s = Self::parse_string(&protocol[data_start..data_end])?; | ||||
|             Ok((Protocol::BulkString(s), &protocol[data_end + 2..])) | ||||
|         } else { | ||||
|             Err(DBError(format!( | ||||
|                 "[new bulk string] unsupported protocol: {:?}", | ||||
|                 protocol | ||||
|             ))) | ||||
|             // No CRLF after bulk length header yet | ||||
|             Err(DBError("[incomplete] bulk header".to_string())) | ||||
|         } | ||||
|     } | ||||
|  | ||||
| @@ -125,16 +128,25 @@ impl Protocol { | ||||
|             let mut remaining = &s[len_end + 2..]; | ||||
|             let mut vec = vec![]; | ||||
|             for _ in 0..array_len { | ||||
|                 let (p, rem) = Protocol::from(remaining)?; | ||||
|                 vec.push(p); | ||||
|                 remaining = rem; | ||||
|                 match Protocol::from(remaining) { | ||||
|                     Ok((p, rem)) => { | ||||
|                         vec.push(p); | ||||
|                         remaining = rem; | ||||
|                     } | ||||
|                     Err(e) => { | ||||
|                         // Propagate incomplete so caller can read more bytes | ||||
|                         if e.0.starts_with("[incomplete]") { | ||||
|                             return Err(e); | ||||
|                         } else { | ||||
|                             return Err(e); | ||||
|                         } | ||||
|                     } | ||||
|                 } | ||||
|             } | ||||
|             Ok((Protocol::Array(vec), remaining)) | ||||
|         } else { | ||||
|             Err(DBError(format!( | ||||
|                 "[new array] unsupported protocol: {:?}", | ||||
|                 s | ||||
|             ))) | ||||
|             // No CRLF after array header yet | ||||
|             Err(DBError("[incomplete] array header".to_string())) | ||||
|         } | ||||
|     } | ||||
|  | ||||
|   | ||||
| @@ -3,6 +3,9 @@ use std::collections::HashMap; | ||||
| use std::sync::Arc; | ||||
| use tokio::io::AsyncReadExt; | ||||
| use tokio::io::AsyncWriteExt; | ||||
| use tokio::sync::{Mutex, oneshot}; | ||||
|  | ||||
| use std::sync::atomic::{AtomicU64, Ordering}; | ||||
|  | ||||
| use crate::cmd::Cmd; | ||||
| use crate::error::DBError; | ||||
| @@ -17,6 +20,22 @@ pub struct Server { | ||||
|     pub client_name: Option<String>, | ||||
|     pub selected_db: u64, // Changed from usize to u64 | ||||
|     pub queued_cmd: Option<Vec<(Cmd, Protocol)>>, | ||||
|  | ||||
|     // BLPOP waiter registry: per (db_index, key) FIFO of waiters | ||||
|     pub list_waiters: Arc<Mutex<HashMap<u64, HashMap<String, Vec<Waiter>>>>>, | ||||
|     pub waiter_seq: Arc<AtomicU64>, | ||||
| } | ||||
|  | ||||
| pub struct Waiter { | ||||
|     pub id: u64, | ||||
|     pub side: PopSide, | ||||
|     pub tx: oneshot::Sender<(String, String)>, // (key, element) | ||||
| } | ||||
|  | ||||
| #[derive(Clone, Copy, Debug, PartialEq, Eq)] | ||||
| pub enum PopSide { | ||||
|     Left, | ||||
|     Right, | ||||
| } | ||||
|  | ||||
| impl Server { | ||||
| @@ -27,6 +46,9 @@ impl Server { | ||||
|             client_name: None, | ||||
|             selected_db: 0, | ||||
|             queued_cmd: None, | ||||
|  | ||||
|             list_waiters: Arc::new(Mutex::new(HashMap::new())), | ||||
|             waiter_seq: Arc::new(AtomicU64::new(1)), | ||||
|         } | ||||
|     } | ||||
|  | ||||
| @@ -66,35 +88,122 @@ impl Server { | ||||
|         self.option.encrypt && db_index >= 10 | ||||
|     } | ||||
|  | ||||
|     // ----- BLPOP waiter helpers ----- | ||||
|  | ||||
|     pub async fn register_waiter(&self, db_index: u64, key: &str, side: PopSide) -> (u64, oneshot::Receiver<(String, String)>) { | ||||
|         let id = self.waiter_seq.fetch_add(1, Ordering::Relaxed); | ||||
|         let (tx, rx) = oneshot::channel::<(String, String)>(); | ||||
|  | ||||
|         let mut guard = self.list_waiters.lock().await; | ||||
|         let per_db = guard.entry(db_index).or_insert_with(HashMap::new); | ||||
|         let q = per_db.entry(key.to_string()).or_insert_with(Vec::new); | ||||
|         q.push(Waiter { id, side, tx }); | ||||
|         (id, rx) | ||||
|     } | ||||
|  | ||||
|     pub async fn unregister_waiter(&self, db_index: u64, key: &str, id: u64) { | ||||
|         let mut guard = self.list_waiters.lock().await; | ||||
|         if let Some(per_db) = guard.get_mut(&db_index) { | ||||
|             if let Some(q) = per_db.get_mut(key) { | ||||
|                 q.retain(|w| w.id != id); | ||||
|                 if q.is_empty() { | ||||
|                     per_db.remove(key); | ||||
|                 } | ||||
|             } | ||||
|             if per_db.is_empty() { | ||||
|                 guard.remove(&db_index); | ||||
|             } | ||||
|         } | ||||
|     } | ||||
|  | ||||
|     // Called after LPUSH/RPUSH to deliver to blocked BLPOP waiters. | ||||
|     pub async fn drain_waiters_after_push(&self, key: &str) -> Result<(), DBError> { | ||||
|         let db_index = self.selected_db; | ||||
|  | ||||
|         loop { | ||||
|             // Check if any waiter exists | ||||
|             let maybe_waiter = { | ||||
|                 let mut guard = self.list_waiters.lock().await; | ||||
|                 if let Some(per_db) = guard.get_mut(&db_index) { | ||||
|                     if let Some(q) = per_db.get_mut(key) { | ||||
|                         if !q.is_empty() { | ||||
|                             // Pop FIFO | ||||
|                             Some(q.remove(0)) | ||||
|                         } else { | ||||
|                             None | ||||
|                         } | ||||
|                     } else { | ||||
|                         None | ||||
|                     } | ||||
|                 } else { | ||||
|                     None | ||||
|                 } | ||||
|             }; | ||||
|  | ||||
|             let waiter = if let Some(w) = maybe_waiter { w } else { break }; | ||||
|  | ||||
|             // Pop one element depending on waiter side | ||||
|             let elems = match waiter.side { | ||||
|                 PopSide::Left => self.current_storage()?.lpop(key, 1)?, | ||||
|                 PopSide::Right => self.current_storage()?.rpop(key, 1)?, | ||||
|             }; | ||||
|             if elems.is_empty() { | ||||
|                 // Nothing to deliver; re-register waiter at the front to preserve order | ||||
|                 let mut guard = self.list_waiters.lock().await; | ||||
|                 let per_db = guard.entry(db_index).or_insert_with(HashMap::new); | ||||
|                 let q = per_db.entry(key.to_string()).or_insert_with(Vec::new); | ||||
|                 q.insert(0, waiter); | ||||
|                 break; | ||||
|             } else { | ||||
|                 let elem = elems[0].clone(); | ||||
|                 // Send to waiter; if receiver dropped, just continue | ||||
|                 let _ = waiter.tx.send((key.to_string(), elem)); | ||||
|                 // Loop to try to satisfy more waiters if more elements remain | ||||
|                 continue; | ||||
|             } | ||||
|         } | ||||
|  | ||||
|         Ok(()) | ||||
|     } | ||||
|  | ||||
|     pub async fn handle( | ||||
|         &mut self, | ||||
|         mut stream: tokio::net::TcpStream, | ||||
|     ) -> Result<(), DBError> { | ||||
|         let mut buf = [0; 512]; | ||||
|          | ||||
|         // Accumulate incoming bytes to handle partial RESP frames | ||||
|         let mut acc = String::new(); | ||||
|         let mut buf = vec![0u8; 8192]; | ||||
|  | ||||
|         loop { | ||||
|             let len = match stream.read(&mut buf).await { | ||||
|             let n = match stream.read(&mut buf).await { | ||||
|                 Ok(0) => { | ||||
|                     println!("[handle] connection closed"); | ||||
|                     return Ok(()); | ||||
|                 } | ||||
|                 Ok(len) => len, | ||||
|                 Ok(n) => n, | ||||
|                 Err(e) => { | ||||
|                     println!("[handle] read error: {:?}", e); | ||||
|                     return Err(e.into()); | ||||
|                 } | ||||
|             }; | ||||
|  | ||||
|             let mut s = str::from_utf8(&buf[..len])?; | ||||
|             while !s.is_empty() { | ||||
|                 let (cmd, protocol, remaining) = match Cmd::from(s) { | ||||
|             // Append to accumulator. RESP for our usage is ASCII-safe. | ||||
|             acc.push_str(str::from_utf8(&buf[..n])?); | ||||
|  | ||||
|             // Try to parse as many complete commands as are available in 'acc'. | ||||
|             loop { | ||||
|                 let parsed = Cmd::from(&acc); | ||||
|                 let (cmd, protocol, remaining) = match parsed { | ||||
|                     Ok((cmd, protocol, remaining)) => (cmd, protocol, remaining), | ||||
|                     Err(e) => { | ||||
|                         println!("\x1b[31;1mprotocol error: {:?}\x1b[0m", e); | ||||
|                         (Cmd::Unknow("protocol_error".to_string()), Protocol::err(&format!("protocol error: {}", e.0)), "") | ||||
|                     Err(_e) => { | ||||
|                         // Incomplete or invalid frame; assume incomplete and wait for more data. | ||||
|                         // This avoids emitting spurious protocol_error for split frames. | ||||
|                         break; | ||||
|                     } | ||||
|                 }; | ||||
|                 s = remaining; | ||||
|  | ||||
|                 // Advance the accumulator to the unparsed remainder | ||||
|                 acc = remaining.to_string(); | ||||
|  | ||||
|                 if self.option.debug { | ||||
|                     println!("\x1b[34;1mgot command: {:?}, protocol: {:?}\x1b[0m", cmd, protocol); | ||||
| @@ -114,7 +223,7 @@ impl Server { | ||||
|                         Protocol::err(&format!("ERR {}", e.0)) | ||||
|                     } | ||||
|                 }; | ||||
|                  | ||||
|  | ||||
|                 if self.option.debug { | ||||
|                     println!("\x1b[34;1mqueued cmd {:?}\x1b[0m", self.queued_cmd); | ||||
|                     println!("\x1b[32;1mgoing to send response {}\x1b[0m", res.encode()); | ||||
| @@ -130,6 +239,11 @@ impl Server { | ||||
|                     println!("[handle] QUIT command received, closing connection"); | ||||
|                     return Ok(()); | ||||
|                 } | ||||
|  | ||||
|                 // Continue parsing any further complete commands already in 'acc' | ||||
|                 if acc.is_empty() { | ||||
|                     break; | ||||
|                 } | ||||
|             } | ||||
|         } | ||||
|     } | ||||
|   | ||||
| @@ -215,4 +215,31 @@ impl Storage { | ||||
|          | ||||
|         Ok(keys) | ||||
|     } | ||||
| } | ||||
|  | ||||
| impl Storage { | ||||
|     pub fn dbsize(&self) -> Result<i64, DBError> { | ||||
|         let read_txn = self.db.begin_read()?; | ||||
|         let types_table = read_txn.open_table(TYPES_TABLE)?; | ||||
|         let expiration_table = read_txn.open_table(EXPIRATION_TABLE)?; | ||||
|  | ||||
|         let mut count: i64 = 0; | ||||
|         let mut iter = types_table.iter()?; | ||||
|         while let Some(entry) = iter.next() { | ||||
|             let entry = entry?; | ||||
|             let key = entry.0.value(); | ||||
|             let ty = entry.1.value(); | ||||
|  | ||||
|             if ty == "string" { | ||||
|                 if let Some(expires_at) = expiration_table.get(key)? { | ||||
|                     if now_in_millis() > expires_at.value() as u128 { | ||||
|                         // Skip logically expired string keys | ||||
|                         continue; | ||||
|                     } | ||||
|                 } | ||||
|             } | ||||
|             count += 1; | ||||
|         } | ||||
|         Ok(count) | ||||
|     } | ||||
| } | ||||
| @@ -98,6 +98,116 @@ impl Storage { | ||||
|             None => Ok(false), // Key does not exist | ||||
|         } | ||||
|     } | ||||
|  | ||||
|     // -------- Expiration helpers (string keys only, consistent with TTL/EXISTS) -------- | ||||
|  | ||||
|     // Set expiry in seconds; returns true if applied (key exists and is string), false otherwise | ||||
|     pub fn expire_seconds(&self, key: &str, secs: u64) -> Result<bool, DBError> { | ||||
|         // Determine eligibility first to avoid holding borrows across commit | ||||
|         let mut applied = false; | ||||
|         let write_txn = self.db.begin_write()?; | ||||
|         { | ||||
|             let types_table = write_txn.open_table(TYPES_TABLE)?; | ||||
|             let is_string = types_table | ||||
|                 .get(key)? | ||||
|                 .map(|v| v.value() == "string") | ||||
|                 .unwrap_or(false); | ||||
|             if is_string { | ||||
|                 let mut expiration_table = write_txn.open_table(EXPIRATION_TABLE)?; | ||||
|                 let expires_at = now_in_millis() + (secs as u128) * 1000; | ||||
|                 expiration_table.insert(key, &(expires_at as u64))?; | ||||
|                 applied = true; | ||||
|             } | ||||
|         } | ||||
|         write_txn.commit()?; | ||||
|         Ok(applied) | ||||
|     } | ||||
|  | ||||
|     // Set expiry in milliseconds; returns true if applied (key exists and is string), false otherwise | ||||
|     pub fn pexpire_millis(&self, key: &str, ms: u128) -> Result<bool, DBError> { | ||||
|         let mut applied = false; | ||||
|         let write_txn = self.db.begin_write()?; | ||||
|         { | ||||
|             let types_table = write_txn.open_table(TYPES_TABLE)?; | ||||
|             let is_string = types_table | ||||
|                 .get(key)? | ||||
|                 .map(|v| v.value() == "string") | ||||
|                 .unwrap_or(false); | ||||
|             if is_string { | ||||
|                 let mut expiration_table = write_txn.open_table(EXPIRATION_TABLE)?; | ||||
|                 let expires_at = now_in_millis() + ms; | ||||
|                 expiration_table.insert(key, &(expires_at as u64))?; | ||||
|                 applied = true; | ||||
|             } | ||||
|         } | ||||
|         write_txn.commit()?; | ||||
|         Ok(applied) | ||||
|     } | ||||
|  | ||||
|     // Remove expiry if present; returns true if removed, false otherwise | ||||
|     pub fn persist(&self, key: &str) -> Result<bool, DBError> { | ||||
|         let mut removed = false; | ||||
|         let write_txn = self.db.begin_write()?; | ||||
|         { | ||||
|             let types_table = write_txn.open_table(TYPES_TABLE)?; | ||||
|             let is_string = types_table | ||||
|                 .get(key)? | ||||
|                 .map(|v| v.value() == "string") | ||||
|                 .unwrap_or(false); | ||||
|             if is_string { | ||||
|                 let mut expiration_table = write_txn.open_table(EXPIRATION_TABLE)?; | ||||
|                 if expiration_table.remove(key)?.is_some() { | ||||
|                     removed = true; | ||||
|                 } | ||||
|             } | ||||
|         } | ||||
|         write_txn.commit()?; | ||||
|         Ok(removed) | ||||
|     } | ||||
|  | ||||
|     // Absolute EXPIREAT in seconds since epoch | ||||
|     // Returns true if applied (key exists and is string), false otherwise | ||||
|     pub fn expire_at_seconds(&self, key: &str, ts_secs: i64) -> Result<bool, DBError> { | ||||
|         let mut applied = false; | ||||
|         let write_txn = self.db.begin_write()?; | ||||
|         { | ||||
|             let types_table = write_txn.open_table(TYPES_TABLE)?; | ||||
|             let is_string = types_table | ||||
|                 .get(key)? | ||||
|                 .map(|v| v.value() == "string") | ||||
|                 .unwrap_or(false); | ||||
|             if is_string { | ||||
|                 let mut expiration_table = write_txn.open_table(EXPIRATION_TABLE)?; | ||||
|                 let expires_at_ms: u128 = if ts_secs <= 0 { 0 } else { (ts_secs as u128) * 1000 }; | ||||
|                 expiration_table.insert(key, &((expires_at_ms as u64)))?; | ||||
|                 applied = true; | ||||
|             } | ||||
|         } | ||||
|         write_txn.commit()?; | ||||
|         Ok(applied) | ||||
|     } | ||||
|  | ||||
|     // Absolute PEXPIREAT in milliseconds since epoch | ||||
|     // Returns true if applied (key exists and is string), false otherwise | ||||
|     pub fn pexpire_at_millis(&self, key: &str, ts_ms: i64) -> Result<bool, DBError> { | ||||
|         let mut applied = false; | ||||
|         let write_txn = self.db.begin_write()?; | ||||
|         { | ||||
|             let types_table = write_txn.open_table(TYPES_TABLE)?; | ||||
|             let is_string = types_table | ||||
|                 .get(key)? | ||||
|                 .map(|v| v.value() == "string") | ||||
|                 .unwrap_or(false); | ||||
|             if is_string { | ||||
|                 let mut expiration_table = write_txn.open_table(EXPIRATION_TABLE)?; | ||||
|                 let expires_at_ms: u128 = if ts_ms <= 0 { 0 } else { ts_ms as u128 }; | ||||
|                 expiration_table.insert(key, &((expires_at_ms as u64)))?; | ||||
|                 applied = true; | ||||
|             } | ||||
|         } | ||||
|         write_txn.commit()?; | ||||
|         Ok(applied) | ||||
|     } | ||||
| } | ||||
|  | ||||
| // Utility function for glob pattern matching | ||||
|   | ||||
| @@ -148,8 +148,6 @@ impl Storage { | ||||
|  | ||||
|     pub fn hexists(&self, key: &str, field: &str) -> Result<bool, DBError> { | ||||
|         let read_txn = self.db.begin_read()?; | ||||
|         let types_table = read_txn.open_table(TYPES_TABLE)?; | ||||
|          | ||||
|         let types_table = read_txn.open_table(TYPES_TABLE)?; | ||||
|         let key_type = { | ||||
|             let access_guard = types_table.get(key)?; | ||||
| @@ -168,8 +166,6 @@ impl Storage { | ||||
|  | ||||
|     pub fn hkeys(&self, key: &str) -> Result<Vec<String>, DBError> { | ||||
|         let read_txn = self.db.begin_read()?; | ||||
|         let types_table = read_txn.open_table(TYPES_TABLE)?; | ||||
|          | ||||
|         let types_table = read_txn.open_table(TYPES_TABLE)?; | ||||
|         let key_type = { | ||||
|             let access_guard = types_table.get(key)?; | ||||
| @@ -200,8 +196,6 @@ impl Storage { | ||||
|     // ✅ ENCRYPTION APPLIED: All values are decrypted after retrieval | ||||
|     pub fn hvals(&self, key: &str) -> Result<Vec<String>, DBError> { | ||||
|         let read_txn = self.db.begin_read()?; | ||||
|         let types_table = read_txn.open_table(TYPES_TABLE)?; | ||||
|          | ||||
|         let types_table = read_txn.open_table(TYPES_TABLE)?; | ||||
|         let key_type = { | ||||
|             let access_guard = types_table.get(key)?; | ||||
| @@ -233,8 +227,6 @@ impl Storage { | ||||
|  | ||||
|     pub fn hlen(&self, key: &str) -> Result<i64, DBError> { | ||||
|         let read_txn = self.db.begin_read()?; | ||||
|         let types_table = read_txn.open_table(TYPES_TABLE)?; | ||||
|          | ||||
|         let types_table = read_txn.open_table(TYPES_TABLE)?; | ||||
|         let key_type = { | ||||
|             let access_guard = types_table.get(key)?; | ||||
| @@ -265,8 +257,6 @@ impl Storage { | ||||
|     // ✅ ENCRYPTION APPLIED: Values are decrypted after retrieval | ||||
|     pub fn hmget(&self, key: &str, fields: Vec<String>) -> Result<Vec<Option<String>>, DBError> { | ||||
|         let read_txn = self.db.begin_read()?; | ||||
|         let types_table = read_txn.open_table(TYPES_TABLE)?; | ||||
|          | ||||
|         let types_table = read_txn.open_table(TYPES_TABLE)?; | ||||
|         let key_type = { | ||||
|             let access_guard = types_table.get(key)?; | ||||
| @@ -334,8 +324,6 @@ impl Storage { | ||||
|     // ✅ ENCRYPTION APPLIED: Values are decrypted after retrieval | ||||
|     pub fn hscan(&self, key: &str, cursor: u64, pattern: Option<&str>, count: Option<u64>) -> Result<(u64, Vec<(String, String)>), DBError> { | ||||
|         let read_txn = self.db.begin_read()?; | ||||
|         let types_table = read_txn.open_table(TYPES_TABLE)?; | ||||
|          | ||||
|         let types_table = read_txn.open_table(TYPES_TABLE)?; | ||||
|         let key_type = { | ||||
|             let access_guard = types_table.get(key)?; | ||||
|   | ||||
| @@ -16,9 +16,9 @@ fn get_redis_connection(port: u16) -> Connection { | ||||
|                 } | ||||
|             } | ||||
|             Err(e) => { | ||||
|                 if attempts >= 20 { | ||||
|                 if attempts >= 120 { | ||||
|                     panic!( | ||||
|                         "Failed to connect to Redis server after 20 attempts: {}", | ||||
|                         "Failed to connect to Redis server after 120 attempts: {}", | ||||
|                         e | ||||
|                     ); | ||||
|                 } | ||||
| @@ -88,8 +88,8 @@ fn setup_server() -> (ServerProcessGuard, u16) { | ||||
|         test_dir, | ||||
|     }; | ||||
|      | ||||
|     // Give the server a moment to start | ||||
|     std::thread::sleep(Duration::from_millis(500)); | ||||
|     // Give the server time to build and start (cargo run may compile first) | ||||
|     std::thread::sleep(Duration::from_millis(2500)); | ||||
|  | ||||
|     (guard, port) | ||||
| } | ||||
|   | ||||
							
								
								
									
										892
									
								
								herodb/tests/usage_suite.rs
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										892
									
								
								herodb/tests/usage_suite.rs
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,892 @@ | ||||
| use herodb::{options::DBOption, server::Server}; | ||||
| use tokio::io::{AsyncReadExt, AsyncWriteExt}; | ||||
| use tokio::net::TcpStream; | ||||
| use tokio::time::{sleep, Duration}; | ||||
|  | ||||
| // ========================= | ||||
| // Helpers | ||||
| // ========================= | ||||
|  | ||||
| async fn start_test_server(test_name: &str) -> (Server, u16) { | ||||
|     use std::sync::atomic::{AtomicU16, Ordering}; | ||||
|     static PORT_COUNTER: AtomicU16 = AtomicU16::new(17100); | ||||
|     let port = PORT_COUNTER.fetch_add(1, Ordering::SeqCst); | ||||
|  | ||||
|     let test_dir = format!("/tmp/herodb_usage_suite_{}", test_name); | ||||
|     let _ = std::fs::remove_dir_all(&test_dir); | ||||
|     std::fs::create_dir_all(&test_dir).unwrap(); | ||||
|  | ||||
|     let option = DBOption { | ||||
|         dir: test_dir, | ||||
|         port, | ||||
|         debug: false, | ||||
|         encrypt: false, | ||||
|         encryption_key: None, | ||||
|     }; | ||||
|  | ||||
|     let server = Server::new(option).await; | ||||
|     (server, port) | ||||
| } | ||||
|  | ||||
| async fn spawn_listener(server: Server, port: u16) { | ||||
|     tokio::spawn(async move { | ||||
|         let listener = tokio::net::TcpListener::bind(format!("127.0.0.1:{}", port)) | ||||
|             .await | ||||
|             .expect("bind listener"); | ||||
|         loop { | ||||
|             match listener.accept().await { | ||||
|                 Ok((stream, _)) => { | ||||
|                     let mut s_clone = server.clone(); | ||||
|                     tokio::spawn(async move { | ||||
|                         let _ = s_clone.handle(stream).await; | ||||
|                     }); | ||||
|                 } | ||||
|                 Err(_e) => break, | ||||
|             } | ||||
|         } | ||||
|     }); | ||||
| } | ||||
|  | ||||
| /// Build RESP array for args ["PING"] -> "*1\r\n$4\r\nPING\r\n" | ||||
| fn build_resp(args: &[&str]) -> String { | ||||
|     let mut s = format!("*{}\r\n", args.len()); | ||||
|     for a in args { | ||||
|         s.push_str(&format!("${}\r\n{}\r\n", a.len(), a)); | ||||
|     } | ||||
|     s | ||||
| } | ||||
|  | ||||
| async fn connect(port: u16) -> TcpStream { | ||||
|     let mut attempts = 0; | ||||
|     loop { | ||||
|         match TcpStream::connect(format!("127.0.0.1:{}", port)).await { | ||||
|             Ok(s) => return s, | ||||
|             Err(_) if attempts < 30 => { | ||||
|                 attempts += 1; | ||||
|                 sleep(Duration::from_millis(100)).await; | ||||
|             } | ||||
|             Err(e) => panic!("Failed to connect: {}", e), | ||||
|         } | ||||
|     } | ||||
| } | ||||
|  | ||||
| fn find_crlf(buf: &[u8], start: usize) -> Option<usize> { | ||||
|     let mut i = start; | ||||
|     while i + 1 < buf.len() { | ||||
|         if buf[i] == b'\r' && buf[i + 1] == b'\n' { | ||||
|             return Some(i); | ||||
|         } | ||||
|         i += 1; | ||||
|     } | ||||
|     None | ||||
| } | ||||
|  | ||||
| fn parse_number_i64(buf: &[u8], start: usize, end: usize) -> Option<i64> { | ||||
|     let s = std::str::from_utf8(&buf[start..end]).ok()?; | ||||
|     s.parse::<i64>().ok() | ||||
| } | ||||
|  | ||||
| // Return number of bytes that make up a complete RESP element starting at 'i', or None if incomplete. | ||||
| fn parse_elem(buf: &[u8], i: usize) -> Option<usize> { | ||||
|     if i >= buf.len() { | ||||
|         return None; | ||||
|     } | ||||
|     match buf[i] { | ||||
|         b'+' | b'-' | b':' => { | ||||
|             let end = find_crlf(buf, i + 1)?; | ||||
|             Some(end + 2 - i) | ||||
|         } | ||||
|         b'$' => { | ||||
|             let hdr_end = find_crlf(buf, i + 1)?; | ||||
|             let n = parse_number_i64(buf, i + 1, hdr_end)?; | ||||
|             if n < 0 { | ||||
|                 // Null bulk string: only header | ||||
|                 Some(hdr_end + 2 - i) | ||||
|             } else { | ||||
|                 let need = hdr_end + 2 + (n as usize) + 2; | ||||
|                 if need <= buf.len() { | ||||
|                     Some(need - i) | ||||
|                 } else { | ||||
|                     None | ||||
|                 } | ||||
|             } | ||||
|         } | ||||
|         b'*' => { | ||||
|             let hdr_end = find_crlf(buf, i + 1)?; | ||||
|             let n = parse_number_i64(buf, i + 1, hdr_end)?; | ||||
|             if n < 0 { | ||||
|                 // Null array: only header | ||||
|                 Some(hdr_end + 2 - i) | ||||
|             } else { | ||||
|                 let mut j = hdr_end + 2; | ||||
|                 for _ in 0..(n as usize) { | ||||
|                     let consumed = parse_elem(buf, j)?; | ||||
|                     j += consumed; | ||||
|                 } | ||||
|                 Some(j - i) | ||||
|             } | ||||
|         } | ||||
|         _ => None, | ||||
|     } | ||||
| } | ||||
|  | ||||
| fn resp_frame_len(buf: &[u8]) -> Option<usize> { | ||||
|     parse_elem(buf, 0) | ||||
| } | ||||
|  | ||||
| async fn read_full_resp(stream: &mut TcpStream) -> String { | ||||
|     let mut buf: Vec<u8> = Vec::with_capacity(8192); | ||||
|     let mut tmp = vec![0u8; 4096]; | ||||
|  | ||||
|     loop { | ||||
|         if let Some(total) = resp_frame_len(&buf) { | ||||
|             if buf.len() >= total { | ||||
|                 return String::from_utf8_lossy(&buf[..total]).to_string(); | ||||
|             } | ||||
|         } | ||||
|  | ||||
|         match tokio::time::timeout(Duration::from_secs(2), stream.read(&mut tmp)).await { | ||||
|             Ok(Ok(n)) => { | ||||
|                 if n == 0 { | ||||
|                     if let Some(total) = resp_frame_len(&buf) { | ||||
|                         if buf.len() >= total { | ||||
|                             return String::from_utf8_lossy(&buf[..total]).to_string(); | ||||
|                         } | ||||
|                     } | ||||
|                     return String::from_utf8_lossy(&buf).to_string(); | ||||
|                 } | ||||
|                 buf.extend_from_slice(&tmp[..n]); | ||||
|             } | ||||
|             Ok(Err(e)) => panic!("read error: {}", e), | ||||
|             Err(_) => panic!("timeout waiting for reply"), | ||||
|         } | ||||
|  | ||||
|         if buf.len() > 8 * 1024 * 1024 { | ||||
|             panic!("reply too large"); | ||||
|         } | ||||
|     } | ||||
| } | ||||
|  | ||||
| async fn send_cmd(stream: &mut TcpStream, args: &[&str]) -> String { | ||||
|     let req = build_resp(args); | ||||
|     stream.write_all(req.as_bytes()).await.unwrap(); | ||||
|     read_full_resp(stream).await | ||||
| } | ||||
|  | ||||
| // Assert helpers with clearer output | ||||
| fn assert_contains(haystack: &str, needle: &str, ctx: &str) { | ||||
|     assert!( | ||||
|         haystack.contains(needle), | ||||
|         "ASSERT CONTAINS failed: '{}' not found in response.\nContext: {}\nResponse:\n{}", | ||||
|         needle, | ||||
|         ctx, | ||||
|         haystack | ||||
|     ); | ||||
| } | ||||
|  | ||||
| fn assert_eq_resp(actual: &str, expected: &str, ctx: &str) { | ||||
|     assert!( | ||||
|         actual == expected, | ||||
|         "ASSERT EQUAL failed.\nContext: {}\nExpected:\n{:?}\nActual:\n{:?}", | ||||
|         ctx, | ||||
|         expected, | ||||
|         actual | ||||
|     ); | ||||
| } | ||||
|  | ||||
| /// Extract the payload of a single RESP Bulk String reply. | ||||
| /// Example input: | ||||
| ///   "$5\r\nhello\r\n" -> Some("hello".to_string()) | ||||
| fn extract_bulk_payload(resp: &str) -> Option<String> { | ||||
|     // find first CRLF after "$len" | ||||
|     let first = resp.find("\r\n")?; | ||||
|     let after = &resp[(first + 2)..]; | ||||
|     // find next CRLF ending payload | ||||
|     let second = after.find("\r\n")?; | ||||
|     Some(after[..second].to_string()) | ||||
| } | ||||
|  | ||||
| // ========================= | ||||
| // Test suites | ||||
| // ========================= | ||||
|  | ||||
| #[tokio::test] | ||||
| async fn test_01_connection_and_info() { | ||||
|     let (server, port) = start_test_server("conn_info").await; | ||||
|     spawn_listener(server, port).await; | ||||
|     sleep(Duration::from_millis(150)).await; | ||||
|  | ||||
|     let mut s = connect(port).await; | ||||
|  | ||||
|     // redis-cli may send COMMAND DOCS, our server replies empty array; harmless. | ||||
|     let pong = send_cmd(&mut s, &["PING"]).await; | ||||
|     assert_contains(&pong, "PONG", "PING should return PONG"); | ||||
|  | ||||
|     let echo = send_cmd(&mut s, &["ECHO", "hello"]).await; | ||||
|     assert_contains(&echo, "hello", "ECHO hello"); | ||||
|  | ||||
|     // INFO (general) | ||||
|     let info = send_cmd(&mut s, &["INFO"]).await; | ||||
|     assert_contains(&info, "redis_version", "INFO should include redis_version"); | ||||
|  | ||||
|     // INFO REPLICATION (static stub) | ||||
|     let repl = send_cmd(&mut s, &["INFO", "replication"]).await; | ||||
|     assert_contains(&repl, "role:master", "INFO replication role"); | ||||
|  | ||||
|     // CONFIG GET subset | ||||
|     let cfg = send_cmd(&mut s, &["CONFIG", "GET", "databases"]).await; | ||||
|     assert_contains(&cfg, "databases", "CONFIG GET databases"); | ||||
|     assert_contains(&cfg, "16", "CONFIG GET databases value"); | ||||
|  | ||||
|     // CLIENT name | ||||
|     let setname = send_cmd(&mut s, &["CLIENT", "SETNAME", "myapp"]).await; | ||||
|     assert_contains(&setname, "OK", "CLIENT SETNAME"); | ||||
|  | ||||
|     let getname = send_cmd(&mut s, &["CLIENT", "GETNAME"]).await; | ||||
|     assert_contains(&getname, "myapp", "CLIENT GETNAME"); | ||||
|  | ||||
|     // SELECT db | ||||
|     let sel = send_cmd(&mut s, &["SELECT", "0"]).await; | ||||
|     assert_contains(&sel, "OK", "SELECT 0"); | ||||
|  | ||||
|     // QUIT should close connection after sending OK | ||||
|     let quit = send_cmd(&mut s, &["QUIT"]).await; | ||||
|     assert_contains(&quit, "OK", "QUIT should return OK"); | ||||
| } | ||||
|  | ||||
| #[tokio::test] | ||||
| async fn test_02_strings_and_expiry() { | ||||
|     let (server, port) = start_test_server("strings").await; | ||||
|     spawn_listener(server, port).await; | ||||
|     sleep(Duration::from_millis(150)).await; | ||||
|  | ||||
|     let mut s = connect(port).await; | ||||
|  | ||||
|     // SET / GET | ||||
|     let set = send_cmd(&mut s, &["SET", "user:1", "alice"]).await; | ||||
|     assert_contains(&set, "OK", "SET user:1 alice"); | ||||
|  | ||||
|     let get = send_cmd(&mut s, &["GET", "user:1"]).await; | ||||
|     assert_contains(&get, "alice", "GET user:1"); | ||||
|  | ||||
|     // EXISTS / DEL | ||||
|     let ex1 = send_cmd(&mut s, &["EXISTS", "user:1"]).await; | ||||
|     assert_contains(&ex1, "1", "EXISTS user:1"); | ||||
|  | ||||
|     let del = send_cmd(&mut s, &["DEL", "user:1"]).await; | ||||
|     assert_contains(&del, "1", "DEL user:1"); | ||||
|  | ||||
|     let ex0 = send_cmd(&mut s, &["EXISTS", "user:1"]).await; | ||||
|     assert_contains(&ex0, "0", "EXISTS after DEL"); | ||||
|  | ||||
|     // INCR behavior | ||||
|     let i1 = send_cmd(&mut s, &["INCR", "count"]).await; | ||||
|     assert_contains(&i1, "1", "INCR new key -> 1"); | ||||
|     let i2 = send_cmd(&mut s, &["INCR", "count"]).await; | ||||
|     assert_contains(&i2, "2", "INCR existing -> 2"); | ||||
|     let _ = send_cmd(&mut s, &["SET", "notnum", "abc"]).await; | ||||
|     let ierr = send_cmd(&mut s, &["INCR", "notnum"]).await; | ||||
|     assert_contains(&ierr, "ERR", "INCR on non-numeric should ERR"); | ||||
|  | ||||
|     // Expiration via SET EX | ||||
|     let setex = send_cmd(&mut s, &["SET", "tmp:1", "boom", "EX", "1"]).await; | ||||
|     assert_contains(&setex, "OK", "SET tmp:1 EX 1"); | ||||
|  | ||||
|     let g_immediate = send_cmd(&mut s, &["GET", "tmp:1"]).await; | ||||
|     assert_contains(&g_immediate, "boom", "GET tmp:1 immediately"); | ||||
|  | ||||
|     let ttl = send_cmd(&mut s, &["TTL", "tmp:1"]).await; | ||||
|     // Implementation returns a SimpleString, accept any numeric content | ||||
|     assert!( | ||||
|         ttl.contains("1") || ttl.contains("0"), | ||||
|         "TTL should be 1 or 0, got: {}", | ||||
|         ttl | ||||
|     ); | ||||
|  | ||||
|     sleep(Duration::from_millis(1100)).await; | ||||
|     let g_after = send_cmd(&mut s, &["GET", "tmp:1"]).await; | ||||
|     assert_contains(&g_after, "$-1", "GET tmp:1 after expiry -> Null"); | ||||
|  | ||||
|     // TYPE | ||||
|     let _ = send_cmd(&mut s, &["SET", "t", "v"]).await; | ||||
|     let ty = send_cmd(&mut s, &["TYPE", "t"]).await; | ||||
|     assert_contains(&ty, "string", "TYPE string key"); | ||||
|     let ty_none = send_cmd(&mut s, &["TYPE", "noexist"]).await; | ||||
|     assert_contains(&ty_none, "none", "TYPE nonexistent"); | ||||
| } | ||||
|  | ||||
| #[tokio::test] | ||||
| async fn test_03_scan_and_keys() { | ||||
|     let (server, port) = start_test_server("scan").await; | ||||
|     spawn_listener(server, port).await; | ||||
|     sleep(Duration::from_millis(150)).await; | ||||
|  | ||||
|     let mut s = connect(port).await; | ||||
|  | ||||
|     for i in 0..5 { | ||||
|         let _ = send_cmd(&mut s, &["SET", &format!("key{}", i), &format!("value{}", i)]).await; | ||||
|     } | ||||
|  | ||||
|     let scan = send_cmd(&mut s, &["SCAN", "0", "MATCH", "key*", "COUNT", "10"]).await; | ||||
|     assert_contains(&scan, "key0", "SCAN should return keys with MATCH"); | ||||
|     assert_contains(&scan, "key4", "SCAN should return last key"); | ||||
|  | ||||
|     let keys = send_cmd(&mut s, &["KEYS", "*"]).await; | ||||
|     assert_contains(&keys, "key0", "KEYS * includes key0"); | ||||
|     assert_contains(&keys, "key4", "KEYS * includes key4"); | ||||
| } | ||||
|  | ||||
| #[tokio::test] | ||||
| async fn test_04_hashes_suite() { | ||||
|     let (server, port) = start_test_server("hashes").await; | ||||
|     spawn_listener(server, port).await; | ||||
|     sleep(Duration::from_millis(150)).await; | ||||
|  | ||||
|     let mut s = connect(port).await; | ||||
|  | ||||
|     // HSET (single, returns number of new fields) | ||||
|     let h1 = send_cmd(&mut s, &["HSET", "profile:1", "name", "alice"]).await; | ||||
|     assert_contains(&h1, "1", "HSET new field -> 1"); | ||||
|  | ||||
|     // HGET | ||||
|     let hg = send_cmd(&mut s, &["HGET", "profile:1", "name"]).await; | ||||
|     assert_contains(&hg, "alice", "HGET existing field"); | ||||
|  | ||||
|     // HSET multiple | ||||
|     let h2 = send_cmd(&mut s, &["HSET", "profile:1", "age", "30", "city", "paris"]).await; | ||||
|     assert_contains(&h2, "2", "HSET added 2 new fields"); | ||||
|  | ||||
|     // HMGET | ||||
|     let hmg = send_cmd(&mut s, &["HMGET", "profile:1", "name", "age", "city", "nope"]).await; | ||||
|     assert_contains(&hmg, "alice", "HMGET name"); | ||||
|     assert_contains(&hmg, "30", "HMGET age"); | ||||
|     assert_contains(&hmg, "paris", "HMGET city"); | ||||
|     assert_contains(&hmg, "$-1", "HMGET non-existent -> Null"); | ||||
|  | ||||
|     // HGETALL | ||||
|     let hga = send_cmd(&mut s, &["HGETALL", "profile:1"]).await; | ||||
|     assert_contains(&hga, "name", "HGETALL contains name"); | ||||
|     assert_contains(&hga, "alice", "HGETALL contains alice"); | ||||
|  | ||||
|     // HLEN | ||||
|     let hlen = send_cmd(&mut s, &["HLEN", "profile:1"]).await; | ||||
|     assert_contains(&hlen, "3", "HLEN is 3"); | ||||
|  | ||||
|     // HEXISTS | ||||
|     let hex1 = send_cmd(&mut s, &["HEXISTS", "profile:1", "age"]).await; | ||||
|     assert_contains(&hex1, "1", "HEXISTS age true"); | ||||
|     let hex0 = send_cmd(&mut s, &["HEXISTS", "profile:1", "nope"]).await; | ||||
|     assert_contains(&hex0, "0", "HEXISTS nope false"); | ||||
|  | ||||
|     // HKEYS / HVALS | ||||
|     let hkeys = send_cmd(&mut s, &["HKEYS", "profile:1"]).await; | ||||
|     assert_contains(&hkeys, "name", "HKEYS includes name"); | ||||
|     let hvals = send_cmd(&mut s, &["HVALS", "profile:1"]).await; | ||||
|     assert_contains(&hvals, "alice", "HVALS includes alice"); | ||||
|  | ||||
|     // HSETNX | ||||
|     let hnx0 = send_cmd(&mut s, &["HSETNX", "profile:1", "name", "bob"]).await; | ||||
|     assert_contains(&hnx0, "0", "HSETNX existing field -> 0"); | ||||
|     let hnx1 = send_cmd(&mut s, &["HSETNX", "profile:1", "nickname", "ali"]).await; | ||||
|     assert_contains(&hnx1, "1", "HSETNX new field -> 1"); | ||||
|  | ||||
|     // HSCAN | ||||
|     let hscan = send_cmd(&mut s, &["HSCAN", "profile:1", "0", "MATCH", "n*", "COUNT", "10"]).await; | ||||
|     assert_contains(&hscan, "name", "HSCAN matches fields starting with n"); | ||||
|     assert_contains(&hscan, "nickname", "HSCAN nickname present"); | ||||
|  | ||||
|     // HDEL | ||||
|     let hdel = send_cmd(&mut s, &["HDEL", "profile:1", "city", "age"]).await; | ||||
|     assert_contains(&hdel, "2", "HDEL removed two fields"); | ||||
| } | ||||
|  | ||||
| #[tokio::test] | ||||
| async fn test_05_lists_suite_including_blpop() { | ||||
|     let (server, port) = start_test_server("lists").await; | ||||
|     spawn_listener(server, port).await; | ||||
|     sleep(Duration::from_millis(150)).await; | ||||
|  | ||||
|     let mut a = connect(port).await; | ||||
|  | ||||
|     // LPUSH / RPUSH / LLEN | ||||
|     let lp = send_cmd(&mut a, &["LPUSH", "q:jobs", "a", "b"]).await; | ||||
|     assert_contains(&lp, "2", "LPUSH added 2, length 2"); | ||||
|  | ||||
|     let rp = send_cmd(&mut a, &["RPUSH", "q:jobs", "c"]).await; | ||||
|     assert_contains(&rp, "3", "RPUSH now length 3"); | ||||
|  | ||||
|     let llen = send_cmd(&mut a, &["LLEN", "q:jobs"]).await; | ||||
|     assert_contains(&llen, "3", "LLEN 3"); | ||||
|  | ||||
|     // LINDEX / LRANGE | ||||
|     let lidx = send_cmd(&mut a, &["LINDEX", "q:jobs", "0"]).await; | ||||
|     assert_eq_resp(&lidx, "$1\r\nb\r\n", "LINDEX q:jobs 0 should be b"); | ||||
|  | ||||
|     let lr = send_cmd(&mut a, &["LRANGE", "q:jobs", "0", "-1"]).await; | ||||
|     assert_eq_resp(&lr, "*3\r\n$1\r\nb\r\n$1\r\na\r\n$1\r\nc\r\n", "LRANGE q:jobs 0 -1 should be [b,a,c]"); | ||||
|  | ||||
|     // LTRIM | ||||
|     let ltrim = send_cmd(&mut a, &["LTRIM", "q:jobs", "0", "1"]).await; | ||||
|     assert_contains(<rim, "OK", "LTRIM OK"); | ||||
|     let lr_post = send_cmd(&mut a, &["LRANGE", "q:jobs", "0", "-1"]).await; | ||||
|     assert_eq_resp(&lr_post, "*2\r\n$1\r\nb\r\n$1\r\na\r\n", "After LTRIM, list [b,a]"); | ||||
|  | ||||
|     // LREM remove first occurrence of b | ||||
|     let lrem = send_cmd(&mut a, &["LREM", "q:jobs", "1", "b"]).await; | ||||
|     assert_contains(&lrem, "1", "LREM removed 1"); | ||||
|  | ||||
|     // LPOP and RPOP | ||||
|     let lpop1 = send_cmd(&mut a, &["LPOP", "q:jobs"]).await; | ||||
|     assert_contains(&lpop1, "$1\r\na\r\n", "LPOP returns a"); | ||||
|     let rpop_empty = send_cmd(&mut a, &["RPOP", "q:jobs"]).await; // empty now | ||||
|     assert_contains(&rpop_empty, "$-1", "RPOP on empty -> Null"); | ||||
|  | ||||
|     // LPOP with count on empty -> [] | ||||
|     let lpop0 = send_cmd(&mut a, &["LPOP", "q:jobs", "2"]).await; | ||||
|     assert_eq_resp(&lpop0, "*0\r\n", "LPOP with count on empty returns empty array"); | ||||
|  | ||||
|     // BLPOP: block on one client, push from another | ||||
|     let c1 = connect(port).await; | ||||
|     let mut c2 = connect(port).await; | ||||
|  | ||||
|     // Start BLPOP on c1 | ||||
|     let blpop_task = tokio::spawn(async move { | ||||
|         let mut c1_local = c1; | ||||
|         send_cmd(&mut c1_local, &["BLPOP", "q:block", "5"]).await | ||||
|     }); | ||||
|  | ||||
|     // Give it time to register waiter | ||||
|     sleep(Duration::from_millis(150)).await; | ||||
|  | ||||
|     // Push from c2 to wake BLPOP | ||||
|     let _ = send_cmd(&mut c2, &["LPUSH", "q:block", "x"]).await; | ||||
|  | ||||
|     // Await BLPOP result | ||||
|     let blpop_res = blpop_task.await.expect("BLPOP task join"); | ||||
|     assert_contains(&blpop_res, "q:block", "BLPOP returned key"); | ||||
|     assert_contains(&blpop_res, "x", "BLPOP returned element"); | ||||
| } | ||||
|  | ||||
| #[tokio::test] | ||||
| async fn test_06_flushdb_suite() { | ||||
|     let (server, port) = start_test_server("flushdb").await; | ||||
|     spawn_listener(server, port).await; | ||||
|     sleep(Duration::from_millis(150)).await; | ||||
|  | ||||
|     let mut s = connect(port).await; | ||||
|  | ||||
|     let _ = send_cmd(&mut s, &["SET", "k1", "v1"]).await; | ||||
|     let _ = send_cmd(&mut s, &["HSET", "h1", "f", "v"]).await; | ||||
|     let _ = send_cmd(&mut s, &["LPUSH", "l1", "a"]).await; | ||||
|  | ||||
|     let keys_before = send_cmd(&mut s, &["KEYS", "*"]).await; | ||||
|     assert_contains(&keys_before, "k1", "have string key before FLUSHDB"); | ||||
|     assert_contains(&keys_before, "h1", "have hash key before FLUSHDB"); | ||||
|     assert_contains(&keys_before, "l1", "have list key before FLUSHDB"); | ||||
|  | ||||
|     let fl = send_cmd(&mut s, &["FLUSHDB"]).await; | ||||
|     assert_contains(&fl, "OK", "FLUSHDB OK"); | ||||
|  | ||||
|     let keys_after = send_cmd(&mut s, &["KEYS", "*"]).await; | ||||
|     assert_eq_resp(&keys_after, "*0\r\n", "DB should be empty after FLUSHDB"); | ||||
| } | ||||
|  | ||||
| #[tokio::test] | ||||
| async fn test_07_age_stateless_suite() { | ||||
|     let (server, port) = start_test_server("age_stateless").await; | ||||
|     spawn_listener(server, port).await; | ||||
|     sleep(Duration::from_millis(150)).await; | ||||
|  | ||||
|     let mut s = connect(port).await; | ||||
|  | ||||
|     // GENENC -> [recipient, identity] | ||||
|     let gen = send_cmd(&mut s, &["AGE", "GENENC"]).await; | ||||
|     assert!( | ||||
|         gen.starts_with("*2\r\n$"), | ||||
|         "AGE GENENC should return array [recipient, identity], got:\n{}", | ||||
|         gen | ||||
|     ); | ||||
|  | ||||
|     // Parse simple RESP array of two bulk strings to extract keys | ||||
|     fn parse_two_bulk_array(resp: &str) -> (String, String) { | ||||
|         // naive parse for tests | ||||
|         let mut lines = resp.lines(); | ||||
|         let _ = lines.next(); // *2 | ||||
|         // $len | ||||
|         let _ = lines.next(); | ||||
|         let recip = lines.next().unwrap_or("").to_string(); | ||||
|         let _ = lines.next(); | ||||
|         let ident = lines.next().unwrap_or("").to_string(); | ||||
|         (recip, ident) | ||||
|     } | ||||
|     let (recipient, identity) = parse_two_bulk_array(&gen); | ||||
|     assert!( | ||||
|         recipient.starts_with("age1") && identity.starts_with("AGE-SECRET-KEY-1"), | ||||
|         "Unexpected AGE key formats.\nrecipient: {}\nidentity: {}", | ||||
|         recipient, | ||||
|         identity | ||||
|     ); | ||||
|  | ||||
|     // ENCRYPT / DECRYPT | ||||
|     let ct = send_cmd(&mut s, &["AGE", "ENCRYPT", &recipient, "hello world"]).await; | ||||
|     let ct_b64 = extract_bulk_payload(&ct).expect("Failed to parse bulk payload from ENCRYPT"); | ||||
|     let pt = send_cmd(&mut s, &["AGE", "DECRYPT", &identity, &ct_b64]).await; | ||||
|     assert_contains(&pt, "hello world", "AGE DECRYPT round-trip"); | ||||
|  | ||||
|     // GENSIGN -> [verify_pub_b64, sign_secret_b64] | ||||
|     let gensign = send_cmd(&mut s, &["AGE", "GENSIGN"]).await; | ||||
|     let (verify_pub, sign_secret) = parse_two_bulk_array(&gensign); | ||||
|     assert!( | ||||
|         !verify_pub.is_empty() && !sign_secret.is_empty(), | ||||
|         "GENSIGN returned empty keys" | ||||
|     ); | ||||
|  | ||||
|     // SIGN / VERIFY | ||||
|     let sig = send_cmd(&mut s, &["AGE", "SIGN", &sign_secret, "msg"]).await; | ||||
|     let sig_b64 = extract_bulk_payload(&sig).expect("Failed to parse bulk payload from SIGN"); | ||||
|     let v_ok = send_cmd(&mut s, &["AGE", "VERIFY", &verify_pub, "msg", &sig_b64]).await; | ||||
|     assert_contains(&v_ok, "1", "VERIFY should be 1 for valid signature"); | ||||
|  | ||||
|     let v_bad = send_cmd(&mut s, &["AGE", "VERIFY", &verify_pub, "tampered", &sig_b64]).await; | ||||
|     assert_contains(&v_bad, "0", "VERIFY should be 0 for invalid message/signature"); | ||||
| } | ||||
|  | ||||
| #[tokio::test] | ||||
| async fn test_08_age_persistent_named_suite() { | ||||
|     let (server, port) = start_test_server("age_persistent").await; | ||||
|     spawn_listener(server, port).await; | ||||
|     sleep(Duration::from_millis(150)).await; | ||||
|  | ||||
|     let mut s = connect(port).await; | ||||
|  | ||||
|     // KEYGEN + ENCRYPTNAME/DECRYPTNAME | ||||
|     let kg = send_cmd(&mut s, &["AGE", "KEYGEN", "app1"]).await; | ||||
|     assert!( | ||||
|         kg.starts_with("*2\r\n"), | ||||
|         "AGE KEYGEN should return [recipient, identity], got:\n{}", | ||||
|         kg | ||||
|     ); | ||||
|  | ||||
|     let ct = send_cmd(&mut s, &["AGE", "ENCRYPTNAME", "app1", "hello"]).await; | ||||
|     let ct_b64 = extract_bulk_payload(&ct).expect("Failed to parse bulk payload from ENCRYPTNAME"); | ||||
|     let pt = send_cmd(&mut s, &["AGE", "DECRYPTNAME", "app1", &ct_b64]).await; | ||||
|     assert_contains(&pt, "hello", "DECRYPTNAME round-trip"); | ||||
|  | ||||
|     // SIGNKEYGEN + SIGNNAME/VERIFYNAME | ||||
|     let skg = send_cmd(&mut s, &["AGE", "SIGNKEYGEN", "app1"]).await; | ||||
|     assert!( | ||||
|         skg.starts_with("*2\r\n"), | ||||
|         "AGE SIGNKEYGEN should return [verify_pub, sign_secret], got:\n{}", | ||||
|         skg | ||||
|     ); | ||||
|  | ||||
|     let sig = send_cmd(&mut s, &["AGE", "SIGNNAME", "app1", "m"] ).await; | ||||
|     let sig_b64 = extract_bulk_payload(&sig).expect("Failed to parse bulk payload from SIGNNAME"); | ||||
|     let v1 = send_cmd(&mut s, &["AGE", "VERIFYNAME", "app1", "m", &sig_b64]).await; | ||||
|     assert_contains(&v1, "1", "VERIFYNAME valid => 1"); | ||||
|  | ||||
|     let v0 = send_cmd(&mut s, &["AGE", "VERIFYNAME", "app1", "bad", &sig_b64]).await; | ||||
|     assert_contains(&v0, "0", "VERIFYNAME invalid => 0"); | ||||
|  | ||||
|     // AGE LIST | ||||
|     let lst = send_cmd(&mut s, &["AGE", "LIST"]).await; | ||||
|     assert_contains(&lst, "encpub", "AGE LIST label encpub"); | ||||
|     assert_contains(&lst, "app1", "AGE LIST includes app1"); | ||||
| } | ||||
|  | ||||
| #[tokio::test] | ||||
| async fn test_10_expire_pexpire_persist() { | ||||
|    let (server, port) = start_test_server("expire_suite").await; | ||||
|    spawn_listener(server, port).await; | ||||
|    sleep(Duration::from_millis(150)).await; | ||||
|  | ||||
|    let mut s = connect(port).await; | ||||
|  | ||||
|    // EXPIRE: seconds | ||||
|    let _ = send_cmd(&mut s, &["SET", "exp:s", "v"]).await; | ||||
|    let ex = send_cmd(&mut s, &["EXPIRE", "exp:s", "1"]).await; | ||||
|    assert_contains(&ex, "1", "EXPIRE exp:s 1 -> 1 (applied)"); | ||||
|    let ttl1 = send_cmd(&mut s, &["TTL", "exp:s"]).await; | ||||
|    assert!( | ||||
|        ttl1.contains("1") || ttl1.contains("0"), | ||||
|        "TTL exp:s should be 1 or 0, got: {}", | ||||
|        ttl1 | ||||
|    ); | ||||
|    sleep(Duration::from_millis(1100)).await; | ||||
|    let get_after = send_cmd(&mut s, &["GET", "exp:s"]).await; | ||||
|    assert_contains(&get_after, "$-1", "GET after expiry should be Null"); | ||||
|    let ttl_after = send_cmd(&mut s, &["TTL", "exp:s"]).await; | ||||
|    assert_contains(&ttl_after, "-2", "TTL after expiry -> -2"); | ||||
|    let exists_after = send_cmd(&mut s, &["EXISTS", "exp:s"]).await; | ||||
|    assert_contains(&exists_after, "0", "EXISTS after expiry -> 0"); | ||||
|  | ||||
|    // PEXPIRE: milliseconds | ||||
|    let _ = send_cmd(&mut s, &["SET", "exp:ms", "v"]).await; | ||||
|    let pex = send_cmd(&mut s, &["PEXPIRE", "exp:ms", "1500"]).await; | ||||
|    assert_contains(&pex, "1", "PEXPIRE exp:ms 1500 -> 1 (applied)"); | ||||
|    let ttl_ms1 = send_cmd(&mut s, &["TTL", "exp:ms"]).await; | ||||
|    assert!( | ||||
|        ttl_ms1.contains("1") || ttl_ms1.contains("0"), | ||||
|        "TTL exp:ms should be 1 or 0 soon after PEXPIRE, got: {}", | ||||
|        ttl_ms1 | ||||
|    ); | ||||
|    sleep(Duration::from_millis(1600)).await; | ||||
|    let exists_ms_after = send_cmd(&mut s, &["EXISTS", "exp:ms"]).await; | ||||
|    assert_contains(&exists_ms_after, "0", "EXISTS exp:ms after ms expiry -> 0"); | ||||
|  | ||||
|    // PERSIST: remove expiration | ||||
|    let _ = send_cmd(&mut s, &["SET", "exp:persist", "v"]).await; | ||||
|    let _ = send_cmd(&mut s, &["EXPIRE", "exp:persist", "5"]).await; | ||||
|    let ttl_pre = send_cmd(&mut s, &["TTL", "exp:persist"]).await; | ||||
|    assert!( | ||||
|        ttl_pre.contains("5") || ttl_pre.contains("4") || ttl_pre.contains("3") || ttl_pre.contains("2") || ttl_pre.contains("1") || ttl_pre.contains("0"), | ||||
|        "TTL exp:persist should be >=0 before persist, got: {}", | ||||
|        ttl_pre | ||||
|    ); | ||||
|    let persist1 = send_cmd(&mut s, &["PERSIST", "exp:persist"]).await; | ||||
|    assert_contains(&persist1, "1", "PERSIST should remove expiration"); | ||||
|    let ttl_post = send_cmd(&mut s, &["TTL", "exp:persist"]).await; | ||||
|    assert_contains(&ttl_post, "-1", "TTL after PERSIST -> -1 (no expiration)"); | ||||
|    // Second persist should return 0 (nothing to remove) | ||||
|    let persist2 = send_cmd(&mut s, &["PERSIST", "exp:persist"]).await; | ||||
|    assert_contains(&persist2, "0", "PERSIST again -> 0 (no expiration to remove)"); | ||||
| } | ||||
|  | ||||
| #[tokio::test] | ||||
| async fn test_11_set_with_options() { | ||||
|     let (server, port) = start_test_server("set_opts").await; | ||||
|     spawn_listener(server, port).await; | ||||
|     sleep(Duration::from_millis(150)).await; | ||||
|  | ||||
|     let mut s = connect(port).await; | ||||
|  | ||||
|     // SET with GET on non-existing key -> returns Null, sets value | ||||
|     let set_get1 = send_cmd(&mut s, &["SET", "s1", "v1", "GET"]).await; | ||||
|     assert_contains(&set_get1, "$-1", "SET s1 v1 GET returns Null when key didn't exist"); | ||||
|     let g1 = send_cmd(&mut s, &["GET", "s1"]).await; | ||||
|     assert_contains(&g1, "v1", "GET s1 after first SET"); | ||||
|  | ||||
|     // SET with GET should return old value, then set to new | ||||
|     let set_get2 = send_cmd(&mut s, &["SET", "s1", "v2", "GET"]).await; | ||||
|     assert_contains(&set_get2, "v1", "SET s1 v2 GET returns previous value v1"); | ||||
|     let g2 = send_cmd(&mut s, &["GET", "s1"]).await; | ||||
|     assert_contains(&g2, "v2", "GET s1 now v2"); | ||||
|  | ||||
|     // NX prevents update when key exists; with GET should return Null and not change | ||||
|     let set_nx = send_cmd(&mut s, &["SET", "s1", "v3", "NX", "GET"]).await; | ||||
|     assert_contains(&set_nx, "$-1", "SET s1 v3 NX GET returns Null when not set"); | ||||
|     let g3 = send_cmd(&mut s, &["GET", "s1"]).await; | ||||
|     assert_contains(&g3, "v2", "GET s1 remains v2 after NX prevented write"); | ||||
|  | ||||
|     // NX allows set when key does not exist | ||||
|     let set_nx2 = send_cmd(&mut s, &["SET", "s2", "v10", "NX"]).await; | ||||
|     assert_contains(&set_nx2, "OK", "SET s2 v10 NX -> OK for new key"); | ||||
|     let g4 = send_cmd(&mut s, &["GET", "s2"]).await; | ||||
|     assert_contains(&g4, "v10", "GET s2 is v10"); | ||||
|  | ||||
|     // XX requires existing key; with GET returns old value and sets new | ||||
|     let set_xx = send_cmd(&mut s, &["SET", "s2", "v11", "XX", "GET"]).await; | ||||
|     assert_contains(&set_xx, "v10", "SET s2 v11 XX GET returns previous v10"); | ||||
|     let g5 = send_cmd(&mut s, &["GET", "s2"]).await; | ||||
|     assert_contains(&g5, "v11", "GET s2 is now v11"); | ||||
|  | ||||
|     // PX expiration path via SET options | ||||
|     let set_px = send_cmd(&mut s, &["SET", "s3", "vpx", "PX", "500"]).await; | ||||
|     assert_contains(&set_px, "OK", "SET s3 vpx PX 500 -> OK"); | ||||
|     let ttl_px1 = send_cmd(&mut s, &["TTL", "s3"]).await; | ||||
|     assert!( | ||||
|         ttl_px1.contains("0") || ttl_px1.contains("1"), | ||||
|         "TTL s3 immediately after PX should be 1 or 0, got: {}", | ||||
|         ttl_px1 | ||||
|     ); | ||||
|     sleep(Duration::from_millis(650)).await; | ||||
|     let g6 = send_cmd(&mut s, &["GET", "s3"]).await; | ||||
|     assert_contains(&g6, "$-1", "GET s3 after PX expiry -> Null"); | ||||
| } | ||||
|  | ||||
| #[tokio::test] | ||||
| async fn test_09_mget_mset_and_variadic_exists_del() { | ||||
|    let (server, port) = start_test_server("mget_mset_variadic").await; | ||||
|    spawn_listener(server, port).await; | ||||
|    sleep(Duration::from_millis(150)).await; | ||||
|  | ||||
|    let mut s = connect(port).await; | ||||
|  | ||||
|    // MSET multiple keys | ||||
|    let mset = send_cmd(&mut s, &["MSET", "k1", "v1", "k2", "v2", "k3", "v3"]).await; | ||||
|    assert_contains(&mset, "OK", "MSET k1 v1 k2 v2 k3 v3 -> OK"); | ||||
|  | ||||
|    // MGET should return values and Null for missing | ||||
|    let mget = send_cmd(&mut s, &["MGET", "k1", "k2", "nope", "k3"]).await; | ||||
|    // Expect an array with 4 entries; verify payloads | ||||
|    assert_contains(&mget, "v1", "MGET k1"); | ||||
|    assert_contains(&mget, "v2", "MGET k2"); | ||||
|    assert_contains(&mget, "v3", "MGET k3"); | ||||
|    assert_contains(&mget, "$-1", "MGET missing returns Null"); | ||||
|  | ||||
|    // EXISTS variadic: count how many exist | ||||
|    let exists_multi = send_cmd(&mut s, &["EXISTS", "k1", "nope", "k3"]).await; | ||||
|    // Server returns SimpleString numeric, e.g. +2 | ||||
|    assert_contains(&exists_multi, "2", "EXISTS k1 nope k3 -> 2"); | ||||
|  | ||||
|    // DEL variadic: delete multiple keys, return count deleted | ||||
|    let del_multi = send_cmd(&mut s, &["DEL", "k1", "k3", "nope"]).await; | ||||
|    assert_contains(&del_multi, "2", "DEL k1 k3 nope -> 2"); | ||||
|  | ||||
|    // Verify deletion | ||||
|    let exists_after = send_cmd(&mut s, &["EXISTS", "k1", "k3"]).await; | ||||
|    assert_contains(&exists_after, "0", "EXISTS k1 k3 after DEL -> 0"); | ||||
|  | ||||
|    // MGET after deletion should include Nulls for deleted keys | ||||
|    let mget_after = send_cmd(&mut s, &["MGET", "k1", "k2", "k3"]).await; | ||||
|    assert_contains(&mget_after, "$-1", "MGET k1 after DEL -> Null"); | ||||
|    assert_contains(&mget_after, "v2", "MGET k2 remains"); | ||||
|    assert_contains(&mget_after, "$-1", "MGET k3 after DEL -> Null"); | ||||
| } | ||||
| #[tokio::test] | ||||
| async fn test_12_hash_incr() { | ||||
|     let (server, port) = start_test_server("hash_incr").await; | ||||
|     spawn_listener(server, port).await; | ||||
|     sleep(Duration::from_millis(150)).await; | ||||
|  | ||||
|     let mut s = connect(port).await; | ||||
|  | ||||
|     // Integer increments | ||||
|     let _ = send_cmd(&mut s, &["HSET", "hinc", "a", "1"]).await; | ||||
|     let r1 = send_cmd(&mut s, &["HINCRBY", "hinc", "a", "2"]).await; | ||||
|     assert_contains(&r1, "3", "HINCRBY hinc a 2 -> 3"); | ||||
|  | ||||
|     let r2 = send_cmd(&mut s, &["HINCRBY", "hinc", "a", "-1"]).await; | ||||
|     assert_contains(&r2, "2", "HINCRBY hinc a -1 -> 2"); | ||||
|  | ||||
|     let r3 = send_cmd(&mut s, &["HINCRBY", "hinc", "b", "5"]).await; | ||||
|     assert_contains(&r3, "5", "HINCRBY hinc b 5 -> 5"); | ||||
|  | ||||
|     // HINCRBY error on non-integer field | ||||
|     let _ = send_cmd(&mut s, &["HSET", "hinc", "s", "x"]).await; | ||||
|     let r_err = send_cmd(&mut s, &["HINCRBY", "hinc", "s", "1"]).await; | ||||
|     assert_contains(&r_err, "ERR", "HINCRBY on non-integer field should ERR"); | ||||
|  | ||||
|     // Float increments | ||||
|     let r4 = send_cmd(&mut s, &["HINCRBYFLOAT", "hinc", "f", "1.5"]).await; | ||||
|     assert_contains(&r4, "1.5", "HINCRBYFLOAT hinc f 1.5 -> 1.5"); | ||||
|  | ||||
|     let r5 = send_cmd(&mut s, &["HINCRBYFLOAT", "hinc", "f", "2.5"]).await; | ||||
|     // Could be "4", "4.0", or "4.000000", accept "4" substring | ||||
|     assert_contains(&r5, "4", "HINCRBYFLOAT hinc f 2.5 -> 4"); | ||||
|  | ||||
|     // HINCRBYFLOAT error on non-float field | ||||
|     let _ = send_cmd(&mut s, &["HSET", "hinc", "notf", "abc"]).await; | ||||
|     let r6 = send_cmd(&mut s, &["HINCRBYFLOAT", "hinc", "notf", "1"]).await; | ||||
|     assert_contains(&r6, "ERR", "HINCRBYFLOAT on non-float field should ERR"); | ||||
| } | ||||
| #[tokio::test] | ||||
| async fn test_05b_brpop_suite() { | ||||
|     let (server, port) = start_test_server("lists_brpop").await; | ||||
|     spawn_listener(server, port).await; | ||||
|     sleep(Duration::from_millis(150)).await; | ||||
|  | ||||
|     let mut a = connect(port).await; | ||||
|  | ||||
|     // RPUSH some initial data, BRPOP should take from the right | ||||
|     let _ = send_cmd(&mut a, &["RPUSH", "q:rjobs", "1", "2"]).await; | ||||
|     let br_nonblock = send_cmd(&mut a, &["BRPOP", "q:rjobs", "0"]).await; | ||||
|     // Should pop the rightmost element "2" | ||||
|     assert_contains(&br_nonblock, "q:rjobs", "BRPOP returns key"); | ||||
|     assert_contains(&br_nonblock, "2", "BRPOP returns rightmost element"); | ||||
|  | ||||
|     // Now test blocking BRPOP: start blocked client, then RPUSH from another client | ||||
|     let c1 = connect(port).await; | ||||
|     let mut c2 = connect(port).await; | ||||
|  | ||||
|     // Start BRPOP on c1 | ||||
|     let brpop_task = tokio::spawn(async move { | ||||
|         let mut c1_local = c1; | ||||
|         send_cmd(&mut c1_local, &["BRPOP", "q:blockr", "5"]).await | ||||
|     }); | ||||
|  | ||||
|     // Give it time to register waiter | ||||
|     sleep(Duration::from_millis(150)).await; | ||||
|  | ||||
|     // Push from right to wake BRPOP | ||||
|     let _ = send_cmd(&mut c2, &["RPUSH", "q:blockr", "X"]).await; | ||||
|  | ||||
|     // Await BRPOP result | ||||
|     let brpop_res = brpop_task.await.expect("BRPOP task join"); | ||||
|     assert_contains(&brpop_res, "q:blockr", "BRPOP returned key"); | ||||
|     assert_contains(&brpop_res, "X", "BRPOP returned element"); | ||||
| } | ||||
| #[tokio::test] | ||||
| async fn test_13_dbsize() { | ||||
|     let (server, port) = start_test_server("dbsize").await; | ||||
|     spawn_listener(server, port).await; | ||||
|     sleep(Duration::from_millis(150)).await; | ||||
|  | ||||
|     let mut s = connect(port).await; | ||||
|  | ||||
|     // Initially empty | ||||
|     let n0 = send_cmd(&mut s, &["DBSIZE"]).await; | ||||
|     assert_contains(&n0, "0", "DBSIZE initial should be 0"); | ||||
|  | ||||
|     // Add a string, a hash, and a list -> dbsize = 3 | ||||
|     let _ = send_cmd(&mut s, &["SET", "s", "v"]).await; | ||||
|     let _ = send_cmd(&mut s, &["HSET", "h", "f", "v"]).await; | ||||
|     let _ = send_cmd(&mut s, &["LPUSH", "l", "a", "b"]).await; | ||||
|  | ||||
|     let n3 = send_cmd(&mut s, &["DBSIZE"]).await; | ||||
|     assert_contains(&n3, "3", "DBSIZE after adding s,h,l should be 3"); | ||||
|  | ||||
|     // Expire the string and wait, dbsize should drop to 2 | ||||
|     let _ = send_cmd(&mut s, &["PEXPIRE", "s", "400"]).await; | ||||
|     sleep(Duration::from_millis(500)).await; | ||||
|  | ||||
|     let n2 = send_cmd(&mut s, &["DBSIZE"]).await; | ||||
|     assert_contains(&n2, "2", "DBSIZE after string expiry should be 2"); | ||||
|  | ||||
|     // Delete remaining keys and confirm 0 | ||||
|     let _ = send_cmd(&mut s, &["DEL", "h"]).await; | ||||
|     let _ = send_cmd(&mut s, &["DEL", "l"]).await; | ||||
|  | ||||
|     let n_final = send_cmd(&mut s, &["DBSIZE"]).await; | ||||
|     assert_contains(&n_final, "0", "DBSIZE after deleting all keys should be 0"); | ||||
| } | ||||
| #[tokio::test] | ||||
| async fn test_14_expireat_pexpireat() { | ||||
|     use std::time::{SystemTime, UNIX_EPOCH}; | ||||
|  | ||||
|     let (server, port) = start_test_server("expireat_suite").await; | ||||
|     spawn_listener(server, port).await; | ||||
|     sleep(Duration::from_millis(150)).await; | ||||
|  | ||||
|     let mut s = connect(port).await; | ||||
|  | ||||
|     // EXPIREAT: seconds since epoch | ||||
|     let now_secs = SystemTime::now().duration_since(UNIX_EPOCH).unwrap().as_secs() as i64; | ||||
|     let _ = send_cmd(&mut s, &["SET", "exp:at:s", "v"]).await; | ||||
|     let exat = send_cmd(&mut s, &["EXPIREAT", "exp:at:s", &format!("{}", now_secs + 1)]).await; | ||||
|     assert_contains(&exat, "1", "EXPIREAT exp:at:s now+1s -> 1 (applied)"); | ||||
|     let ttl1 = send_cmd(&mut s, &["TTL", "exp:at:s"]).await; | ||||
|     assert!( | ||||
|         ttl1.contains("1") || ttl1.contains("0"), | ||||
|         "TTL exp:at:s should be 1 or 0 shortly after EXPIREAT, got: {}", | ||||
|         ttl1 | ||||
|     ); | ||||
|     sleep(Duration::from_millis(1200)).await; | ||||
|     let exists_after_exat = send_cmd(&mut s, &["EXISTS", "exp:at:s"]).await; | ||||
|     assert_contains(&exists_after_exat, "0", "EXISTS exp:at:s after EXPIREAT expiry -> 0"); | ||||
|  | ||||
|     // PEXPIREAT: milliseconds since epoch | ||||
|     let now_ms = SystemTime::now().duration_since(UNIX_EPOCH).unwrap().as_millis() as i64; | ||||
|     let _ = send_cmd(&mut s, &["SET", "exp:at:ms", "v"]).await; | ||||
|     let pexat = send_cmd(&mut s, &["PEXPIREAT", "exp:at:ms", &format!("{}", now_ms + 450)]).await; | ||||
|     assert_contains(&pexat, "1", "PEXPIREAT exp:at:ms now+450ms -> 1 (applied)"); | ||||
|     let ttl2 = send_cmd(&mut s, &["TTL", "exp:at:ms"]).await; | ||||
|     assert!( | ||||
|         ttl2.contains("0") || ttl2.contains("1"), | ||||
|         "TTL exp:at:ms should be 0..1 soon after PEXPIREAT, got: {}", | ||||
|         ttl2 | ||||
|     ); | ||||
|     sleep(Duration::from_millis(600)).await; | ||||
|     let exists_after_pexat = send_cmd(&mut s, &["EXISTS", "exp:at:ms"]).await; | ||||
|     assert_contains(&exists_after_pexat, "0", "EXISTS exp:at:ms after PEXPIREAT expiry -> 0"); | ||||
| } | ||||
		Reference in New Issue
	
	Block a user