Files
herodb/src/cmd.rs
2025-08-23 04:57:47 +02:00

1516 lines
68 KiB
Rust

use crate::{error::DBError, protocol::Protocol, server::Server};
use tokio::time::{timeout, Duration};
use futures::future::select_all;
#[derive(Debug, Clone)]
pub enum Cmd {
Ping,
Echo(String),
Select(u64), // Changed from u16 to u64
Get(String),
Set(String, String),
SetPx(String, String, u128),
SetEx(String, String, u128),
// Advanced SET with options: (key, value, ex_ms, nx, xx, get)
SetOpts(String, String, Option<u128>, bool, bool, bool),
MGet(Vec<String>),
MSet(Vec<(String, String)>),
Keys,
DbSize,
ConfigGet(String),
Info(Option<String>),
Del(String),
Type(String),
Incr(String),
Multi,
Exec,
Discard,
// Hash commands
HSet(String, Vec<(String, String)>),
HGet(String, String),
HGetAll(String),
HDel(String, Vec<String>),
HExists(String, String),
HKeys(String),
HVals(String),
HLen(String),
HMGet(String, Vec<String>),
HSetNx(String, String, String),
HIncrBy(String, String, i64),
HIncrByFloat(String, String, f64),
HScan(String, u64, Option<String>, Option<u64>), // key, cursor, pattern, count
Scan(u64, Option<String>, Option<u64>), // cursor, pattern, count
Ttl(String),
Expire(String, i64),
PExpire(String, i64),
ExpireAt(String, i64),
PExpireAt(String, i64),
Persist(String),
Exists(String),
ExistsMulti(Vec<String>),
DelMulti(Vec<String>),
Quit,
Client(Vec<String>),
ClientSetName(String),
ClientGetName,
Command(Vec<String>),
// List commands
LPush(String, Vec<String>),
RPush(String, Vec<String>),
LPop(String, Option<u64>),
RPop(String, Option<u64>),
BLPop(Vec<String>, f64),
BRPop(Vec<String>, f64),
LLen(String),
LRem(String, i64, String),
LTrim(String, i64, i64),
LIndex(String, i64),
LRange(String, i64, i64),
FlushDb,
Unknow(String),
// AGE (rage) commands — stateless
AgeGenEnc,
AgeGenSign,
AgeEncrypt(String, String), // recipient, message
AgeDecrypt(String, String), // identity, ciphertext_b64
AgeSign(String, String), // signing_secret, message
AgeVerify(String, String, String), // verify_pub, message, signature_b64
// NEW: persistent named-key commands
AgeKeygen(String), // name
AgeSignKeygen(String), // name
AgeEncryptName(String, String), // name, message
AgeDecryptName(String, String), // name, ciphertext_b64
AgeSignName(String, String), // name, message
AgeVerifyName(String, String, String), // name, message, signature_b64
AgeList,
}
impl Cmd {
pub fn from(s: &str) -> Result<(Self, Protocol, &str), DBError> {
let (protocol, remaining) = Protocol::from(s)?;
match protocol.clone() {
Protocol::Array(p) => {
let cmd = p.into_iter().map(|x| x.decode()).collect::<Vec<_>>();
if cmd.is_empty() {
return Err(DBError("cmd length is 0".to_string()));
}
Ok((
match cmd[0].to_lowercase().as_str() {
"select" => {
if cmd.len() != 2 {
return Err(DBError("wrong number of arguments for SELECT".to_string()));
}
let idx = cmd[1].parse::<u64>().map_err(|_| DBError("ERR DB index is not an integer".to_string()))?;
Cmd::Select(idx)
}
"echo" => Cmd::Echo(cmd[1].clone()),
"ping" => Cmd::Ping,
"get" => Cmd::Get(cmd[1].clone()),
"set" => {
if cmd.len() < 3 {
return Err(DBError("wrong number of arguments for SET".to_string()));
}
let key = cmd[1].clone();
let val = cmd[2].clone();
// Parse optional flags: EX sec | PX ms | NX | XX | GET
let mut ex_ms: Option<u128> = None;
let mut nx = false;
let mut xx = false;
let mut getflag = false;
let mut i = 3;
while i < cmd.len() {
match cmd[i].to_lowercase().as_str() {
"ex" => {
if i + 1 >= cmd.len() {
return Err(DBError("ERR syntax error".to_string()));
}
let secs: u128 = cmd[i + 1].parse().map_err(|_| DBError("ERR value is not an integer or out of range".to_string()))?;
ex_ms = Some(secs * 1000);
i += 2;
}
"px" => {
if i + 1 >= cmd.len() {
return Err(DBError("ERR syntax error".to_string()));
}
let ms: u128 = cmd[i + 1].parse().map_err(|_| DBError("ERR value is not an integer or out of range".to_string()))?;
ex_ms = Some(ms);
i += 2;
}
"nx" => { nx = true; i += 1; }
"xx" => { xx = true; i += 1; }
"get" => { getflag = true; i += 1; }
_ => {
return Err(DBError(format!("unsupported cmd {:?}", cmd)));
}
}
}
// If no options, keep legacy behavior
if ex_ms.is_none() && !nx && !xx && !getflag {
Cmd::Set(key, val)
} else {
Cmd::SetOpts(key, val, ex_ms, nx, xx, getflag)
}
}
"setex" => {
if cmd.len() != 4 {
return Err(DBError(format!("wrong number of arguments for SETEX command")));
}
Cmd::SetEx(cmd[1].clone(), cmd[3].clone(), cmd[2].parse().unwrap())
}
"mget" => {
if cmd.len() < 2 {
return Err(DBError("wrong number of arguments for MGET command".to_string()));
}
Cmd::MGet(cmd[1..].to_vec())
}
"mset" => {
if cmd.len() < 3 || ((cmd.len() - 1) % 2 != 0) {
return Err(DBError("wrong number of arguments for MSET command".to_string()));
}
let mut pairs = Vec::new();
let mut i = 1;
while i + 1 < cmd.len() {
pairs.push((cmd[i].clone(), cmd[i + 1].clone()));
i += 2;
}
Cmd::MSet(pairs)
}
"config" => {
if cmd.len() != 3 || cmd[1].to_lowercase() != "get" {
return Err(DBError(format!("unsupported cmd {:?}", cmd)));
} else {
Cmd::ConfigGet(cmd[2].clone())
}
}
"keys" => {
if cmd.len() != 2 || cmd[1] != "*" {
return Err(DBError(format!("unsupported cmd {:?}", cmd)));
} else {
Cmd::Keys
}
}
"dbsize" => {
if cmd.len() != 1 {
return Err(DBError(format!("wrong number of arguments for DBSIZE command")));
}
Cmd::DbSize
}
"info" => {
let section = if cmd.len() == 2 {
Some(cmd[1].clone())
} else {
None
};
Cmd::Info(section)
}
"del" => {
if cmd.len() < 2 {
return Err(DBError(format!("wrong number of arguments for DEL command")));
}
if cmd.len() == 2 {
Cmd::Del(cmd[1].clone())
} else {
Cmd::DelMulti(cmd[1..].to_vec())
}
}
"type" => {
if cmd.len() != 2 {
return Err(DBError(format!("unsupported cmd {:?}", cmd)));
}
Cmd::Type(cmd[1].clone())
}
"incr" => {
if cmd.len() != 2 {
return Err(DBError(format!("unsupported cmd {:?}", cmd)));
}
Cmd::Incr(cmd[1].clone())
}
"multi" => {
if cmd.len() != 1 {
return Err(DBError(format!("unsupported cmd {:?}", cmd)));
}
Cmd::Multi
}
"exec" => {
if cmd.len() != 1 {
return Err(DBError(format!("unsupported cmd {:?}", cmd)));
}
Cmd::Exec
}
"discard" => Cmd::Discard,
// Hash commands
"hset" => {
if cmd.len() < 4 || (cmd.len() - 2) % 2 != 0 {
return Err(DBError(format!("wrong number of arguments for HSET command")));
}
let mut pairs = Vec::new();
let mut i = 2;
while i + 1 < cmd.len() {
pairs.push((cmd[i].clone(), cmd[i + 1].clone()));
i += 2;
}
Cmd::HSet(cmd[1].clone(), pairs)
}
"hget" => {
if cmd.len() != 3 {
return Err(DBError(format!("wrong number of arguments for HGET command")));
}
Cmd::HGet(cmd[1].clone(), cmd[2].clone())
}
"hgetall" => {
if cmd.len() != 2 {
return Err(DBError(format!("wrong number of arguments for HGETALL command")));
}
Cmd::HGetAll(cmd[1].clone())
}
"hdel" => {
if cmd.len() < 3 {
return Err(DBError(format!("wrong number of arguments for HDEL command")));
}
Cmd::HDel(cmd[1].clone(), cmd[2..].to_vec())
}
"hexists" => {
if cmd.len() != 3 {
return Err(DBError(format!("wrong number of arguments for HEXISTS command")));
}
Cmd::HExists(cmd[1].clone(), cmd[2].clone())
}
"hkeys" => {
if cmd.len() != 2 {
return Err(DBError(format!("wrong number of arguments for HKEYS command")));
}
Cmd::HKeys(cmd[1].clone())
}
"hvals" => {
if cmd.len() != 2 {
return Err(DBError(format!("wrong number of arguments for HVALS command")));
}
Cmd::HVals(cmd[1].clone())
}
"hlen" => {
if cmd.len() != 2 {
return Err(DBError(format!("wrong number of arguments for HLEN command")));
}
Cmd::HLen(cmd[1].clone())
}
"hmget" => {
if cmd.len() < 3 {
return Err(DBError(format!("wrong number of arguments for HMGET command")));
}
Cmd::HMGet(cmd[1].clone(), cmd[2..].to_vec())
}
"hsetnx" => {
if cmd.len() != 4 {
return Err(DBError(format!("wrong number of arguments for HSETNX command")));
}
Cmd::HSetNx(cmd[1].clone(), cmd[2].clone(), cmd[3].clone())
}
"hincrby" => {
if cmd.len() != 4 {
return Err(DBError(format!("wrong number of arguments for HINCRBY command")));
}
let delta = cmd[3].parse::<i64>().map_err(|_| DBError("ERR value is not an integer or out of range".to_string()))?;
Cmd::HIncrBy(cmd[1].clone(), cmd[2].clone(), delta)
}
"hincrbyfloat" => {
if cmd.len() != 4 {
return Err(DBError(format!("wrong number of arguments for HINCRBYFLOAT command")));
}
let delta = cmd[3].parse::<f64>().map_err(|_| DBError("ERR value is not a valid float".to_string()))?;
Cmd::HIncrByFloat(cmd[1].clone(), cmd[2].clone(), delta)
}
"hscan" => {
if cmd.len() < 3 {
return Err(DBError(format!("wrong number of arguments for HSCAN command")));
}
let key = cmd[1].clone();
let cursor = cmd[2].parse::<u64>().map_err(|_|
DBError("ERR invalid cursor".to_string()))?;
let mut pattern = None;
let mut count = None;
let mut i = 3;
while i < cmd.len() {
match cmd[i].to_lowercase().as_str() {
"match" => {
if i + 1 >= cmd.len() {
return Err(DBError("ERR syntax error".to_string()));
}
pattern = Some(cmd[i + 1].clone());
i += 2;
}
"count" => {
if i + 1 >= cmd.len() {
return Err(DBError("ERR syntax error".to_string()));
}
count = Some(cmd[i + 1].parse::<u64>().map_err(|_|
DBError("ERR value is not an integer or out of range".to_string()))?);
i += 2;
}
_ => {
return Err(DBError(format!("ERR syntax error")));
}
}
}
Cmd::HScan(key, cursor, pattern, count)
}
"scan" => {
if cmd.len() < 2 {
return Err(DBError(format!("wrong number of arguments for SCAN command")));
}
let cursor = cmd[1].parse::<u64>().map_err(|_|
DBError("ERR invalid cursor".to_string()))?;
let mut pattern = None;
let mut count = None;
let mut i = 2;
while i < cmd.len() {
match cmd[i].to_lowercase().as_str() {
"match" => {
if i + 1 >= cmd.len() {
return Err(DBError("ERR syntax error".to_string()));
}
pattern = Some(cmd[i + 1].clone());
i += 2;
}
"count" => {
if i + 1 >= cmd.len() {
return Err(DBError("ERR syntax error".to_string()));
}
count = Some(cmd[i + 1].parse::<u64>().map_err(|_|
DBError("ERR value is not an integer or out of range".to_string()))?);
i += 2;
}
_ => {
return Err(DBError(format!("ERR syntax error")));
}
}
}
Cmd::Scan(cursor, pattern, count)
}
"ttl" => {
if cmd.len() != 2 {
return Err(DBError(format!("wrong number of arguments for TTL command")));
}
Cmd::Ttl(cmd[1].clone())
}
"expire" => {
if cmd.len() != 3 {
return Err(DBError("wrong number of arguments for EXPIRE command".to_string()));
}
let secs = cmd[2].parse::<i64>().map_err(|_| DBError("ERR value is not an integer or out of range".to_string()))?;
Cmd::Expire(cmd[1].clone(), secs)
}
"pexpire" => {
if cmd.len() != 3 {
return Err(DBError("wrong number of arguments for PEXPIRE command".to_string()));
}
let ms = cmd[2].parse::<i64>().map_err(|_| DBError("ERR value is not an integer or out of range".to_string()))?;
Cmd::PExpire(cmd[1].clone(), ms)
}
"expireat" => {
if cmd.len() != 3 {
return Err(DBError("wrong number of arguments for EXPIREAT command".to_string()));
}
let ts = cmd[2].parse::<i64>().map_err(|_| DBError("ERR value is not an integer or out of range".to_string()))?;
Cmd::ExpireAt(cmd[1].clone(), ts)
}
"pexpireat" => {
if cmd.len() != 3 {
return Err(DBError("wrong number of arguments for PEXPIREAT command".to_string()));
}
let ts_ms = cmd[2].parse::<i64>().map_err(|_| DBError("ERR value is not an integer or out of range".to_string()))?;
Cmd::PExpireAt(cmd[1].clone(), ts_ms)
}
"persist" => {
if cmd.len() != 2 {
return Err(DBError("wrong number of arguments for PERSIST command".to_string()));
}
Cmd::Persist(cmd[1].clone())
}
"exists" => {
if cmd.len() < 2 {
return Err(DBError(format!("wrong number of arguments for EXISTS command")));
}
if cmd.len() == 2 {
Cmd::Exists(cmd[1].clone())
} else {
Cmd::ExistsMulti(cmd[1..].to_vec())
}
}
"quit" => {
if cmd.len() != 1 {
return Err(DBError(format!("wrong number of arguments for QUIT command")));
}
Cmd::Quit
}
"client" => {
if cmd.len() > 1 {
match cmd[1].to_lowercase().as_str() {
"setname" => {
if cmd.len() == 3 {
Cmd::ClientSetName(cmd[2].clone())
} else {
return Err(DBError("wrong number of arguments for 'client setname' command".to_string()));
}
}
"getname" => {
if cmd.len() == 2 {
Cmd::ClientGetName
} else {
return Err(DBError("wrong number of arguments for 'client getname' command".to_string()));
}
}
_ => Cmd::Client(cmd[1..].to_vec()),
}
} else {
Cmd::Client(vec![])
}
}
"command" => {
let args = if cmd.len() > 1 { cmd[1..].to_vec() } else { vec![] };
Cmd::Command(args)
}
"lpush" => {
if cmd.len() < 3 {
return Err(DBError(format!("wrong number of arguments for LPUSH command")));
}
Cmd::LPush(cmd[1].clone(), cmd[2..].to_vec())
}
"rpush" => {
if cmd.len() < 3 {
return Err(DBError(format!("wrong number of arguments for RPUSH command")));
}
Cmd::RPush(cmd[1].clone(), cmd[2..].to_vec())
}
"lpop" => {
if cmd.len() < 2 || cmd.len() > 3 {
return Err(DBError(format!("wrong number of arguments for LPOP command")));
}
let count = if cmd.len() == 3 {
Some(cmd[2].parse::<u64>().map_err(|_| DBError("ERR value is not an integer or out of range".to_string()))?)
} else {
None
};
Cmd::LPop(cmd[1].clone(), count)
}
"rpop" => {
if cmd.len() < 2 || cmd.len() > 3 {
return Err(DBError(format!("wrong number of arguments for RPOP command")));
}
let count = if cmd.len() == 3 {
Some(cmd[2].parse::<u64>().map_err(|_| DBError("ERR value is not an integer or out of range".to_string()))?)
} else {
None
};
Cmd::RPop(cmd[1].clone(), count)
}
"blpop" => {
if cmd.len() < 3 {
return Err(DBError(format!("wrong number of arguments for BLPOP command")));
}
// keys are all but the last argument
let keys = cmd[1..cmd.len()-1].to_vec();
let timeout_f = cmd[cmd.len()-1]
.parse::<f64>()
.map_err(|_| DBError("ERR timeout is not a number".to_string()))?;
Cmd::BLPop(keys, timeout_f)
}
"brpop" => {
if cmd.len() < 3 {
return Err(DBError(format!("wrong number of arguments for BRPOP command")));
}
// keys are all but the last argument
let keys = cmd[1..cmd.len()-1].to_vec();
let timeout_f = cmd[cmd.len()-1]
.parse::<f64>()
.map_err(|_| DBError("ERR timeout is not a number".to_string()))?;
Cmd::BRPop(keys, timeout_f)
}
"llen" => {
if cmd.len() != 2 {
return Err(DBError(format!("wrong number of arguments for LLEN command")));
}
Cmd::LLen(cmd[1].clone())
}
"lrem" => {
if cmd.len() != 4 {
return Err(DBError(format!("wrong number of arguments for LREM command")));
}
let count = cmd[2].parse::<i64>().map_err(|_| DBError("ERR value is not an integer or out of range".to_string()))?;
Cmd::LRem(cmd[1].clone(), count, cmd[3].clone())
}
"ltrim" => {
if cmd.len() != 4 {
return Err(DBError(format!("wrong number of arguments for LTRIM command")));
}
let start = cmd[2].parse::<i64>().map_err(|_| DBError("ERR value is not an integer or out of range".to_string()))?;
let stop = cmd[3].parse::<i64>().map_err(|_| DBError("ERR value is not an integer or out of range".to_string()))?;
Cmd::LTrim(cmd[1].clone(), start, stop)
}
"lindex" => {
if cmd.len() != 3 {
return Err(DBError(format!("wrong number of arguments for LINDEX command")));
}
let index = cmd[2].parse::<i64>().map_err(|_| DBError("ERR value is not an integer or out of range".to_string()))?;
Cmd::LIndex(cmd[1].clone(), index)
}
"lrange" => {
if cmd.len() != 4 {
return Err(DBError(format!("wrong number of arguments for LRANGE command")));
}
let start = cmd[2].parse::<i64>().map_err(|_| DBError("ERR value is not an integer or out of range".to_string()))?;
let stop = cmd[3].parse::<i64>().map_err(|_| DBError("ERR value is not an integer or out of range".to_string()))?;
Cmd::LRange(cmd[1].clone(), start, stop)
}
"flushdb" => {
if cmd.len() != 1 {
return Err(DBError("wrong number of arguments for FLUSHDB command".to_string()));
}
Cmd::FlushDb
}
"age" => {
if cmd.len() < 2 {
return Err(DBError("wrong number of arguments for AGE".to_string()));
}
match cmd[1].to_lowercase().as_str() {
// stateless
"genenc" => { if cmd.len() != 2 { return Err(DBError("AGE GENENC takes no args".to_string())); }
Cmd::AgeGenEnc }
"gensign" => { if cmd.len() != 2 { return Err(DBError("AGE GENSIGN takes no args".to_string())); }
Cmd::AgeGenSign }
"encrypt" => { if cmd.len() != 4 { return Err(DBError("AGE ENCRYPT <recipient> <message>".to_string())); }
Cmd::AgeEncrypt(cmd[2].clone(), cmd[3].clone()) }
"decrypt" => { if cmd.len() != 4 { return Err(DBError("AGE DECRYPT <identity> <ciphertext_b64>".to_string())); }
Cmd::AgeDecrypt(cmd[2].clone(), cmd[3].clone()) }
"sign" => { if cmd.len() != 4 { return Err(DBError("AGE SIGN <signing_secret> <message>".to_string())); }
Cmd::AgeSign(cmd[2].clone(), cmd[3].clone()) }
"verify" => { if cmd.len() != 5 { return Err(DBError("AGE VERIFY <verify_pub> <message> <signature_b64>".to_string())); }
Cmd::AgeVerify(cmd[2].clone(), cmd[3].clone(), cmd[4].clone()) }
// persistent names
"keygen" => { if cmd.len() != 3 { return Err(DBError("AGE KEYGEN <name>".to_string())); }
Cmd::AgeKeygen(cmd[2].clone()) }
"signkeygen" => { if cmd.len() != 3 { return Err(DBError("AGE SIGNKEYGEN <name>".to_string())); }
Cmd::AgeSignKeygen(cmd[2].clone()) }
"encryptname" => { if cmd.len() != 4 { return Err(DBError("AGE ENCRYPTNAME <name> <message>".to_string())); }
Cmd::AgeEncryptName(cmd[2].clone(), cmd[3].clone()) }
"decryptname" => { if cmd.len() != 4 { return Err(DBError("AGE DECRYPTNAME <name> <ciphertext_b64>".to_string())); }
Cmd::AgeDecryptName(cmd[2].clone(), cmd[3].clone()) }
"signname" => { if cmd.len() != 4 { return Err(DBError("AGE SIGNNAME <name> <message>".to_string())); }
Cmd::AgeSignName(cmd[2].clone(), cmd[3].clone()) }
"verifyname" => { if cmd.len() != 5 { return Err(DBError("AGE VERIFYNAME <name> <message> <signature_b64>".to_string())); }
Cmd::AgeVerifyName(cmd[2].clone(), cmd[3].clone(), cmd[4].clone()) }
"list" => { if cmd.len() != 2 { return Err(DBError("AGE LIST".to_string())); }
Cmd::AgeList }
_ => return Err(DBError(format!("unsupported AGE subcommand {:?}", cmd))),
}
}
_ => Cmd::Unknow(cmd[0].clone()),
},
protocol,
remaining
))
}
_ => Err(DBError(format!(
"fail to parse as cmd for {:?}",
protocol
))),
}
}
pub async fn run(self, server: &mut Server) -> Result<Protocol, DBError> {
// Handle queued commands for transactions
if server.queued_cmd.is_some()
&& !matches!(self, Cmd::Exec)
&& !matches!(self, Cmd::Multi)
&& !matches!(self, Cmd::Discard)
{
let protocol = self.clone().to_protocol();
server.queued_cmd.as_mut().unwrap().push((self, protocol));
return Ok(Protocol::SimpleString("QUEUED".to_string()));
}
match self {
Cmd::Select(db) => select_cmd(server, db).await,
Cmd::Ping => Ok(Protocol::SimpleString("PONG".to_string())),
Cmd::Echo(s) => Ok(Protocol::BulkString(s)),
Cmd::Get(k) => get_cmd(server, &k).await,
Cmd::Set(k, v) => set_cmd(server, &k, &v).await,
Cmd::SetPx(k, v, x) => set_px_cmd(server, &k, &v, &x).await,
Cmd::SetEx(k, v, x) => set_ex_cmd(server, &k, &v, &x).await,
Cmd::SetOpts(k, v, ex_ms, nx, xx, getflag) => set_with_opts_cmd(server, &k, &v, ex_ms, nx, xx, getflag).await,
Cmd::MGet(keys) => mget_cmd(server, &keys).await,
Cmd::MSet(pairs) => mset_cmd(server, &pairs).await,
Cmd::Del(k) => del_cmd(server, &k).await,
Cmd::DelMulti(keys) => del_multi_cmd(server, &keys).await,
Cmd::ConfigGet(name) => config_get_cmd(&name, server),
Cmd::Keys => keys_cmd(server).await,
Cmd::DbSize => dbsize_cmd(server).await,
Cmd::Info(section) => info_cmd(server, &section).await,
Cmd::Type(k) => type_cmd(server, &k).await,
Cmd::Incr(key) => incr_cmd(server, &key).await,
Cmd::Multi => {
server.queued_cmd = Some(Vec::<(Cmd, Protocol)>::new());
Ok(Protocol::SimpleString("OK".to_string()))
}
Cmd::Exec => exec_cmd(server).await,
Cmd::Discard => {
if server.queued_cmd.is_some() {
server.queued_cmd = None;
Ok(Protocol::SimpleString("OK".to_string()))
} else {
Ok(Protocol::err("ERR DISCARD without MULTI"))
}
}
// Hash commands
Cmd::HSet(key, pairs) => hset_cmd(server, &key, &pairs).await,
Cmd::HGet(key, field) => hget_cmd(server, &key, &field).await,
Cmd::HGetAll(key) => hgetall_cmd(server, &key).await,
Cmd::HDel(key, fields) => hdel_cmd(server, &key, &fields).await,
Cmd::HExists(key, field) => hexists_cmd(server, &key, &field).await,
Cmd::HKeys(key) => hkeys_cmd(server, &key).await,
Cmd::HVals(key) => hvals_cmd(server, &key).await,
Cmd::HLen(key) => hlen_cmd(server, &key).await,
Cmd::HMGet(key, fields) => hmget_cmd(server, &key, &fields).await,
Cmd::HSetNx(key, field, value) => hsetnx_cmd(server, &key, &field, &value).await,
Cmd::HIncrBy(key, field, delta) => hincrby_cmd(server, &key, &field, delta).await,
Cmd::HIncrByFloat(key, field, delta) => hincrbyfloat_cmd(server, &key, &field, delta).await,
Cmd::HScan(key, cursor, pattern, count) => hscan_cmd(server, &key, &cursor, pattern.as_deref(), &count).await,
Cmd::Scan(cursor, pattern, count) => scan_cmd(server, &cursor, pattern.as_deref(), &count).await,
Cmd::Ttl(key) => ttl_cmd(server, &key).await,
Cmd::Expire(key, secs) => expire_cmd(server, &key, secs).await,
Cmd::PExpire(key, ms) => pexpire_cmd(server, &key, ms).await,
Cmd::ExpireAt(key, ts_secs) => expireat_cmd(server, &key, ts_secs).await,
Cmd::PExpireAt(key, ts_ms) => pexpireat_cmd(server, &key, ts_ms).await,
Cmd::Persist(key) => persist_cmd(server, &key).await,
Cmd::Exists(key) => exists_cmd(server, &key).await,
Cmd::ExistsMulti(keys) => exists_multi_cmd(server, &keys).await,
Cmd::Quit => Ok(Protocol::SimpleString("OK".to_string())),
Cmd::Client(_) => Ok(Protocol::SimpleString("OK".to_string())),
Cmd::ClientSetName(name) => client_setname_cmd(server, &name).await,
Cmd::ClientGetName => client_getname_cmd(server).await,
Cmd::Command(args) => command_cmd(&args),
// List commands
Cmd::LPush(key, elements) => lpush_cmd(server, &key, &elements).await,
Cmd::RPush(key, elements) => rpush_cmd(server, &key, &elements).await,
Cmd::LPop(key, count) => lpop_cmd(server, &key, &count).await,
Cmd::RPop(key, count) => rpop_cmd(server, &key, &count).await,
Cmd::BLPop(keys, timeout) => blpop_cmd(server, &keys, timeout).await,
Cmd::BRPop(keys, timeout) => brpop_cmd(server, &keys, timeout).await,
Cmd::LLen(key) => llen_cmd(server, &key).await,
Cmd::LRem(key, count, element) => lrem_cmd(server, &key, count, &element).await,
Cmd::LTrim(key, start, stop) => ltrim_cmd(server, &key, start, stop).await,
Cmd::LIndex(key, index) => lindex_cmd(server, &key, index).await,
Cmd::LRange(key, start, stop) => lrange_cmd(server, &key, start, stop).await,
Cmd::FlushDb => flushdb_cmd(server).await,
// AGE (rage): stateless
Cmd::AgeGenEnc => Ok(crate::age::cmd_age_genenc().await),
Cmd::AgeGenSign => Ok(crate::age::cmd_age_gensign().await),
Cmd::AgeEncrypt(recipient, message) => Ok(crate::age::cmd_age_encrypt(&recipient, &message).await),
Cmd::AgeDecrypt(identity, ct_b64) => Ok(crate::age::cmd_age_decrypt(&identity, &ct_b64).await),
Cmd::AgeSign(secret, message) => Ok(crate::age::cmd_age_sign(&secret, &message).await),
Cmd::AgeVerify(vpub, msg, sig_b64) => Ok(crate::age::cmd_age_verify(&vpub, &msg, &sig_b64).await),
// AGE (rage): persistent named keys
Cmd::AgeKeygen(name) => Ok(crate::age::cmd_age_keygen(server, &name).await),
Cmd::AgeSignKeygen(name) => Ok(crate::age::cmd_age_signkeygen(server, &name).await),
Cmd::AgeEncryptName(name, message) => Ok(crate::age::cmd_age_encrypt_name(server, &name, &message).await),
Cmd::AgeDecryptName(name, ct_b64) => Ok(crate::age::cmd_age_decrypt_name(server, &name, &ct_b64).await),
Cmd::AgeSignName(name, message) => Ok(crate::age::cmd_age_sign_name(server, &name, &message).await),
Cmd::AgeVerifyName(name, message, sig_b64) => Ok(crate::age::cmd_age_verify_name(server, &name, &message, &sig_b64).await),
Cmd::AgeList => Ok(crate::age::cmd_age_list(server).await),
Cmd::Unknow(s) => Ok(Protocol::err(&format!("ERR unknown command `{}`", s))),
}
}
pub fn to_protocol(self) -> Protocol {
match self {
Cmd::Select(db) => Protocol::Array(vec![Protocol::BulkString("select".to_string()), Protocol::BulkString(db.to_string())]),
Cmd::Ping => Protocol::Array(vec![Protocol::BulkString("ping".to_string())]),
Cmd::Echo(s) => Protocol::Array(vec![Protocol::BulkString("echo".to_string()), Protocol::BulkString(s)]),
Cmd::Get(k) => Protocol::Array(vec![Protocol::BulkString("get".to_string()), Protocol::BulkString(k)]),
Cmd::Set(k, v) => Protocol::Array(vec![Protocol::BulkString("set".to_string()), Protocol::BulkString(k), Protocol::BulkString(v)]),
_ => Protocol::SimpleString("...".to_string())
}
}
}
async fn flushdb_cmd(server: &mut Server) -> Result<Protocol, DBError> {
match server.current_storage()?.flushdb() {
Ok(_) => Ok(Protocol::SimpleString("OK".to_string())),
Err(e) => Ok(Protocol::err(&e.0)),
}
}
async fn select_cmd(server: &mut Server, db: u64) -> Result<Protocol, DBError> {
// Test if we can access the database (this will create it if needed)
server.selected_db = db;
match server.current_storage() {
Ok(_) => Ok(Protocol::SimpleString("OK".to_string())),
Err(e) => Ok(Protocol::err(&e.0)),
}
}
async fn lindex_cmd(server: &Server, key: &str, index: i64) -> Result<Protocol, DBError> {
match server.current_storage()?.lindex(key, index) {
Ok(Some(element)) => Ok(Protocol::BulkString(element)),
Ok(None) => Ok(Protocol::Null),
Err(e) => Ok(Protocol::err(&e.0)),
}
}
async fn lrange_cmd(server: &Server, key: &str, start: i64, stop: i64) -> Result<Protocol, DBError> {
match server.current_storage()?.lrange(key, start, stop) {
Ok(elements) => Ok(Protocol::Array(elements.into_iter().map(Protocol::BulkString).collect())),
Err(e) => Ok(Protocol::err(&e.0)),
}
}
async fn ltrim_cmd(server: &Server, key: &str, start: i64, stop: i64) -> Result<Protocol, DBError> {
match server.current_storage()?.ltrim(key, start, stop) {
Ok(_) => Ok(Protocol::SimpleString("OK".to_string())),
Err(e) => Ok(Protocol::err(&e.0)),
}
}
async fn lrem_cmd(server: &Server, key: &str, count: i64, element: &str) -> Result<Protocol, DBError> {
match server.current_storage()?.lrem(key, count, element) {
Ok(removed_count) => Ok(Protocol::SimpleString(removed_count.to_string())),
Err(e) => Ok(Protocol::err(&e.0)),
}
}
async fn llen_cmd(server: &Server, key: &str) -> Result<Protocol, DBError> {
match server.current_storage()?.llen(key) {
Ok(len) => Ok(Protocol::SimpleString(len.to_string())),
Err(e) => Ok(Protocol::err(&e.0)),
}
}
async fn lpop_cmd(server: &Server, key: &str, count: &Option<u64>) -> Result<Protocol, DBError> {
let count_val = count.unwrap_or(1);
match server.current_storage()?.lpop(key, count_val) {
Ok(elements) => {
if elements.is_empty() {
if count.is_some() {
Ok(Protocol::Array(vec![]))
} else {
Ok(Protocol::Null)
}
} else if count.is_some() {
Ok(Protocol::Array(elements.into_iter().map(Protocol::BulkString).collect()))
} else {
Ok(Protocol::BulkString(elements[0].clone()))
}
},
Err(e) => Ok(Protocol::err(&e.0)),
}
}
async fn rpop_cmd(server: &Server, key: &str, count: &Option<u64>) -> Result<Protocol, DBError> {
let count_val = count.unwrap_or(1);
match server.current_storage()?.rpop(key, count_val) {
Ok(elements) => {
if elements.is_empty() {
if count.is_some() {
Ok(Protocol::Array(vec![]))
} else {
Ok(Protocol::Null)
}
} else if count.is_some() {
Ok(Protocol::Array(elements.into_iter().map(Protocol::BulkString).collect()))
} else {
Ok(Protocol::BulkString(elements[0].clone()))
}
},
Err(e) => Ok(Protocol::err(&e.0)),
}
}
// BLPOP implementation
async fn blpop_cmd(server: &Server, keys: &[String], timeout_secs: f64) -> Result<Protocol, DBError> {
// Immediate, non-blocking attempt in key order
for k in keys {
let elems = server.current_storage()?.lpop(k, 1)?;
if !elems.is_empty() {
return Ok(Protocol::Array(vec![
Protocol::BulkString(k.clone()),
Protocol::BulkString(elems[0].clone()),
]));
}
}
// If timeout is zero, return immediately with Null
if timeout_secs <= 0.0 {
return Ok(Protocol::Null);
}
// Register waiters for each key
let db_index = server.selected_db;
let mut ids: Vec<u64> = Vec::with_capacity(keys.len());
let mut names: Vec<String> = Vec::with_capacity(keys.len());
let mut rxs: Vec<tokio::sync::oneshot::Receiver<(String, String)>> = Vec::with_capacity(keys.len());
for k in keys {
let (id, rx) = server.register_waiter(db_index, k, crate::server::PopSide::Left).await;
ids.push(id);
names.push(k.clone());
rxs.push(rx);
}
// Wait for the first delivery or timeout
let wait_fut = async move {
let mut futures_vec = rxs;
loop {
if futures_vec.is_empty() {
return None;
}
let (res, idx, remaining) = select_all(futures_vec).await;
match res {
Ok((k, elem)) => {
return Some((k, elem, idx, remaining));
}
Err(_canceled) => {
// That waiter was canceled; continue with the rest
futures_vec = remaining;
continue;
}
}
}
};
match timeout(Duration::from_secs_f64(timeout_secs), wait_fut).await {
Ok(Some((k, elem, idx, _remaining))) => {
// Unregister other waiters
for (i, key_name) in names.iter().enumerate() {
if i != idx {
server.unregister_waiter(db_index, key_name, ids[i]).await;
}
}
Ok(Protocol::Array(vec![
Protocol::BulkString(k),
Protocol::BulkString(elem),
]))
}
Ok(None) => {
// No futures left; unregister all waiters
for (i, key_name) in names.iter().enumerate() {
server.unregister_waiter(db_index, key_name, ids[i]).await;
}
Ok(Protocol::Null)
}
Err(_elapsed) => {
// Timeout: unregister all waiters
for (i, key_name) in names.iter().enumerate() {
server.unregister_waiter(db_index, key_name, ids[i]).await;
}
Ok(Protocol::Null)
}
}
}
// BRPOP implementation (mirror of BLPOP, popping from the right)
async fn brpop_cmd(server: &Server, keys: &[String], timeout_secs: f64) -> Result<Protocol, DBError> {
// Immediate, non-blocking attempt in key order using RPOP
for k in keys {
let elems = server.current_storage()?.rpop(k, 1)?;
if !elems.is_empty() {
return Ok(Protocol::Array(vec![
Protocol::BulkString(k.clone()),
Protocol::BulkString(elems[0].clone()),
]));
}
}
// If timeout is zero, return immediately with Null
if timeout_secs <= 0.0 {
return Ok(Protocol::Null);
}
// Register waiters for each key (Right side)
let db_index = server.selected_db;
let mut ids: Vec<u64> = Vec::with_capacity(keys.len());
let mut names: Vec<String> = Vec::with_capacity(keys.len());
let mut rxs: Vec<tokio::sync::oneshot::Receiver<(String, String)>> = Vec::with_capacity(keys.len());
for k in keys {
let (id, rx) = server.register_waiter(db_index, k, crate::server::PopSide::Right).await;
ids.push(id);
names.push(k.clone());
rxs.push(rx);
}
// Wait for the first delivery or timeout
let wait_fut = async move {
let mut futures_vec = rxs;
loop {
if futures_vec.is_empty() {
return None;
}
let (res, idx, remaining) = select_all(futures_vec).await;
match res {
Ok((k, elem)) => {
return Some((k, elem, idx, remaining));
}
Err(_canceled) => {
// That waiter was canceled; continue with the rest
futures_vec = remaining;
continue;
}
}
}
};
match timeout(Duration::from_secs_f64(timeout_secs), wait_fut).await {
Ok(Some((k, elem, idx, _remaining))) => {
// Unregister other waiters
for (i, key_name) in names.iter().enumerate() {
if i != idx {
server.unregister_waiter(db_index, key_name, ids[i]).await;
}
}
Ok(Protocol::Array(vec![
Protocol::BulkString(k),
Protocol::BulkString(elem),
]))
}
Ok(None) => {
// No futures left; unregister all waiters
for (i, key_name) in names.iter().enumerate() {
server.unregister_waiter(db_index, key_name, ids[i]).await;
}
Ok(Protocol::Null)
}
Err(_elapsed) => {
// Timeout: unregister all waiters
for (i, key_name) in names.iter().enumerate() {
server.unregister_waiter(db_index, key_name, ids[i]).await;
}
Ok(Protocol::Null)
}
}
}
async fn lpush_cmd(server: &Server, key: &str, elements: &[String]) -> Result<Protocol, DBError> {
match server.current_storage()?.lpush(key, elements.to_vec()) {
Ok(len) => {
// Attempt to deliver to any blocked BLPOP waiters
let _ = server.drain_waiters_after_push(key).await;
Ok(Protocol::SimpleString(len.to_string()))
}
Err(e) => Ok(Protocol::err(&e.0)),
}
}
async fn rpush_cmd(server: &Server, key: &str, elements: &[String]) -> Result<Protocol, DBError> {
match server.current_storage()?.rpush(key, elements.to_vec()) {
Ok(len) => {
// Attempt to deliver to any blocked BLPOP waiters
let _ = server.drain_waiters_after_push(key).await;
Ok(Protocol::SimpleString(len.to_string()))
}
Err(e) => Ok(Protocol::err(&e.0)),
}
}
async fn exec_cmd(server: &mut Server) -> Result<Protocol, DBError> {
// Move the queued commands out of `server` so we drop the borrow immediately.
let cmds = if let Some(cmds) = server.queued_cmd.take() {
cmds
} else {
return Ok(Protocol::err("ERR EXEC without MULTI"));
};
let mut out = Vec::new();
for (cmd, _) in cmds {
// Use Box::pin to handle recursion in async function
let res = Box::pin(cmd.run(server)).await?;
out.push(res);
}
Ok(Protocol::Array(out))
}
async fn incr_cmd(server: &Server, key: &String) -> Result<Protocol, DBError> {
let storage = server.current_storage()?;
let current_value = storage.get(key)?;
let new_value = match current_value {
Some(v) => {
match v.parse::<i64>() {
Ok(num) => num + 1,
Err(_) => return Ok(Protocol::err("ERR value is not an integer or out of range")),
}
}
None => 1,
};
storage.set(key.clone(), new_value.to_string())?;
Ok(Protocol::SimpleString(new_value.to_string()))
}
fn config_get_cmd(name: &String, server: &Server) -> Result<Protocol, DBError> {
let value = match name.as_str() {
"dir" => Some(server.option.dir.clone()),
"dbfilename" => Some(format!("{}.db", server.selected_db)),
"databases" => Some("16".to_string()), // Hardcoded as per original logic
_ => None,
};
if let Some(val) = value {
Ok(Protocol::Array(vec![
Protocol::BulkString(name.clone()),
Protocol::BulkString(val),
]))
} else {
// Return an empty array for unknown config options, which is standard Redis behavior
Ok(Protocol::Array(vec![]))
}
}
async fn keys_cmd(server: &Server) -> Result<Protocol, DBError> {
let keys = server.current_storage()?.keys("*")?;
Ok(Protocol::Array(
keys.into_iter().map(Protocol::BulkString).collect(),
))
}
async fn dbsize_cmd(server: &Server) -> Result<Protocol, DBError> {
match server.current_storage()?.dbsize() {
Ok(n) => Ok(Protocol::SimpleString(n.to_string())),
Err(e) => Ok(Protocol::err(&e.0)),
}
}
async fn info_cmd(server: &Server, section: &Option<String>) -> Result<Protocol, DBError> {
let storage_info = server.current_storage()?.info()?;
let mut info_map: std::collections::HashMap<String, String> = storage_info.into_iter().collect();
info_map.insert("redis_version".to_string(), "7.0.0".to_string());
info_map.insert("selected_db".to_string(), server.selected_db.to_string());
info_map.insert("backend".to_string(), format!("{:?}", server.option.backend));
let mut info_string = String::new();
info_string.push_str("# Server\n");
info_string.push_str(&format!("redis_version:{}\n", info_map.get("redis_version").unwrap()));
info_string.push_str(&format!("backend:{}\n", info_map.get("backend").unwrap()));
info_string.push_str(&format!("encrypted:{}\n", info_map.get("is_encrypted").unwrap()));
info_string.push_str("# Keyspace\n");
info_string.push_str(&format!("db{}:keys={},expires=0,avg_ttl=0\n", info_map.get("selected_db").unwrap(), info_map.get("db_size").unwrap()));
match section {
Some(s) => {
let sl = s.to_lowercase();
if sl == "replication" {
Ok(Protocol::BulkString(
"role:master\nmaster_replid:8371b4fb1155b71f4a04d3e1bc3e18c4a990aeea\nmaster_repl_offset:0\n".to_string()
))
} else {
// Return general info for unknown sections (e.g., SERVER)
Ok(Protocol::BulkString(info_string))
}
}
None => Ok(Protocol::BulkString(info_string)),
}
}
async fn type_cmd(server: &Server, k: &String) -> Result<Protocol, DBError> {
match server.current_storage()?.get_key_type(k)? {
Some(type_str) => Ok(Protocol::SimpleString(type_str)),
None => Ok(Protocol::SimpleString("none".to_string())),
}
}
async fn del_cmd(server: &Server, k: &str) -> Result<Protocol, DBError> {
server.current_storage()?.del(k.to_string())?;
Ok(Protocol::SimpleString("1".to_string()))
}
async fn set_ex_cmd(
server: &Server,
k: &str,
v: &str,
x: &u128,
) -> Result<Protocol, DBError> {
server.current_storage()?.setx(k.to_string(), v.to_string(), *x * 1000)?;
Ok(Protocol::SimpleString("OK".to_string()))
}
async fn set_px_cmd(
server: &Server,
k: &str,
v: &str,
x: &u128,
) -> Result<Protocol, DBError> {
server.current_storage()?.setx(k.to_string(), v.to_string(), *x)?;
Ok(Protocol::SimpleString("OK".to_string()))
}
async fn set_cmd(server: &Server, k: &str, v: &str) -> Result<Protocol, DBError> {
server.current_storage()?.set(k.to_string(), v.to_string())?;
Ok(Protocol::SimpleString("OK".to_string()))
}
// Advanced SET with options: EX/PX/NX/XX/GET
async fn set_with_opts_cmd(
server: &Server,
key: &str,
value: &str,
ex_ms: Option<u128>,
nx: bool,
xx: bool,
get_old: bool,
) -> Result<Protocol, DBError> {
let storage = server.current_storage()?;
// Determine existence (for NX/XX)
let exists = storage.exists(key)?;
// If both NX and XX, condition can never be satisfied -> no-op
let mut should_set = true;
if nx && exists {
should_set = false;
}
if xx && !exists {
should_set = false;
}
// Fetch old value if needed for GET
let old_val = if get_old {
storage.get(key)?
} else {
None
};
if should_set {
if let Some(ms) = ex_ms {
storage.setx(key.to_string(), value.to_string(), ms)?;
} else {
storage.set(key.to_string(), value.to_string())?;
}
}
if get_old {
// Return previous value (or Null), regardless of NX/XX outcome only if set executed?
// We follow Redis semantics: return old value if set executed, else Null
if should_set {
Ok(old_val.map_or(Protocol::Null, Protocol::BulkString))
} else {
Ok(Protocol::Null)
}
} else {
if should_set {
Ok(Protocol::SimpleString("OK".to_string()))
} else {
Ok(Protocol::Null)
}
}
}
// MGET: return array of bulk strings or Null for missing
async fn mget_cmd(server: &Server, keys: &[String]) -> Result<Protocol, DBError> {
let mut out: Vec<Protocol> = Vec::with_capacity(keys.len());
let storage = server.current_storage()?;
for k in keys {
match storage.get(k)? {
Some(v) => out.push(Protocol::BulkString(v)),
None => out.push(Protocol::Null),
}
}
Ok(Protocol::Array(out))
}
// MSET: set multiple key/value pairs, return OK
async fn mset_cmd(server: &Server, pairs: &[(String, String)]) -> Result<Protocol, DBError> {
let storage = server.current_storage()?;
for (k, v) in pairs {
storage.set(k.clone(), v.clone())?;
}
Ok(Protocol::SimpleString("OK".to_string()))
}
// DEL with multiple keys: return count of keys actually deleted
async fn del_multi_cmd(server: &Server, keys: &[String]) -> Result<Protocol, DBError> {
let storage = server.current_storage()?;
let mut deleted = 0i64;
for k in keys {
if storage.exists(k)? {
storage.del(k.clone())?;
deleted += 1;
}
}
Ok(Protocol::SimpleString(deleted.to_string()))
}
// EXISTS with multiple keys: return count existing
async fn exists_multi_cmd(server: &Server, keys: &[String]) -> Result<Protocol, DBError> {
let storage = server.current_storage()?;
let mut count = 0i64;
for k in keys {
if storage.exists(k)? {
count += 1;
}
}
Ok(Protocol::SimpleString(count.to_string()))
}
async fn get_cmd(server: &Server, k: &str) -> Result<Protocol, DBError> {
let v = server.current_storage()?.get(k)?;
Ok(v.map_or(Protocol::Null, Protocol::BulkString))
}
// Hash command implementations
async fn hset_cmd(server: &Server, key: &str, pairs: &[(String, String)]) -> Result<Protocol, DBError> {
let new_fields = server.current_storage()?.hset(key, pairs.to_vec())?;
Ok(Protocol::SimpleString(new_fields.to_string()))
}
async fn hget_cmd(server: &Server, key: &str, field: &str) -> Result<Protocol, DBError> {
match server.current_storage()?.hget(key, field) {
Ok(Some(value)) => Ok(Protocol::BulkString(value)),
Ok(None) => Ok(Protocol::Null),
Err(e) => Ok(Protocol::err(&e.0)),
}
}
async fn hgetall_cmd(server: &Server, key: &str) -> Result<Protocol, DBError> {
match server.current_storage()?.hgetall(key) {
Ok(pairs) => {
let mut result = Vec::new();
for (field, value) in pairs {
result.push(Protocol::BulkString(field));
result.push(Protocol::BulkString(value));
}
Ok(Protocol::Array(result))
}
Err(e) => Ok(Protocol::err(&e.0)),
}
}
async fn hdel_cmd(server: &Server, key: &str, fields: &[String]) -> Result<Protocol, DBError> {
match server.current_storage()?.hdel(key, fields.to_vec()) {
Ok(deleted) => Ok(Protocol::SimpleString(deleted.to_string())),
Err(e) => Ok(Protocol::err(&e.0)),
}
}
async fn hexists_cmd(server: &Server, key: &str, field: &str) -> Result<Protocol, DBError> {
match server.current_storage()?.hexists(key, field) {
Ok(exists) => Ok(Protocol::SimpleString(if exists { "1" } else { "0" }.to_string())),
Err(e) => Ok(Protocol::err(&e.0)),
}
}
async fn hkeys_cmd(server: &Server, key: &str) -> Result<Protocol, DBError> {
match server.current_storage()?.hkeys(key) {
Ok(keys) => Ok(Protocol::Array(
keys.into_iter().map(Protocol::BulkString).collect(),
)),
Err(e) => Ok(Protocol::err(&e.0)),
}
}
async fn hvals_cmd(server: &Server, key: &str) -> Result<Protocol, DBError> {
match server.current_storage()?.hvals(key) {
Ok(values) => Ok(Protocol::Array(
values.into_iter().map(Protocol::BulkString).collect(),
)),
Err(e) => Ok(Protocol::err(&e.0)),
}
}
async fn hlen_cmd(server: &Server, key: &str) -> Result<Protocol, DBError> {
match server.current_storage()?.hlen(key) {
Ok(len) => Ok(Protocol::SimpleString(len.to_string())),
Err(e) => Ok(Protocol::err(&e.0)),
}
}
async fn hmget_cmd(server: &Server, key: &str, fields: &[String]) -> Result<Protocol, DBError> {
match server.current_storage()?.hmget(key, fields.to_vec()) {
Ok(values) => {
let result: Vec<Protocol> = values
.into_iter()
.map(|v| v.map_or(Protocol::Null, Protocol::BulkString))
.collect();
Ok(Protocol::Array(result))
}
Err(e) => Ok(Protocol::err(&e.0)),
}
}
async fn hsetnx_cmd(server: &Server, key: &str, field: &str, value: &str) -> Result<Protocol, DBError> {
match server.current_storage()?.hsetnx(key, field, value) {
Ok(was_set) => Ok(Protocol::SimpleString(if was_set { "1" } else { "0" }.to_string())),
Err(e) => Ok(Protocol::err(&e.0)),
}
}
async fn hincrby_cmd(server: &Server, key: &str, field: &str, delta: i64) -> Result<Protocol, DBError> {
let storage = server.current_storage()?;
let current = storage.hget(key, field)?;
let base: i64 = match current {
Some(v) => v.parse::<i64>().map_err(|_| DBError("ERR value is not an integer or out of range".to_string()))?,
None => 0,
};
let new_val = base.checked_add(delta).ok_or_else(|| DBError("ERR increment or decrement would overflow".to_string()))?;
// Update the field
storage.hset(key, vec![(field.to_string(), new_val.to_string())])?;
Ok(Protocol::SimpleString(new_val.to_string()))
}
async fn hincrbyfloat_cmd(server: &Server, key: &str, field: &str, delta: f64) -> Result<Protocol, DBError> {
let storage = server.current_storage()?;
let current = storage.hget(key, field)?;
let base: f64 = match current {
Some(v) => v.parse::<f64>().map_err(|_| DBError("ERR value is not a valid float".to_string()))?,
None => 0.0,
};
let new_val = base + delta;
// Update the field
storage.hset(key, vec![(field.to_string(), new_val.to_string())])?;
Ok(Protocol::SimpleString(new_val.to_string()))
}
async fn scan_cmd(
server: &Server,
cursor: &u64,
pattern: Option<&str>,
count: &Option<u64>
) -> Result<Protocol, DBError> {
match server.current_storage()?.scan(*cursor, pattern, *count) {
Ok((next_cursor, key_value_pairs)) => {
let mut result = Vec::new();
result.push(Protocol::BulkString(next_cursor.to_string()));
// For SCAN, we only return the keys, not the values
let keys: Vec<Protocol> = key_value_pairs.into_iter().map(|(key, _)| Protocol::BulkString(key)).collect();
result.push(Protocol::Array(keys));
Ok(Protocol::Array(result))
}
Err(e) => Ok(Protocol::err(&format!("ERR {}", e.0))),
}
}
async fn hscan_cmd(
server: &Server,
key: &str,
cursor: &u64,
pattern: Option<&str>,
count: &Option<u64>
) -> Result<Protocol, DBError> {
match server.current_storage()?.hscan(key, *cursor, pattern, *count) {
Ok((next_cursor, field_value_pairs)) => {
let mut result = Vec::new();
result.push(Protocol::BulkString(next_cursor.to_string()));
// For HSCAN, we return field-value pairs flattened
let mut fields_and_values = Vec::new();
for (field, value) in field_value_pairs {
fields_and_values.push(Protocol::BulkString(field));
fields_and_values.push(Protocol::BulkString(value));
}
result.push(Protocol::Array(fields_and_values));
Ok(Protocol::Array(result))
}
Err(e) => Ok(Protocol::err(&format!("ERR {}", e.0))),
}
}
async fn ttl_cmd(server: &Server, key: &str) -> Result<Protocol, DBError> {
match server.current_storage()?.ttl(key) {
Ok(ttl) => Ok(Protocol::SimpleString(ttl.to_string())),
Err(e) => Ok(Protocol::err(&e.0)),
}
}
async fn exists_cmd(server: &Server, key: &str) -> Result<Protocol, DBError> {
match server.current_storage()?.exists(key) {
Ok(exists) => Ok(Protocol::SimpleString(if exists { "1" } else { "0" }.to_string())),
Err(e) => Ok(Protocol::err(&e.0)),
}
}
// EXPIRE key seconds -> 1 if timeout set, 0 otherwise
async fn expire_cmd(server: &Server, key: &str, secs: i64) -> Result<Protocol, DBError> {
if secs < 0 {
return Ok(Protocol::SimpleString("0".to_string()));
}
match server.current_storage()?.expire_seconds(key, secs as u64) {
Ok(applied) => Ok(Protocol::SimpleString(if applied { "1" } else { "0" }.to_string())),
Err(e) => Ok(Protocol::err(&e.0)),
}
}
// PEXPIRE key milliseconds -> 1 if timeout set, 0 otherwise
async fn pexpire_cmd(server: &Server, key: &str, ms: i64) -> Result<Protocol, DBError> {
if ms < 0 {
return Ok(Protocol::SimpleString("0".to_string()));
}
match server.current_storage()?.pexpire_millis(key, ms as u128) {
Ok(applied) => Ok(Protocol::SimpleString(if applied { "1" } else { "0" }.to_string())),
Err(e) => Ok(Protocol::err(&e.0)),
}
}
// PERSIST key -> 1 if timeout removed, 0 otherwise
async fn persist_cmd(server: &Server, key: &str) -> Result<Protocol, DBError> {
match server.current_storage()?.persist(key) {
Ok(removed) => Ok(Protocol::SimpleString(if removed { "1" } else { "0" }.to_string())),
Err(e) => Ok(Protocol::err(&e.0)),
}
}
// EXPIREAT key timestamp-seconds -> 1 if timeout set, 0 otherwise
async fn expireat_cmd(server: &Server, key: &str, ts_secs: i64) -> Result<Protocol, DBError> {
match server.current_storage()?.expire_at_seconds(key, ts_secs) {
Ok(applied) => Ok(Protocol::SimpleString(if applied { "1" } else { "0" }.to_string())),
Err(e) => Ok(Protocol::err(&e.0)),
}
}
// PEXPIREAT key timestamp-milliseconds -> 1 if timeout set, 0 otherwise
async fn pexpireat_cmd(server: &Server, key: &str, ts_ms: i64) -> Result<Protocol, DBError> {
match server.current_storage()?.pexpire_at_millis(key, ts_ms) {
Ok(applied) => Ok(Protocol::SimpleString(if applied { "1" } else { "0" }.to_string())),
Err(e) => Ok(Protocol::err(&e.0)),
}
}
async fn client_setname_cmd(server: &mut Server, name: &str) -> Result<Protocol, DBError> {
server.client_name = Some(name.to_string());
Ok(Protocol::SimpleString("OK".to_string()))
}
async fn client_getname_cmd(server: &Server) -> Result<Protocol, DBError> {
match &server.client_name {
Some(name) => Ok(Protocol::BulkString(name.clone())),
None => Ok(Protocol::Null),
}
}
// Minimal COMMAND subcommands stub to satisfy redis-cli probes.
// - COMMAND DOCS ... => return empty array
// - COMMAND INFO ... => return empty array
// - Any other => empty array
fn command_cmd(args: &[String]) -> Result<Protocol, DBError> {
if args.is_empty() {
return Ok(Protocol::Array(vec![]));
}
let sub = args[0].to_lowercase();
match sub.as_str() {
"docs" => Ok(Protocol::Array(vec![])),
"info" => Ok(Protocol::Array(vec![])),
_ => Ok(Protocol::Array(vec![])),
}
}