Files
herodb/crates/herodb/cmd.rs
2025-08-16 14:22:56 +02:00

971 lines
45 KiB
Rust

use crate::{error::DBError, protocol::Protocol, server::Server};
use serde::Serialize;
#[derive(Debug, Clone)]
pub enum Cmd {
Ping,
Echo(String),
Select(u64), // Changed from u16 to u64
Get(String),
Set(String, String),
SetPx(String, String, u128),
SetEx(String, String, u128),
Keys,
ConfigGet(String),
Info(Option<String>),
Del(String),
Type(String),
Incr(String),
Multi,
Exec,
Discard,
// Hash commands
HSet(String, Vec<(String, String)>),
HGet(String, String),
HGetAll(String),
HDel(String, Vec<String>),
HExists(String, String),
HKeys(String),
HVals(String),
HLen(String),
HMGet(String, Vec<String>),
HSetNx(String, String, String),
HScan(String, u64, Option<String>, Option<u64>), // key, cursor, pattern, count
Scan(u64, Option<String>, Option<u64>), // cursor, pattern, count
Ttl(String),
Exists(String),
Quit,
Client(Vec<String>),
ClientSetName(String),
ClientGetName,
// List commands
LPush(String, Vec<String>),
RPush(String, Vec<String>),
LPop(String, Option<u64>),
RPop(String, Option<u64>),
LLen(String),
LRem(String, i64, String),
LTrim(String, i64, i64),
LIndex(String, i64),
LRange(String, i64, i64),
FlushDb,
Unknow(String),
// AGE (rage) commands — stateless
AgeGenEnc,
AgeGenSign,
AgeEncrypt(String, String), // recipient, message
AgeDecrypt(String, String), // identity, ciphertext_b64
AgeSign(String, String), // signing_secret, message
AgeVerify(String, String, String), // verify_pub, message, signature_b64
// NEW: persistent named-key commands
AgeKeygen(String), // name
AgeSignKeygen(String), // name
AgeEncryptName(String, String), // name, message
AgeDecryptName(String, String), // name, ciphertext_b64
AgeSignName(String, String), // name, message
AgeVerifyName(String, String, String), // name, message, signature_b64
AgeList,
}
impl Cmd {
pub fn from(s: &str) -> Result<(Self, Protocol, &str), DBError> {
let (protocol, remaining) = Protocol::from(s)?;
match protocol.clone() {
Protocol::Array(p) => {
let cmd = p.into_iter().map(|x| x.decode()).collect::<Vec<_>>();
if cmd.is_empty() {
return Err(DBError("cmd length is 0".to_string()));
}
Ok((
match cmd[0].to_lowercase().as_str() {
"select" => {
if cmd.len() != 2 {
return Err(DBError("wrong number of arguments for SELECT".to_string()));
}
let idx = cmd[1].parse::<u64>().map_err(|_| DBError("ERR DB index is not an integer".to_string()))?;
Cmd::Select(idx)
}
"echo" => Cmd::Echo(cmd[1].clone()),
"ping" => Cmd::Ping,
"get" => Cmd::Get(cmd[1].clone()),
"set" => {
if cmd.len() == 5 && cmd[3].to_lowercase() == "px" {
Cmd::SetPx(cmd[1].clone(), cmd[2].clone(), cmd[4].parse().unwrap())
} else if cmd.len() == 5 && cmd[3].to_lowercase() == "ex" {
Cmd::SetEx(cmd[1].clone(), cmd[2].clone(), cmd[4].parse().unwrap())
} else if cmd.len() == 3 {
Cmd::Set(cmd[1].clone(), cmd[2].clone())
} else {
return Err(DBError(format!("unsupported cmd {:?}", cmd)));
}
}
"setex" => {
if cmd.len() != 4 {
return Err(DBError(format!("wrong number of arguments for SETEX command")));
}
Cmd::SetEx(cmd[1].clone(), cmd[3].clone(), cmd[2].parse().unwrap())
}
"config" => {
if cmd.len() != 3 || cmd[1].to_lowercase() != "get" {
return Err(DBError(format!("unsupported cmd {:?}", cmd)));
} else {
Cmd::ConfigGet(cmd[2].clone())
}
}
"keys" => {
if cmd.len() != 2 || cmd[1] != "*" {
return Err(DBError(format!("unsupported cmd {:?}", cmd)));
} else {
Cmd::Keys
}
}
"info" => {
let section = if cmd.len() == 2 {
Some(cmd[1].clone())
} else {
None
};
Cmd::Info(section)
}
"del" => {
if cmd.len() != 2 {
return Err(DBError(format!("unsupported cmd {:?}", cmd)));
}
Cmd::Del(cmd[1].clone())
}
"type" => {
if cmd.len() != 2 {
return Err(DBError(format!("unsupported cmd {:?}", cmd)));
}
Cmd::Type(cmd[1].clone())
}
"incr" => {
if cmd.len() != 2 {
return Err(DBError(format!("unsupported cmd {:?}", cmd)));
}
Cmd::Incr(cmd[1].clone())
}
"multi" => {
if cmd.len() != 1 {
return Err(DBError(format!("unsupported cmd {:?}", cmd)));
}
Cmd::Multi
}
"exec" => {
if cmd.len() != 1 {
return Err(DBError(format!("unsupported cmd {:?}", cmd)));
}
Cmd::Exec
}
"discard" => Cmd::Discard,
// Hash commands
"hset" => {
if cmd.len() < 4 || (cmd.len() - 2) % 2 != 0 {
return Err(DBError(format!("wrong number of arguments for HSET command")));
}
let mut pairs = Vec::new();
let mut i = 2;
while i + 1 < cmd.len() {
pairs.push((cmd[i].clone(), cmd[i + 1].clone()));
i += 2;
}
Cmd::HSet(cmd[1].clone(), pairs)
}
"hget" => {
if cmd.len() != 3 {
return Err(DBError(format!("wrong number of arguments for HGET command")));
}
Cmd::HGet(cmd[1].clone(), cmd[2].clone())
}
"hgetall" => {
if cmd.len() != 2 {
return Err(DBError(format!("wrong number of arguments for HGETALL command")));
}
Cmd::HGetAll(cmd[1].clone())
}
"hdel" => {
if cmd.len() < 3 {
return Err(DBError(format!("wrong number of arguments for HDEL command")));
}
Cmd::HDel(cmd[1].clone(), cmd[2..].to_vec())
}
"hexists" => {
if cmd.len() != 3 {
return Err(DBError(format!("wrong number of arguments for HEXISTS command")));
}
Cmd::HExists(cmd[1].clone(), cmd[2].clone())
}
"hkeys" => {
if cmd.len() != 2 {
return Err(DBError(format!("wrong number of arguments for HKEYS command")));
}
Cmd::HKeys(cmd[1].clone())
}
"hvals" => {
if cmd.len() != 2 {
return Err(DBError(format!("wrong number of arguments for HVALS command")));
}
Cmd::HVals(cmd[1].clone())
}
"hlen" => {
if cmd.len() != 2 {
return Err(DBError(format!("wrong number of arguments for HLEN command")));
}
Cmd::HLen(cmd[1].clone())
}
"hmget" => {
if cmd.len() < 3 {
return Err(DBError(format!("wrong number of arguments for HMGET command")));
}
Cmd::HMGet(cmd[1].clone(), cmd[2..].to_vec())
}
"hsetnx" => {
if cmd.len() != 4 {
return Err(DBError(format!("wrong number of arguments for HSETNX command")));
}
Cmd::HSetNx(cmd[1].clone(), cmd[2].clone(), cmd[3].clone())
}
"hscan" => {
if cmd.len() < 3 {
return Err(DBError(format!("wrong number of arguments for HSCAN command")));
}
let key = cmd[1].clone();
let cursor = cmd[2].parse::<u64>().map_err(|_|
DBError("ERR invalid cursor".to_string()))?;
let mut pattern = None;
let mut count = None;
let mut i = 3;
while i < cmd.len() {
match cmd[i].to_lowercase().as_str() {
"match" => {
if i + 1 >= cmd.len() {
return Err(DBError("ERR syntax error".to_string()));
}
pattern = Some(cmd[i + 1].clone());
i += 2;
}
"count" => {
if i + 1 >= cmd.len() {
return Err(DBError("ERR syntax error".to_string()));
}
count = Some(cmd[i + 1].parse::<u64>().map_err(|_|
DBError("ERR value is not an integer or out of range".to_string()))?);
i += 2;
}
_ => {
return Err(DBError(format!("ERR syntax error")));
}
}
}
Cmd::HScan(key, cursor, pattern, count)
}
"scan" => {
if cmd.len() < 2 {
return Err(DBError(format!("wrong number of arguments for SCAN command")));
}
let cursor = cmd[1].parse::<u64>().map_err(|_|
DBError("ERR invalid cursor".to_string()))?;
let mut pattern = None;
let mut count = None;
let mut i = 2;
while i < cmd.len() {
match cmd[i].to_lowercase().as_str() {
"match" => {
if i + 1 >= cmd.len() {
return Err(DBError("ERR syntax error".to_string()));
}
pattern = Some(cmd[i + 1].clone());
i += 2;
}
"count" => {
if i + 1 >= cmd.len() {
return Err(DBError("ERR syntax error".to_string()));
}
count = Some(cmd[i + 1].parse::<u64>().map_err(|_|
DBError("ERR value is not an integer or out of range".to_string()))?);
i += 2;
}
_ => {
return Err(DBError(format!("ERR syntax error")));
}
}
}
Cmd::Scan(cursor, pattern, count)
}
"ttl" => {
if cmd.len() != 2 {
return Err(DBError(format!("wrong number of arguments for TTL command")));
}
Cmd::Ttl(cmd[1].clone())
}
"exists" => {
if cmd.len() != 2 {
return Err(DBError(format!("wrong number of arguments for EXISTS command")));
}
Cmd::Exists(cmd[1].clone())
}
"quit" => {
if cmd.len() != 1 {
return Err(DBError(format!("wrong number of arguments for QUIT command")));
}
Cmd::Quit
}
"client" => {
if cmd.len() > 1 {
match cmd[1].to_lowercase().as_str() {
"setname" => {
if cmd.len() == 3 {
Cmd::ClientSetName(cmd[2].clone())
} else {
return Err(DBError("wrong number of arguments for 'client setname' command".to_string()));
}
}
"getname" => {
if cmd.len() == 2 {
Cmd::ClientGetName
} else {
return Err(DBError("wrong number of arguments for 'client getname' command".to_string()));
}
}
_ => Cmd::Client(cmd[1..].to_vec()),
}
} else {
Cmd::Client(vec![])
}
}
"lpush" => {
if cmd.len() < 3 {
return Err(DBError(format!("wrong number of arguments for LPUSH command")));
}
Cmd::LPush(cmd[1].clone(), cmd[2..].to_vec())
}
"rpush" => {
if cmd.len() < 3 {
return Err(DBError(format!("wrong number of arguments for RPUSH command")));
}
Cmd::RPush(cmd[1].clone(), cmd[2..].to_vec())
}
"lpop" => {
if cmd.len() < 2 || cmd.len() > 3 {
return Err(DBError(format!("wrong number of arguments for LPOP command")));
}
let count = if cmd.len() == 3 {
Some(cmd[2].parse::<u64>().map_err(|_| DBError("ERR value is not an integer or out of range".to_string()))?)
} else {
None
};
Cmd::LPop(cmd[1].clone(), count)
}
"rpop" => {
if cmd.len() < 2 || cmd.len() > 3 {
return Err(DBError(format!("wrong number of arguments for RPOP command")));
}
let count = if cmd.len() == 3 {
Some(cmd[2].parse::<u64>().map_err(|_| DBError("ERR value is not an integer or out of range".to_string()))?)
} else {
None
};
Cmd::RPop(cmd[1].clone(), count)
}
"llen" => {
if cmd.len() != 2 {
return Err(DBError(format!("wrong number of arguments for LLEN command")));
}
Cmd::LLen(cmd[1].clone())
}
"lrem" => {
if cmd.len() != 4 {
return Err(DBError(format!("wrong number of arguments for LREM command")));
}
let count = cmd[2].parse::<i64>().map_err(|_| DBError("ERR value is not an integer or out of range".to_string()))?;
Cmd::LRem(cmd[1].clone(), count, cmd[3].clone())
}
"ltrim" => {
if cmd.len() != 4 {
return Err(DBError(format!("wrong number of arguments for LTRIM command")));
}
let start = cmd[2].parse::<i64>().map_err(|_| DBError("ERR value is not an integer or out of range".to_string()))?;
let stop = cmd[3].parse::<i64>().map_err(|_| DBError("ERR value is not an integer or out of range".to_string()))?;
Cmd::LTrim(cmd[1].clone(), start, stop)
}
"lindex" => {
if cmd.len() != 3 {
return Err(DBError(format!("wrong number of arguments for LINDEX command")));
}
let index = cmd[2].parse::<i64>().map_err(|_| DBError("ERR value is not an integer or out of range".to_string()))?;
Cmd::LIndex(cmd[1].clone(), index)
}
"lrange" => {
if cmd.len() != 4 {
return Err(DBError(format!("wrong number of arguments for LRANGE command")));
}
let start = cmd[2].parse::<i64>().map_err(|_| DBError("ERR value is not an integer or out of range".to_string()))?;
let stop = cmd[3].parse::<i64>().map_err(|_| DBError("ERR value is not an integer or out of range".to_string()))?;
Cmd::LRange(cmd[1].clone(), start, stop)
}
"flushdb" => {
if cmd.len() != 1 {
return Err(DBError("wrong number of arguments for FLUSHDB command".to_string()));
}
Cmd::FlushDb
}
"age" => {
if cmd.len() < 2 {
return Err(DBError("wrong number of arguments for AGE".to_string()));
}
match cmd[1].to_lowercase().as_str() {
// stateless
"genenc" => { if cmd.len() != 2 { return Err(DBError("AGE GENENC takes no args".to_string())); }
Cmd::AgeGenEnc }
"gensign" => { if cmd.len() != 2 { return Err(DBError("AGE GENSIGN takes no args".to_string())); }
Cmd::AgeGenSign }
"encrypt" => { if cmd.len() != 4 { return Err(DBError("AGE ENCRYPT <recipient> <message>".to_string())); }
Cmd::AgeEncrypt(cmd[2].clone(), cmd[3].clone()) }
"decrypt" => { if cmd.len() != 4 { return Err(DBError("AGE DECRYPT <identity> <ciphertext_b64>".to_string())); }
Cmd::AgeDecrypt(cmd[2].clone(), cmd[3].clone()) }
"sign" => { if cmd.len() != 4 { return Err(DBError("AGE SIGN <signing_secret> <message>".to_string())); }
Cmd::AgeSign(cmd[2].clone(), cmd[3].clone()) }
"verify" => { if cmd.len() != 5 { return Err(DBError("AGE VERIFY <verify_pub> <message> <signature_b64>".to_string())); }
Cmd::AgeVerify(cmd[2].clone(), cmd[3].clone(), cmd[4].clone()) }
// persistent names
"keygen" => { if cmd.len() != 3 { return Err(DBError("AGE KEYGEN <name>".to_string())); }
Cmd::AgeKeygen(cmd[2].clone()) }
"signkeygen" => { if cmd.len() != 3 { return Err(DBError("AGE SIGNKEYGEN <name>".to_string())); }
Cmd::AgeSignKeygen(cmd[2].clone()) }
"encryptname" => { if cmd.len() != 4 { return Err(DBError("AGE ENCRYPTNAME <name> <message>".to_string())); }
Cmd::AgeEncryptName(cmd[2].clone(), cmd[3].clone()) }
"decryptname" => { if cmd.len() != 4 { return Err(DBError("AGE DECRYPTNAME <name> <ciphertext_b64>".to_string())); }
Cmd::AgeDecryptName(cmd[2].clone(), cmd[3].clone()) }
"signname" => { if cmd.len() != 4 { return Err(DBError("AGE SIGNNAME <name> <message>".to_string())); }
Cmd::AgeSignName(cmd[2].clone(), cmd[3].clone()) }
"verifyname" => { if cmd.len() != 5 { return Err(DBError("AGE VERIFYNAME <name> <message> <signature_b64>".to_string())); }
Cmd::AgeVerifyName(cmd[2].clone(), cmd[3].clone(), cmd[4].clone()) }
"list" => { if cmd.len() != 2 { return Err(DBError("AGE LIST".to_string())); }
Cmd::AgeList }
_ => return Err(DBError(format!("unsupported AGE subcommand {:?}", cmd))),
}
}
_ => Cmd::Unknow(cmd[0].clone()),
},
protocol,
remaining
))
}
_ => Err(DBError(format!(
"fail to parse as cmd for {:?}",
protocol
))),
}
}
pub async fn run(self, server: &mut Server) -> Result<Protocol, DBError> {
// Handle queued commands for transactions
if server.queued_cmd.is_some()
&& !matches!(self, Cmd::Exec)
&& !matches!(self, Cmd::Multi)
&& !matches!(self, Cmd::Discard)
{
let protocol = self.clone().to_protocol();
server.queued_cmd.as_mut().unwrap().push((self, protocol));
return Ok(Protocol::SimpleString("QUEUED".to_string()));
}
match self {
Cmd::Select(db) => select_cmd(server, db).await,
Cmd::Ping => Ok(Protocol::SimpleString("PONG".to_string())),
Cmd::Echo(s) => Ok(Protocol::BulkString(s)),
Cmd::Get(k) => get_cmd(server, &k).await,
Cmd::Set(k, v) => set_cmd(server, &k, &v).await,
Cmd::SetPx(k, v, x) => set_px_cmd(server, &k, &v, &x).await,
Cmd::SetEx(k, v, x) => set_ex_cmd(server, &k, &v, &x).await,
Cmd::Del(k) => del_cmd(server, &k).await,
Cmd::ConfigGet(name) => config_get_cmd(&name, server),
Cmd::Keys => keys_cmd(server).await,
Cmd::Info(section) => info_cmd(server, &section).await,
Cmd::Type(k) => type_cmd(server, &k).await,
Cmd::Incr(key) => incr_cmd(server, &key).await,
Cmd::Multi => {
server.queued_cmd = Some(Vec::<(Cmd, Protocol)>::new());
Ok(Protocol::SimpleString("OK".to_string()))
}
Cmd::Exec => exec_cmd(server).await,
Cmd::Discard => {
if server.queued_cmd.is_some() {
server.queued_cmd = None;
Ok(Protocol::SimpleString("OK".to_string()))
} else {
Ok(Protocol::err("ERR DISCARD without MULTI"))
}
}
// Hash commands
Cmd::HSet(key, pairs) => hset_cmd(server, &key, &pairs).await,
Cmd::HGet(key, field) => hget_cmd(server, &key, &field).await,
Cmd::HGetAll(key) => hgetall_cmd(server, &key).await,
Cmd::HDel(key, fields) => hdel_cmd(server, &key, &fields).await,
Cmd::HExists(key, field) => hexists_cmd(server, &key, &field).await,
Cmd::HKeys(key) => hkeys_cmd(server, &key).await,
Cmd::HVals(key) => hvals_cmd(server, &key).await,
Cmd::HLen(key) => hlen_cmd(server, &key).await,
Cmd::HMGet(key, fields) => hmget_cmd(server, &key, &fields).await,
Cmd::HSetNx(key, field, value) => hsetnx_cmd(server, &key, &field, &value).await,
Cmd::HScan(key, cursor, pattern, count) => hscan_cmd(server, &key, &cursor, pattern.as_deref(), &count).await,
Cmd::Scan(cursor, pattern, count) => scan_cmd(server, &cursor, pattern.as_deref(), &count).await,
Cmd::Ttl(key) => ttl_cmd(server, &key).await,
Cmd::Exists(key) => exists_cmd(server, &key).await,
Cmd::Quit => Ok(Protocol::SimpleString("OK".to_string())),
Cmd::Client(_) => Ok(Protocol::SimpleString("OK".to_string())),
Cmd::ClientSetName(name) => client_setname_cmd(server, &name).await,
Cmd::ClientGetName => client_getname_cmd(server).await,
// List commands
Cmd::LPush(key, elements) => lpush_cmd(server, &key, &elements).await,
Cmd::RPush(key, elements) => rpush_cmd(server, &key, &elements).await,
Cmd::LPop(key, count) => lpop_cmd(server, &key, &count).await,
Cmd::RPop(key, count) => rpop_cmd(server, &key, &count).await,
Cmd::LLen(key) => llen_cmd(server, &key).await,
Cmd::LRem(key, count, element) => lrem_cmd(server, &key, count, &element).await,
Cmd::LTrim(key, start, stop) => ltrim_cmd(server, &key, start, stop).await,
Cmd::LIndex(key, index) => lindex_cmd(server, &key, index).await,
Cmd::LRange(key, start, stop) => lrange_cmd(server, &key, start, stop).await,
Cmd::FlushDb => flushdb_cmd(server).await,
// AGE (rage): stateless
Cmd::AgeGenEnc => Ok(crate::age::cmd_age_genenc().await),
Cmd::AgeGenSign => Ok(crate::age::cmd_age_gensign().await),
Cmd::AgeEncrypt(recipient, message) => Ok(crate::age::cmd_age_encrypt(&recipient, &message).await),
Cmd::AgeDecrypt(identity, ct_b64) => Ok(crate::age::cmd_age_decrypt(&identity, &ct_b64).await),
Cmd::AgeSign(secret, message) => Ok(crate::age::cmd_age_sign(&secret, &message).await),
Cmd::AgeVerify(vpub, msg, sig_b64) => Ok(crate::age::cmd_age_verify(&vpub, &msg, &sig_b64).await),
// AGE (rage): persistent named keys
Cmd::AgeKeygen(name) => Ok(crate::age::cmd_age_keygen(server, &name).await),
Cmd::AgeSignKeygen(name) => Ok(crate::age::cmd_age_signkeygen(server, &name).await),
Cmd::AgeEncryptName(name, message) => Ok(crate::age::cmd_age_encrypt_name(server, &name, &message).await),
Cmd::AgeDecryptName(name, ct_b64) => Ok(crate::age::cmd_age_decrypt_name(server, &name, &ct_b64).await),
Cmd::AgeSignName(name, message) => Ok(crate::age::cmd_age_sign_name(server, &name, &message).await),
Cmd::AgeVerifyName(name, message, sig_b64) => Ok(crate::age::cmd_age_verify_name(server, &name, &message, &sig_b64).await),
Cmd::AgeList => Ok(crate::age::cmd_age_list(server).await),
Cmd::Unknow(s) => Ok(Protocol::err(&format!("ERR unknown command `{}`", s))),
}
}
pub fn to_protocol(self) -> Protocol {
match self {
Cmd::Select(db) => Protocol::Array(vec![Protocol::BulkString("select".to_string()), Protocol::BulkString(db.to_string())]),
Cmd::Ping => Protocol::Array(vec![Protocol::BulkString("ping".to_string())]),
Cmd::Echo(s) => Protocol::Array(vec![Protocol::BulkString("echo".to_string()), Protocol::BulkString(s)]),
Cmd::Get(k) => Protocol::Array(vec![Protocol::BulkString("get".to_string()), Protocol::BulkString(k)]),
Cmd::Set(k, v) => Protocol::Array(vec![Protocol::BulkString("set".to_string()), Protocol::BulkString(k), Protocol::BulkString(v)]),
_ => Protocol::SimpleString("...".to_string())
}
}
}
async fn flushdb_cmd(server: &mut Server) -> Result<Protocol, DBError> {
match server.current_storage()?.flushdb() {
Ok(_) => Ok(Protocol::SimpleString("OK".to_string())),
Err(e) => Ok(Protocol::err(&e.0)),
}
}
async fn select_cmd(server: &mut Server, db: u64) -> Result<Protocol, DBError> {
// Test if we can access the database (this will create it if needed)
server.selected_db = db;
match server.current_storage() {
Ok(_) => Ok(Protocol::SimpleString("OK".to_string())),
Err(e) => Ok(Protocol::err(&e.0)),
}
}
async fn lindex_cmd(server: &Server, key: &str, index: i64) -> Result<Protocol, DBError> {
match server.current_storage()?.lindex(key, index) {
Ok(Some(element)) => Ok(Protocol::BulkString(element)),
Ok(None) => Ok(Protocol::Null),
Err(e) => Ok(Protocol::err(&e.0)),
}
}
async fn lrange_cmd(server: &Server, key: &str, start: i64, stop: i64) -> Result<Protocol, DBError> {
match server.current_storage()?.lrange(key, start, stop) {
Ok(elements) => Ok(Protocol::Array(elements.into_iter().map(Protocol::BulkString).collect())),
Err(e) => Ok(Protocol::err(&e.0)),
}
}
async fn ltrim_cmd(server: &Server, key: &str, start: i64, stop: i64) -> Result<Protocol, DBError> {
match server.current_storage()?.ltrim(key, start, stop) {
Ok(_) => Ok(Protocol::SimpleString("OK".to_string())),
Err(e) => Ok(Protocol::err(&e.0)),
}
}
async fn lrem_cmd(server: &Server, key: &str, count: i64, element: &str) -> Result<Protocol, DBError> {
match server.current_storage()?.lrem(key, count, element) {
Ok(removed_count) => Ok(Protocol::SimpleString(removed_count.to_string())),
Err(e) => Ok(Protocol::err(&e.0)),
}
}
async fn llen_cmd(server: &Server, key: &str) -> Result<Protocol, DBError> {
match server.current_storage()?.llen(key) {
Ok(len) => Ok(Protocol::SimpleString(len.to_string())),
Err(e) => Ok(Protocol::err(&e.0)),
}
}
async fn lpop_cmd(server: &Server, key: &str, count: &Option<u64>) -> Result<Protocol, DBError> {
let count_val = count.unwrap_or(1);
match server.current_storage()?.lpop(key, count_val) {
Ok(elements) => {
if elements.is_empty() {
if count.is_some() {
Ok(Protocol::Array(vec![]))
} else {
Ok(Protocol::Null)
}
} else if count.is_some() {
Ok(Protocol::Array(elements.into_iter().map(Protocol::BulkString).collect()))
} else {
Ok(Protocol::BulkString(elements[0].clone()))
}
},
Err(e) => Ok(Protocol::err(&e.0)),
}
}
async fn rpop_cmd(server: &Server, key: &str, count: &Option<u64>) -> Result<Protocol, DBError> {
let count_val = count.unwrap_or(1);
match server.current_storage()?.rpop(key, count_val) {
Ok(elements) => {
if elements.is_empty() {
if count.is_some() {
Ok(Protocol::Array(vec![]))
} else {
Ok(Protocol::Null)
}
} else if count.is_some() {
Ok(Protocol::Array(elements.into_iter().map(Protocol::BulkString).collect()))
} else {
Ok(Protocol::BulkString(elements[0].clone()))
}
},
Err(e) => Ok(Protocol::err(&e.0)),
}
}
async fn lpush_cmd(server: &Server, key: &str, elements: &[String]) -> Result<Protocol, DBError> {
match server.current_storage()?.lpush(key, elements.to_vec()) {
Ok(len) => Ok(Protocol::SimpleString(len.to_string())),
Err(e) => Ok(Protocol::err(&e.0)),
}
}
async fn rpush_cmd(server: &Server, key: &str, elements: &[String]) -> Result<Protocol, DBError> {
match server.current_storage()?.rpush(key, elements.to_vec()) {
Ok(len) => Ok(Protocol::SimpleString(len.to_string())),
Err(e) => Ok(Protocol::err(&e.0)),
}
}
async fn exec_cmd(server: &mut Server) -> Result<Protocol, DBError> {
// Move the queued commands out of `server` so we drop the borrow immediately.
let cmds = if let Some(cmds) = server.queued_cmd.take() {
cmds
} else {
return Ok(Protocol::err("ERR EXEC without MULTI"));
};
let mut out = Vec::new();
for (cmd, _) in cmds {
// Use Box::pin to handle recursion in async function
let res = Box::pin(cmd.run(server)).await?;
out.push(res);
}
Ok(Protocol::Array(out))
}
async fn incr_cmd(server: &Server, key: &String) -> Result<Protocol, DBError> {
let storage = server.current_storage()?;
let current_value = storage.get(key)?;
let new_value = match current_value {
Some(v) => {
match v.parse::<i64>() {
Ok(num) => num + 1,
Err(_) => return Ok(Protocol::err("ERR value is not an integer or out of range")),
}
}
None => 1,
};
storage.set(key.clone(), new_value.to_string())?;
Ok(Protocol::SimpleString(new_value.to_string()))
}
fn config_get_cmd(name: &String, server: &Server) -> Result<Protocol, DBError> {
let value = match name.as_str() {
"dir" => Some(server.option.dir.clone()),
"dbfilename" => Some(format!("{}.db", server.selected_db)),
"databases" => Some("16".to_string()), // Hardcoded as per original logic
_ => None,
};
if let Some(val) = value {
Ok(Protocol::Array(vec![
Protocol::BulkString(name.clone()),
Protocol::BulkString(val),
]))
} else {
// Return an empty array for unknown config options, which is standard Redis behavior
Ok(Protocol::Array(vec![]))
}
}
async fn keys_cmd(server: &Server) -> Result<Protocol, DBError> {
let keys = server.current_storage()?.keys("*")?;
Ok(Protocol::Array(
keys.into_iter().map(Protocol::BulkString).collect(),
))
}
#[derive(Serialize)]
struct ServerInfo {
redis_version: String,
encrypted: bool,
selected_db: u64,
}
async fn info_cmd(server: &Server, section: &Option<String>) -> Result<Protocol, DBError> {
let info = ServerInfo {
redis_version: "7.0.0".to_string(),
encrypted: server.current_storage()?.is_encrypted(),
selected_db: server.selected_db,
};
let mut info_string = String::new();
info_string.push_str(&format!("# Server\n"));
info_string.push_str(&format!("redis_version:{}\n", info.redis_version));
info_string.push_str(&format!("encrypted:{}\n", if info.encrypted { 1 } else { 0 }));
info_string.push_str(&format!("# Keyspace\n"));
info_string.push_str(&format!("db{}:keys=0,expires=0,avg_ttl=0\n", info.selected_db));
match section {
Some(s) => match s.as_str() {
"replication" => Ok(Protocol::BulkString(
"role:master\nmaster_replid:8371b4fb1155b71f4a04d3e1bc3e18c4a990aeea\nmaster_repl_offset:0\n".to_string()
)),
_ => Err(DBError(format!("unsupported section {:?}", s))),
},
None => {
Ok(Protocol::BulkString(info_string))
}
}
}
async fn type_cmd(server: &Server, k: &String) -> Result<Protocol, DBError> {
match server.current_storage()?.get_key_type(k)? {
Some(type_str) => Ok(Protocol::SimpleString(type_str)),
None => Ok(Protocol::SimpleString("none".to_string())),
}
}
async fn del_cmd(server: &Server, k: &str) -> Result<Protocol, DBError> {
server.current_storage()?.del(k.to_string())?;
Ok(Protocol::SimpleString("1".to_string()))
}
async fn set_ex_cmd(
server: &Server,
k: &str,
v: &str,
x: &u128,
) -> Result<Protocol, DBError> {
server.current_storage()?.setx(k.to_string(), v.to_string(), *x * 1000)?;
Ok(Protocol::SimpleString("OK".to_string()))
}
async fn set_px_cmd(
server: &Server,
k: &str,
v: &str,
x: &u128,
) -> Result<Protocol, DBError> {
server.current_storage()?.setx(k.to_string(), v.to_string(), *x)?;
Ok(Protocol::SimpleString("OK".to_string()))
}
async fn set_cmd(server: &Server, k: &str, v: &str) -> Result<Protocol, DBError> {
server.current_storage()?.set(k.to_string(), v.to_string())?;
Ok(Protocol::SimpleString("OK".to_string()))
}
async fn get_cmd(server: &Server, k: &str) -> Result<Protocol, DBError> {
let v = server.current_storage()?.get(k)?;
Ok(v.map_or(Protocol::Null, Protocol::BulkString))
}
// Hash command implementations
async fn hset_cmd(server: &Server, key: &str, pairs: &[(String, String)]) -> Result<Protocol, DBError> {
let new_fields = server.current_storage()?.hset(key, pairs.to_vec())?;
Ok(Protocol::SimpleString(new_fields.to_string()))
}
async fn hget_cmd(server: &Server, key: &str, field: &str) -> Result<Protocol, DBError> {
match server.current_storage()?.hget(key, field) {
Ok(Some(value)) => Ok(Protocol::BulkString(value)),
Ok(None) => Ok(Protocol::Null),
Err(e) => Ok(Protocol::err(&e.0)),
}
}
async fn hgetall_cmd(server: &Server, key: &str) -> Result<Protocol, DBError> {
match server.current_storage()?.hgetall(key) {
Ok(pairs) => {
let mut result = Vec::new();
for (field, value) in pairs {
result.push(Protocol::BulkString(field));
result.push(Protocol::BulkString(value));
}
Ok(Protocol::Array(result))
}
Err(e) => Ok(Protocol::err(&e.0)),
}
}
async fn hdel_cmd(server: &Server, key: &str, fields: &[String]) -> Result<Protocol, DBError> {
match server.current_storage()?.hdel(key, fields.to_vec()) {
Ok(deleted) => Ok(Protocol::SimpleString(deleted.to_string())),
Err(e) => Ok(Protocol::err(&e.0)),
}
}
async fn hexists_cmd(server: &Server, key: &str, field: &str) -> Result<Protocol, DBError> {
match server.current_storage()?.hexists(key, field) {
Ok(exists) => Ok(Protocol::SimpleString(if exists { "1" } else { "0" }.to_string())),
Err(e) => Ok(Protocol::err(&e.0)),
}
}
async fn hkeys_cmd(server: &Server, key: &str) -> Result<Protocol, DBError> {
match server.current_storage()?.hkeys(key) {
Ok(keys) => Ok(Protocol::Array(
keys.into_iter().map(Protocol::BulkString).collect(),
)),
Err(e) => Ok(Protocol::err(&e.0)),
}
}
async fn hvals_cmd(server: &Server, key: &str) -> Result<Protocol, DBError> {
match server.current_storage()?.hvals(key) {
Ok(values) => Ok(Protocol::Array(
values.into_iter().map(Protocol::BulkString).collect(),
)),
Err(e) => Ok(Protocol::err(&e.0)),
}
}
async fn hlen_cmd(server: &Server, key: &str) -> Result<Protocol, DBError> {
match server.current_storage()?.hlen(key) {
Ok(len) => Ok(Protocol::SimpleString(len.to_string())),
Err(e) => Ok(Protocol::err(&e.0)),
}
}
async fn hmget_cmd(server: &Server, key: &str, fields: &[String]) -> Result<Protocol, DBError> {
match server.current_storage()?.hmget(key, fields.to_vec()) {
Ok(values) => {
let result: Vec<Protocol> = values
.into_iter()
.map(|v| v.map_or(Protocol::Null, Protocol::BulkString))
.collect();
Ok(Protocol::Array(result))
}
Err(e) => Ok(Protocol::err(&e.0)),
}
}
async fn hsetnx_cmd(server: &Server, key: &str, field: &str, value: &str) -> Result<Protocol, DBError> {
match server.current_storage()?.hsetnx(key, field, value) {
Ok(was_set) => Ok(Protocol::SimpleString(if was_set { "1" } else { "0" }.to_string())),
Err(e) => Ok(Protocol::err(&e.0)),
}
}
async fn scan_cmd(
server: &Server,
cursor: &u64,
pattern: Option<&str>,
count: &Option<u64>
) -> Result<Protocol, DBError> {
match server.current_storage()?.scan(*cursor, pattern, *count) {
Ok((next_cursor, key_value_pairs)) => {
let mut result = Vec::new();
result.push(Protocol::BulkString(next_cursor.to_string()));
// For SCAN, we only return the keys, not the values
let keys: Vec<Protocol> = key_value_pairs.into_iter().map(|(key, _)| Protocol::BulkString(key)).collect();
result.push(Protocol::Array(keys));
Ok(Protocol::Array(result))
}
Err(e) => Ok(Protocol::err(&format!("ERR {}", e.0))),
}
}
async fn hscan_cmd(
server: &Server,
key: &str,
cursor: &u64,
pattern: Option<&str>,
count: &Option<u64>
) -> Result<Protocol, DBError> {
match server.current_storage()?.hscan(key, *cursor, pattern, *count) {
Ok((next_cursor, field_value_pairs)) => {
let mut result = Vec::new();
result.push(Protocol::BulkString(next_cursor.to_string()));
// For HSCAN, we return field-value pairs flattened
let mut fields_and_values = Vec::new();
for (field, value) in field_value_pairs {
fields_and_values.push(Protocol::BulkString(field));
fields_and_values.push(Protocol::BulkString(value));
}
result.push(Protocol::Array(fields_and_values));
Ok(Protocol::Array(result))
}
Err(e) => Ok(Protocol::err(&format!("ERR {}", e.0))),
}
}
async fn ttl_cmd(server: &Server, key: &str) -> Result<Protocol, DBError> {
match server.current_storage()?.ttl(key) {
Ok(ttl) => Ok(Protocol::SimpleString(ttl.to_string())),
Err(e) => Ok(Protocol::err(&e.0)),
}
}
async fn exists_cmd(server: &Server, key: &str) -> Result<Protocol, DBError> {
match server.current_storage()?.exists(key) {
Ok(exists) => Ok(Protocol::SimpleString(if exists { "1" } else { "0" }.to_string())),
Err(e) => Ok(Protocol::err(&e.0)),
}
}
async fn client_setname_cmd(server: &mut Server, name: &str) -> Result<Protocol, DBError> {
server.client_name = Some(name.to_string());
Ok(Protocol::SimpleString("OK".to_string()))
}
async fn client_getname_cmd(server: &Server) -> Result<Protocol, DBError> {
match &server.client_name {
Some(name) => Ok(Protocol::BulkString(name.clone())),
None => Ok(Protocol::Null),
}
}