971 lines
45 KiB
Rust
971 lines
45 KiB
Rust
use crate::{error::DBError, protocol::Protocol, server::Server};
|
|
use serde::Serialize;
|
|
|
|
#[derive(Debug, Clone)]
|
|
pub enum Cmd {
|
|
Ping,
|
|
Echo(String),
|
|
Select(u64), // Changed from u16 to u64
|
|
Get(String),
|
|
Set(String, String),
|
|
SetPx(String, String, u128),
|
|
SetEx(String, String, u128),
|
|
Keys,
|
|
ConfigGet(String),
|
|
Info(Option<String>),
|
|
Del(String),
|
|
Type(String),
|
|
Incr(String),
|
|
Multi,
|
|
Exec,
|
|
Discard,
|
|
// Hash commands
|
|
HSet(String, Vec<(String, String)>),
|
|
HGet(String, String),
|
|
HGetAll(String),
|
|
HDel(String, Vec<String>),
|
|
HExists(String, String),
|
|
HKeys(String),
|
|
HVals(String),
|
|
HLen(String),
|
|
HMGet(String, Vec<String>),
|
|
HSetNx(String, String, String),
|
|
HScan(String, u64, Option<String>, Option<u64>), // key, cursor, pattern, count
|
|
Scan(u64, Option<String>, Option<u64>), // cursor, pattern, count
|
|
Ttl(String),
|
|
Exists(String),
|
|
Quit,
|
|
Client(Vec<String>),
|
|
ClientSetName(String),
|
|
ClientGetName,
|
|
// List commands
|
|
LPush(String, Vec<String>),
|
|
RPush(String, Vec<String>),
|
|
LPop(String, Option<u64>),
|
|
RPop(String, Option<u64>),
|
|
LLen(String),
|
|
LRem(String, i64, String),
|
|
LTrim(String, i64, i64),
|
|
LIndex(String, i64),
|
|
LRange(String, i64, i64),
|
|
FlushDb,
|
|
Unknow(String),
|
|
// AGE (rage) commands — stateless
|
|
AgeGenEnc,
|
|
AgeGenSign,
|
|
AgeEncrypt(String, String), // recipient, message
|
|
AgeDecrypt(String, String), // identity, ciphertext_b64
|
|
AgeSign(String, String), // signing_secret, message
|
|
AgeVerify(String, String, String), // verify_pub, message, signature_b64
|
|
|
|
// NEW: persistent named-key commands
|
|
AgeKeygen(String), // name
|
|
AgeSignKeygen(String), // name
|
|
AgeEncryptName(String, String), // name, message
|
|
AgeDecryptName(String, String), // name, ciphertext_b64
|
|
AgeSignName(String, String), // name, message
|
|
AgeVerifyName(String, String, String), // name, message, signature_b64
|
|
AgeList,
|
|
}
|
|
|
|
impl Cmd {
|
|
pub fn from(s: &str) -> Result<(Self, Protocol, &str), DBError> {
|
|
let (protocol, remaining) = Protocol::from(s)?;
|
|
match protocol.clone() {
|
|
Protocol::Array(p) => {
|
|
let cmd = p.into_iter().map(|x| x.decode()).collect::<Vec<_>>();
|
|
if cmd.is_empty() {
|
|
return Err(DBError("cmd length is 0".to_string()));
|
|
}
|
|
Ok((
|
|
match cmd[0].to_lowercase().as_str() {
|
|
"select" => {
|
|
if cmd.len() != 2 {
|
|
return Err(DBError("wrong number of arguments for SELECT".to_string()));
|
|
}
|
|
let idx = cmd[1].parse::<u64>().map_err(|_| DBError("ERR DB index is not an integer".to_string()))?;
|
|
Cmd::Select(idx)
|
|
}
|
|
"echo" => Cmd::Echo(cmd[1].clone()),
|
|
"ping" => Cmd::Ping,
|
|
"get" => Cmd::Get(cmd[1].clone()),
|
|
"set" => {
|
|
if cmd.len() == 5 && cmd[3].to_lowercase() == "px" {
|
|
Cmd::SetPx(cmd[1].clone(), cmd[2].clone(), cmd[4].parse().unwrap())
|
|
} else if cmd.len() == 5 && cmd[3].to_lowercase() == "ex" {
|
|
Cmd::SetEx(cmd[1].clone(), cmd[2].clone(), cmd[4].parse().unwrap())
|
|
} else if cmd.len() == 3 {
|
|
Cmd::Set(cmd[1].clone(), cmd[2].clone())
|
|
} else {
|
|
return Err(DBError(format!("unsupported cmd {:?}", cmd)));
|
|
}
|
|
}
|
|
"setex" => {
|
|
if cmd.len() != 4 {
|
|
return Err(DBError(format!("wrong number of arguments for SETEX command")));
|
|
}
|
|
Cmd::SetEx(cmd[1].clone(), cmd[3].clone(), cmd[2].parse().unwrap())
|
|
}
|
|
"config" => {
|
|
if cmd.len() != 3 || cmd[1].to_lowercase() != "get" {
|
|
return Err(DBError(format!("unsupported cmd {:?}", cmd)));
|
|
} else {
|
|
Cmd::ConfigGet(cmd[2].clone())
|
|
}
|
|
}
|
|
"keys" => {
|
|
if cmd.len() != 2 || cmd[1] != "*" {
|
|
return Err(DBError(format!("unsupported cmd {:?}", cmd)));
|
|
} else {
|
|
Cmd::Keys
|
|
}
|
|
}
|
|
"info" => {
|
|
let section = if cmd.len() == 2 {
|
|
Some(cmd[1].clone())
|
|
} else {
|
|
None
|
|
};
|
|
Cmd::Info(section)
|
|
}
|
|
"del" => {
|
|
if cmd.len() != 2 {
|
|
return Err(DBError(format!("unsupported cmd {:?}", cmd)));
|
|
}
|
|
Cmd::Del(cmd[1].clone())
|
|
}
|
|
"type" => {
|
|
if cmd.len() != 2 {
|
|
return Err(DBError(format!("unsupported cmd {:?}", cmd)));
|
|
}
|
|
Cmd::Type(cmd[1].clone())
|
|
}
|
|
"incr" => {
|
|
if cmd.len() != 2 {
|
|
return Err(DBError(format!("unsupported cmd {:?}", cmd)));
|
|
}
|
|
Cmd::Incr(cmd[1].clone())
|
|
}
|
|
"multi" => {
|
|
if cmd.len() != 1 {
|
|
return Err(DBError(format!("unsupported cmd {:?}", cmd)));
|
|
}
|
|
Cmd::Multi
|
|
}
|
|
"exec" => {
|
|
if cmd.len() != 1 {
|
|
return Err(DBError(format!("unsupported cmd {:?}", cmd)));
|
|
}
|
|
Cmd::Exec
|
|
}
|
|
"discard" => Cmd::Discard,
|
|
// Hash commands
|
|
"hset" => {
|
|
if cmd.len() < 4 || (cmd.len() - 2) % 2 != 0 {
|
|
return Err(DBError(format!("wrong number of arguments for HSET command")));
|
|
}
|
|
let mut pairs = Vec::new();
|
|
let mut i = 2;
|
|
while i + 1 < cmd.len() {
|
|
pairs.push((cmd[i].clone(), cmd[i + 1].clone()));
|
|
i += 2;
|
|
}
|
|
Cmd::HSet(cmd[1].clone(), pairs)
|
|
}
|
|
"hget" => {
|
|
if cmd.len() != 3 {
|
|
return Err(DBError(format!("wrong number of arguments for HGET command")));
|
|
}
|
|
Cmd::HGet(cmd[1].clone(), cmd[2].clone())
|
|
}
|
|
"hgetall" => {
|
|
if cmd.len() != 2 {
|
|
return Err(DBError(format!("wrong number of arguments for HGETALL command")));
|
|
}
|
|
Cmd::HGetAll(cmd[1].clone())
|
|
}
|
|
"hdel" => {
|
|
if cmd.len() < 3 {
|
|
return Err(DBError(format!("wrong number of arguments for HDEL command")));
|
|
}
|
|
Cmd::HDel(cmd[1].clone(), cmd[2..].to_vec())
|
|
}
|
|
"hexists" => {
|
|
if cmd.len() != 3 {
|
|
return Err(DBError(format!("wrong number of arguments for HEXISTS command")));
|
|
}
|
|
Cmd::HExists(cmd[1].clone(), cmd[2].clone())
|
|
}
|
|
"hkeys" => {
|
|
if cmd.len() != 2 {
|
|
return Err(DBError(format!("wrong number of arguments for HKEYS command")));
|
|
}
|
|
Cmd::HKeys(cmd[1].clone())
|
|
}
|
|
"hvals" => {
|
|
if cmd.len() != 2 {
|
|
return Err(DBError(format!("wrong number of arguments for HVALS command")));
|
|
}
|
|
Cmd::HVals(cmd[1].clone())
|
|
}
|
|
"hlen" => {
|
|
if cmd.len() != 2 {
|
|
return Err(DBError(format!("wrong number of arguments for HLEN command")));
|
|
}
|
|
Cmd::HLen(cmd[1].clone())
|
|
}
|
|
"hmget" => {
|
|
if cmd.len() < 3 {
|
|
return Err(DBError(format!("wrong number of arguments for HMGET command")));
|
|
}
|
|
Cmd::HMGet(cmd[1].clone(), cmd[2..].to_vec())
|
|
}
|
|
"hsetnx" => {
|
|
if cmd.len() != 4 {
|
|
return Err(DBError(format!("wrong number of arguments for HSETNX command")));
|
|
}
|
|
Cmd::HSetNx(cmd[1].clone(), cmd[2].clone(), cmd[3].clone())
|
|
}
|
|
"hscan" => {
|
|
if cmd.len() < 3 {
|
|
return Err(DBError(format!("wrong number of arguments for HSCAN command")));
|
|
}
|
|
|
|
let key = cmd[1].clone();
|
|
let cursor = cmd[2].parse::<u64>().map_err(|_|
|
|
DBError("ERR invalid cursor".to_string()))?;
|
|
|
|
let mut pattern = None;
|
|
let mut count = None;
|
|
let mut i = 3;
|
|
|
|
while i < cmd.len() {
|
|
match cmd[i].to_lowercase().as_str() {
|
|
"match" => {
|
|
if i + 1 >= cmd.len() {
|
|
return Err(DBError("ERR syntax error".to_string()));
|
|
}
|
|
pattern = Some(cmd[i + 1].clone());
|
|
i += 2;
|
|
}
|
|
"count" => {
|
|
if i + 1 >= cmd.len() {
|
|
return Err(DBError("ERR syntax error".to_string()));
|
|
}
|
|
count = Some(cmd[i + 1].parse::<u64>().map_err(|_|
|
|
DBError("ERR value is not an integer or out of range".to_string()))?);
|
|
i += 2;
|
|
}
|
|
_ => {
|
|
return Err(DBError(format!("ERR syntax error")));
|
|
}
|
|
}
|
|
}
|
|
|
|
Cmd::HScan(key, cursor, pattern, count)
|
|
}
|
|
"scan" => {
|
|
if cmd.len() < 2 {
|
|
return Err(DBError(format!("wrong number of arguments for SCAN command")));
|
|
}
|
|
|
|
let cursor = cmd[1].parse::<u64>().map_err(|_|
|
|
DBError("ERR invalid cursor".to_string()))?;
|
|
|
|
let mut pattern = None;
|
|
let mut count = None;
|
|
let mut i = 2;
|
|
|
|
while i < cmd.len() {
|
|
match cmd[i].to_lowercase().as_str() {
|
|
"match" => {
|
|
if i + 1 >= cmd.len() {
|
|
return Err(DBError("ERR syntax error".to_string()));
|
|
}
|
|
pattern = Some(cmd[i + 1].clone());
|
|
i += 2;
|
|
}
|
|
"count" => {
|
|
if i + 1 >= cmd.len() {
|
|
return Err(DBError("ERR syntax error".to_string()));
|
|
}
|
|
count = Some(cmd[i + 1].parse::<u64>().map_err(|_|
|
|
DBError("ERR value is not an integer or out of range".to_string()))?);
|
|
i += 2;
|
|
}
|
|
_ => {
|
|
return Err(DBError(format!("ERR syntax error")));
|
|
}
|
|
}
|
|
}
|
|
|
|
Cmd::Scan(cursor, pattern, count)
|
|
}
|
|
"ttl" => {
|
|
if cmd.len() != 2 {
|
|
return Err(DBError(format!("wrong number of arguments for TTL command")));
|
|
}
|
|
Cmd::Ttl(cmd[1].clone())
|
|
}
|
|
"exists" => {
|
|
if cmd.len() != 2 {
|
|
return Err(DBError(format!("wrong number of arguments for EXISTS command")));
|
|
}
|
|
Cmd::Exists(cmd[1].clone())
|
|
}
|
|
"quit" => {
|
|
if cmd.len() != 1 {
|
|
return Err(DBError(format!("wrong number of arguments for QUIT command")));
|
|
}
|
|
Cmd::Quit
|
|
}
|
|
"client" => {
|
|
if cmd.len() > 1 {
|
|
match cmd[1].to_lowercase().as_str() {
|
|
"setname" => {
|
|
if cmd.len() == 3 {
|
|
Cmd::ClientSetName(cmd[2].clone())
|
|
} else {
|
|
return Err(DBError("wrong number of arguments for 'client setname' command".to_string()));
|
|
}
|
|
}
|
|
"getname" => {
|
|
if cmd.len() == 2 {
|
|
Cmd::ClientGetName
|
|
} else {
|
|
return Err(DBError("wrong number of arguments for 'client getname' command".to_string()));
|
|
}
|
|
}
|
|
_ => Cmd::Client(cmd[1..].to_vec()),
|
|
}
|
|
} else {
|
|
Cmd::Client(vec![])
|
|
}
|
|
}
|
|
"lpush" => {
|
|
if cmd.len() < 3 {
|
|
return Err(DBError(format!("wrong number of arguments for LPUSH command")));
|
|
}
|
|
Cmd::LPush(cmd[1].clone(), cmd[2..].to_vec())
|
|
}
|
|
"rpush" => {
|
|
if cmd.len() < 3 {
|
|
return Err(DBError(format!("wrong number of arguments for RPUSH command")));
|
|
}
|
|
Cmd::RPush(cmd[1].clone(), cmd[2..].to_vec())
|
|
}
|
|
"lpop" => {
|
|
if cmd.len() < 2 || cmd.len() > 3 {
|
|
return Err(DBError(format!("wrong number of arguments for LPOP command")));
|
|
}
|
|
let count = if cmd.len() == 3 {
|
|
Some(cmd[2].parse::<u64>().map_err(|_| DBError("ERR value is not an integer or out of range".to_string()))?)
|
|
} else {
|
|
None
|
|
};
|
|
Cmd::LPop(cmd[1].clone(), count)
|
|
}
|
|
"rpop" => {
|
|
if cmd.len() < 2 || cmd.len() > 3 {
|
|
return Err(DBError(format!("wrong number of arguments for RPOP command")));
|
|
}
|
|
let count = if cmd.len() == 3 {
|
|
Some(cmd[2].parse::<u64>().map_err(|_| DBError("ERR value is not an integer or out of range".to_string()))?)
|
|
} else {
|
|
None
|
|
};
|
|
Cmd::RPop(cmd[1].clone(), count)
|
|
}
|
|
"llen" => {
|
|
if cmd.len() != 2 {
|
|
return Err(DBError(format!("wrong number of arguments for LLEN command")));
|
|
}
|
|
Cmd::LLen(cmd[1].clone())
|
|
}
|
|
"lrem" => {
|
|
if cmd.len() != 4 {
|
|
return Err(DBError(format!("wrong number of arguments for LREM command")));
|
|
}
|
|
let count = cmd[2].parse::<i64>().map_err(|_| DBError("ERR value is not an integer or out of range".to_string()))?;
|
|
Cmd::LRem(cmd[1].clone(), count, cmd[3].clone())
|
|
}
|
|
"ltrim" => {
|
|
if cmd.len() != 4 {
|
|
return Err(DBError(format!("wrong number of arguments for LTRIM command")));
|
|
}
|
|
let start = cmd[2].parse::<i64>().map_err(|_| DBError("ERR value is not an integer or out of range".to_string()))?;
|
|
let stop = cmd[3].parse::<i64>().map_err(|_| DBError("ERR value is not an integer or out of range".to_string()))?;
|
|
Cmd::LTrim(cmd[1].clone(), start, stop)
|
|
}
|
|
"lindex" => {
|
|
if cmd.len() != 3 {
|
|
return Err(DBError(format!("wrong number of arguments for LINDEX command")));
|
|
}
|
|
let index = cmd[2].parse::<i64>().map_err(|_| DBError("ERR value is not an integer or out of range".to_string()))?;
|
|
Cmd::LIndex(cmd[1].clone(), index)
|
|
}
|
|
"lrange" => {
|
|
if cmd.len() != 4 {
|
|
return Err(DBError(format!("wrong number of arguments for LRANGE command")));
|
|
}
|
|
let start = cmd[2].parse::<i64>().map_err(|_| DBError("ERR value is not an integer or out of range".to_string()))?;
|
|
let stop = cmd[3].parse::<i64>().map_err(|_| DBError("ERR value is not an integer or out of range".to_string()))?;
|
|
Cmd::LRange(cmd[1].clone(), start, stop)
|
|
}
|
|
"flushdb" => {
|
|
if cmd.len() != 1 {
|
|
return Err(DBError("wrong number of arguments for FLUSHDB command".to_string()));
|
|
}
|
|
Cmd::FlushDb
|
|
}
|
|
"age" => {
|
|
if cmd.len() < 2 {
|
|
return Err(DBError("wrong number of arguments for AGE".to_string()));
|
|
}
|
|
match cmd[1].to_lowercase().as_str() {
|
|
// stateless
|
|
"genenc" => { if cmd.len() != 2 { return Err(DBError("AGE GENENC takes no args".to_string())); }
|
|
Cmd::AgeGenEnc }
|
|
"gensign" => { if cmd.len() != 2 { return Err(DBError("AGE GENSIGN takes no args".to_string())); }
|
|
Cmd::AgeGenSign }
|
|
"encrypt" => { if cmd.len() != 4 { return Err(DBError("AGE ENCRYPT <recipient> <message>".to_string())); }
|
|
Cmd::AgeEncrypt(cmd[2].clone(), cmd[3].clone()) }
|
|
"decrypt" => { if cmd.len() != 4 { return Err(DBError("AGE DECRYPT <identity> <ciphertext_b64>".to_string())); }
|
|
Cmd::AgeDecrypt(cmd[2].clone(), cmd[3].clone()) }
|
|
"sign" => { if cmd.len() != 4 { return Err(DBError("AGE SIGN <signing_secret> <message>".to_string())); }
|
|
Cmd::AgeSign(cmd[2].clone(), cmd[3].clone()) }
|
|
"verify" => { if cmd.len() != 5 { return Err(DBError("AGE VERIFY <verify_pub> <message> <signature_b64>".to_string())); }
|
|
Cmd::AgeVerify(cmd[2].clone(), cmd[3].clone(), cmd[4].clone()) }
|
|
|
|
// persistent names
|
|
"keygen" => { if cmd.len() != 3 { return Err(DBError("AGE KEYGEN <name>".to_string())); }
|
|
Cmd::AgeKeygen(cmd[2].clone()) }
|
|
"signkeygen" => { if cmd.len() != 3 { return Err(DBError("AGE SIGNKEYGEN <name>".to_string())); }
|
|
Cmd::AgeSignKeygen(cmd[2].clone()) }
|
|
"encryptname" => { if cmd.len() != 4 { return Err(DBError("AGE ENCRYPTNAME <name> <message>".to_string())); }
|
|
Cmd::AgeEncryptName(cmd[2].clone(), cmd[3].clone()) }
|
|
"decryptname" => { if cmd.len() != 4 { return Err(DBError("AGE DECRYPTNAME <name> <ciphertext_b64>".to_string())); }
|
|
Cmd::AgeDecryptName(cmd[2].clone(), cmd[3].clone()) }
|
|
"signname" => { if cmd.len() != 4 { return Err(DBError("AGE SIGNNAME <name> <message>".to_string())); }
|
|
Cmd::AgeSignName(cmd[2].clone(), cmd[3].clone()) }
|
|
"verifyname" => { if cmd.len() != 5 { return Err(DBError("AGE VERIFYNAME <name> <message> <signature_b64>".to_string())); }
|
|
Cmd::AgeVerifyName(cmd[2].clone(), cmd[3].clone(), cmd[4].clone()) }
|
|
"list" => { if cmd.len() != 2 { return Err(DBError("AGE LIST".to_string())); }
|
|
Cmd::AgeList }
|
|
_ => return Err(DBError(format!("unsupported AGE subcommand {:?}", cmd))),
|
|
}
|
|
}
|
|
_ => Cmd::Unknow(cmd[0].clone()),
|
|
},
|
|
protocol,
|
|
remaining
|
|
))
|
|
}
|
|
_ => Err(DBError(format!(
|
|
"fail to parse as cmd for {:?}",
|
|
protocol
|
|
))),
|
|
}
|
|
}
|
|
|
|
pub async fn run(self, server: &mut Server) -> Result<Protocol, DBError> {
|
|
// Handle queued commands for transactions
|
|
if server.queued_cmd.is_some()
|
|
&& !matches!(self, Cmd::Exec)
|
|
&& !matches!(self, Cmd::Multi)
|
|
&& !matches!(self, Cmd::Discard)
|
|
{
|
|
let protocol = self.clone().to_protocol();
|
|
server.queued_cmd.as_mut().unwrap().push((self, protocol));
|
|
return Ok(Protocol::SimpleString("QUEUED".to_string()));
|
|
}
|
|
|
|
match self {
|
|
Cmd::Select(db) => select_cmd(server, db).await,
|
|
Cmd::Ping => Ok(Protocol::SimpleString("PONG".to_string())),
|
|
Cmd::Echo(s) => Ok(Protocol::BulkString(s)),
|
|
Cmd::Get(k) => get_cmd(server, &k).await,
|
|
Cmd::Set(k, v) => set_cmd(server, &k, &v).await,
|
|
Cmd::SetPx(k, v, x) => set_px_cmd(server, &k, &v, &x).await,
|
|
Cmd::SetEx(k, v, x) => set_ex_cmd(server, &k, &v, &x).await,
|
|
Cmd::Del(k) => del_cmd(server, &k).await,
|
|
Cmd::ConfigGet(name) => config_get_cmd(&name, server),
|
|
Cmd::Keys => keys_cmd(server).await,
|
|
Cmd::Info(section) => info_cmd(server, §ion).await,
|
|
Cmd::Type(k) => type_cmd(server, &k).await,
|
|
Cmd::Incr(key) => incr_cmd(server, &key).await,
|
|
Cmd::Multi => {
|
|
server.queued_cmd = Some(Vec::<(Cmd, Protocol)>::new());
|
|
Ok(Protocol::SimpleString("OK".to_string()))
|
|
}
|
|
Cmd::Exec => exec_cmd(server).await,
|
|
Cmd::Discard => {
|
|
if server.queued_cmd.is_some() {
|
|
server.queued_cmd = None;
|
|
Ok(Protocol::SimpleString("OK".to_string()))
|
|
} else {
|
|
Ok(Protocol::err("ERR DISCARD without MULTI"))
|
|
}
|
|
}
|
|
// Hash commands
|
|
Cmd::HSet(key, pairs) => hset_cmd(server, &key, &pairs).await,
|
|
Cmd::HGet(key, field) => hget_cmd(server, &key, &field).await,
|
|
Cmd::HGetAll(key) => hgetall_cmd(server, &key).await,
|
|
Cmd::HDel(key, fields) => hdel_cmd(server, &key, &fields).await,
|
|
Cmd::HExists(key, field) => hexists_cmd(server, &key, &field).await,
|
|
Cmd::HKeys(key) => hkeys_cmd(server, &key).await,
|
|
Cmd::HVals(key) => hvals_cmd(server, &key).await,
|
|
Cmd::HLen(key) => hlen_cmd(server, &key).await,
|
|
Cmd::HMGet(key, fields) => hmget_cmd(server, &key, &fields).await,
|
|
Cmd::HSetNx(key, field, value) => hsetnx_cmd(server, &key, &field, &value).await,
|
|
Cmd::HScan(key, cursor, pattern, count) => hscan_cmd(server, &key, &cursor, pattern.as_deref(), &count).await,
|
|
Cmd::Scan(cursor, pattern, count) => scan_cmd(server, &cursor, pattern.as_deref(), &count).await,
|
|
Cmd::Ttl(key) => ttl_cmd(server, &key).await,
|
|
Cmd::Exists(key) => exists_cmd(server, &key).await,
|
|
Cmd::Quit => Ok(Protocol::SimpleString("OK".to_string())),
|
|
Cmd::Client(_) => Ok(Protocol::SimpleString("OK".to_string())),
|
|
Cmd::ClientSetName(name) => client_setname_cmd(server, &name).await,
|
|
Cmd::ClientGetName => client_getname_cmd(server).await,
|
|
// List commands
|
|
Cmd::LPush(key, elements) => lpush_cmd(server, &key, &elements).await,
|
|
Cmd::RPush(key, elements) => rpush_cmd(server, &key, &elements).await,
|
|
Cmd::LPop(key, count) => lpop_cmd(server, &key, &count).await,
|
|
Cmd::RPop(key, count) => rpop_cmd(server, &key, &count).await,
|
|
Cmd::LLen(key) => llen_cmd(server, &key).await,
|
|
Cmd::LRem(key, count, element) => lrem_cmd(server, &key, count, &element).await,
|
|
Cmd::LTrim(key, start, stop) => ltrim_cmd(server, &key, start, stop).await,
|
|
Cmd::LIndex(key, index) => lindex_cmd(server, &key, index).await,
|
|
Cmd::LRange(key, start, stop) => lrange_cmd(server, &key, start, stop).await,
|
|
Cmd::FlushDb => flushdb_cmd(server).await,
|
|
// AGE (rage): stateless
|
|
Cmd::AgeGenEnc => Ok(crate::age::cmd_age_genenc().await),
|
|
Cmd::AgeGenSign => Ok(crate::age::cmd_age_gensign().await),
|
|
Cmd::AgeEncrypt(recipient, message) => Ok(crate::age::cmd_age_encrypt(&recipient, &message).await),
|
|
Cmd::AgeDecrypt(identity, ct_b64) => Ok(crate::age::cmd_age_decrypt(&identity, &ct_b64).await),
|
|
Cmd::AgeSign(secret, message) => Ok(crate::age::cmd_age_sign(&secret, &message).await),
|
|
Cmd::AgeVerify(vpub, msg, sig_b64) => Ok(crate::age::cmd_age_verify(&vpub, &msg, &sig_b64).await),
|
|
|
|
// AGE (rage): persistent named keys
|
|
Cmd::AgeKeygen(name) => Ok(crate::age::cmd_age_keygen(server, &name).await),
|
|
Cmd::AgeSignKeygen(name) => Ok(crate::age::cmd_age_signkeygen(server, &name).await),
|
|
Cmd::AgeEncryptName(name, message) => Ok(crate::age::cmd_age_encrypt_name(server, &name, &message).await),
|
|
Cmd::AgeDecryptName(name, ct_b64) => Ok(crate::age::cmd_age_decrypt_name(server, &name, &ct_b64).await),
|
|
Cmd::AgeSignName(name, message) => Ok(crate::age::cmd_age_sign_name(server, &name, &message).await),
|
|
Cmd::AgeVerifyName(name, message, sig_b64) => Ok(crate::age::cmd_age_verify_name(server, &name, &message, &sig_b64).await),
|
|
Cmd::AgeList => Ok(crate::age::cmd_age_list(server).await),
|
|
Cmd::Unknow(s) => Ok(Protocol::err(&format!("ERR unknown command `{}`", s))),
|
|
}
|
|
}
|
|
|
|
pub fn to_protocol(self) -> Protocol {
|
|
match self {
|
|
Cmd::Select(db) => Protocol::Array(vec![Protocol::BulkString("select".to_string()), Protocol::BulkString(db.to_string())]),
|
|
Cmd::Ping => Protocol::Array(vec![Protocol::BulkString("ping".to_string())]),
|
|
Cmd::Echo(s) => Protocol::Array(vec![Protocol::BulkString("echo".to_string()), Protocol::BulkString(s)]),
|
|
Cmd::Get(k) => Protocol::Array(vec![Protocol::BulkString("get".to_string()), Protocol::BulkString(k)]),
|
|
Cmd::Set(k, v) => Protocol::Array(vec![Protocol::BulkString("set".to_string()), Protocol::BulkString(k), Protocol::BulkString(v)]),
|
|
_ => Protocol::SimpleString("...".to_string())
|
|
}
|
|
}
|
|
}
|
|
|
|
async fn flushdb_cmd(server: &mut Server) -> Result<Protocol, DBError> {
|
|
match server.current_storage()?.flushdb() {
|
|
Ok(_) => Ok(Protocol::SimpleString("OK".to_string())),
|
|
Err(e) => Ok(Protocol::err(&e.0)),
|
|
}
|
|
}
|
|
|
|
async fn select_cmd(server: &mut Server, db: u64) -> Result<Protocol, DBError> {
|
|
// Test if we can access the database (this will create it if needed)
|
|
server.selected_db = db;
|
|
match server.current_storage() {
|
|
Ok(_) => Ok(Protocol::SimpleString("OK".to_string())),
|
|
Err(e) => Ok(Protocol::err(&e.0)),
|
|
}
|
|
}
|
|
|
|
async fn lindex_cmd(server: &Server, key: &str, index: i64) -> Result<Protocol, DBError> {
|
|
match server.current_storage()?.lindex(key, index) {
|
|
Ok(Some(element)) => Ok(Protocol::BulkString(element)),
|
|
Ok(None) => Ok(Protocol::Null),
|
|
Err(e) => Ok(Protocol::err(&e.0)),
|
|
}
|
|
}
|
|
|
|
async fn lrange_cmd(server: &Server, key: &str, start: i64, stop: i64) -> Result<Protocol, DBError> {
|
|
match server.current_storage()?.lrange(key, start, stop) {
|
|
Ok(elements) => Ok(Protocol::Array(elements.into_iter().map(Protocol::BulkString).collect())),
|
|
Err(e) => Ok(Protocol::err(&e.0)),
|
|
}
|
|
}
|
|
|
|
async fn ltrim_cmd(server: &Server, key: &str, start: i64, stop: i64) -> Result<Protocol, DBError> {
|
|
match server.current_storage()?.ltrim(key, start, stop) {
|
|
Ok(_) => Ok(Protocol::SimpleString("OK".to_string())),
|
|
Err(e) => Ok(Protocol::err(&e.0)),
|
|
}
|
|
}
|
|
|
|
async fn lrem_cmd(server: &Server, key: &str, count: i64, element: &str) -> Result<Protocol, DBError> {
|
|
match server.current_storage()?.lrem(key, count, element) {
|
|
Ok(removed_count) => Ok(Protocol::SimpleString(removed_count.to_string())),
|
|
Err(e) => Ok(Protocol::err(&e.0)),
|
|
}
|
|
}
|
|
|
|
async fn llen_cmd(server: &Server, key: &str) -> Result<Protocol, DBError> {
|
|
match server.current_storage()?.llen(key) {
|
|
Ok(len) => Ok(Protocol::SimpleString(len.to_string())),
|
|
Err(e) => Ok(Protocol::err(&e.0)),
|
|
}
|
|
}
|
|
|
|
async fn lpop_cmd(server: &Server, key: &str, count: &Option<u64>) -> Result<Protocol, DBError> {
|
|
let count_val = count.unwrap_or(1);
|
|
match server.current_storage()?.lpop(key, count_val) {
|
|
Ok(elements) => {
|
|
if elements.is_empty() {
|
|
if count.is_some() {
|
|
Ok(Protocol::Array(vec![]))
|
|
} else {
|
|
Ok(Protocol::Null)
|
|
}
|
|
} else if count.is_some() {
|
|
Ok(Protocol::Array(elements.into_iter().map(Protocol::BulkString).collect()))
|
|
} else {
|
|
Ok(Protocol::BulkString(elements[0].clone()))
|
|
}
|
|
},
|
|
Err(e) => Ok(Protocol::err(&e.0)),
|
|
}
|
|
}
|
|
|
|
async fn rpop_cmd(server: &Server, key: &str, count: &Option<u64>) -> Result<Protocol, DBError> {
|
|
let count_val = count.unwrap_or(1);
|
|
match server.current_storage()?.rpop(key, count_val) {
|
|
Ok(elements) => {
|
|
if elements.is_empty() {
|
|
if count.is_some() {
|
|
Ok(Protocol::Array(vec![]))
|
|
} else {
|
|
Ok(Protocol::Null)
|
|
}
|
|
} else if count.is_some() {
|
|
Ok(Protocol::Array(elements.into_iter().map(Protocol::BulkString).collect()))
|
|
} else {
|
|
Ok(Protocol::BulkString(elements[0].clone()))
|
|
}
|
|
},
|
|
Err(e) => Ok(Protocol::err(&e.0)),
|
|
}
|
|
}
|
|
|
|
async fn lpush_cmd(server: &Server, key: &str, elements: &[String]) -> Result<Protocol, DBError> {
|
|
match server.current_storage()?.lpush(key, elements.to_vec()) {
|
|
Ok(len) => Ok(Protocol::SimpleString(len.to_string())),
|
|
Err(e) => Ok(Protocol::err(&e.0)),
|
|
}
|
|
}
|
|
|
|
async fn rpush_cmd(server: &Server, key: &str, elements: &[String]) -> Result<Protocol, DBError> {
|
|
match server.current_storage()?.rpush(key, elements.to_vec()) {
|
|
Ok(len) => Ok(Protocol::SimpleString(len.to_string())),
|
|
Err(e) => Ok(Protocol::err(&e.0)),
|
|
}
|
|
}
|
|
|
|
async fn exec_cmd(server: &mut Server) -> Result<Protocol, DBError> {
|
|
// Move the queued commands out of `server` so we drop the borrow immediately.
|
|
let cmds = if let Some(cmds) = server.queued_cmd.take() {
|
|
cmds
|
|
} else {
|
|
return Ok(Protocol::err("ERR EXEC without MULTI"));
|
|
};
|
|
|
|
let mut out = Vec::new();
|
|
for (cmd, _) in cmds {
|
|
// Use Box::pin to handle recursion in async function
|
|
let res = Box::pin(cmd.run(server)).await?;
|
|
out.push(res);
|
|
}
|
|
Ok(Protocol::Array(out))
|
|
}
|
|
|
|
async fn incr_cmd(server: &Server, key: &String) -> Result<Protocol, DBError> {
|
|
let storage = server.current_storage()?;
|
|
let current_value = storage.get(key)?;
|
|
|
|
let new_value = match current_value {
|
|
Some(v) => {
|
|
match v.parse::<i64>() {
|
|
Ok(num) => num + 1,
|
|
Err(_) => return Ok(Protocol::err("ERR value is not an integer or out of range")),
|
|
}
|
|
}
|
|
None => 1,
|
|
};
|
|
|
|
storage.set(key.clone(), new_value.to_string())?;
|
|
Ok(Protocol::SimpleString(new_value.to_string()))
|
|
}
|
|
|
|
fn config_get_cmd(name: &String, server: &Server) -> Result<Protocol, DBError> {
|
|
let value = match name.as_str() {
|
|
"dir" => Some(server.option.dir.clone()),
|
|
"dbfilename" => Some(format!("{}.db", server.selected_db)),
|
|
"databases" => Some("16".to_string()), // Hardcoded as per original logic
|
|
_ => None,
|
|
};
|
|
|
|
if let Some(val) = value {
|
|
Ok(Protocol::Array(vec![
|
|
Protocol::BulkString(name.clone()),
|
|
Protocol::BulkString(val),
|
|
]))
|
|
} else {
|
|
// Return an empty array for unknown config options, which is standard Redis behavior
|
|
Ok(Protocol::Array(vec![]))
|
|
}
|
|
}
|
|
|
|
async fn keys_cmd(server: &Server) -> Result<Protocol, DBError> {
|
|
let keys = server.current_storage()?.keys("*")?;
|
|
Ok(Protocol::Array(
|
|
keys.into_iter().map(Protocol::BulkString).collect(),
|
|
))
|
|
}
|
|
|
|
#[derive(Serialize)]
|
|
struct ServerInfo {
|
|
redis_version: String,
|
|
encrypted: bool,
|
|
selected_db: u64,
|
|
}
|
|
|
|
async fn info_cmd(server: &Server, section: &Option<String>) -> Result<Protocol, DBError> {
|
|
let info = ServerInfo {
|
|
redis_version: "7.0.0".to_string(),
|
|
encrypted: server.current_storage()?.is_encrypted(),
|
|
selected_db: server.selected_db,
|
|
};
|
|
|
|
let mut info_string = String::new();
|
|
info_string.push_str(&format!("# Server\n"));
|
|
info_string.push_str(&format!("redis_version:{}\n", info.redis_version));
|
|
info_string.push_str(&format!("encrypted:{}\n", if info.encrypted { 1 } else { 0 }));
|
|
info_string.push_str(&format!("# Keyspace\n"));
|
|
info_string.push_str(&format!("db{}:keys=0,expires=0,avg_ttl=0\n", info.selected_db));
|
|
|
|
|
|
match section {
|
|
Some(s) => match s.as_str() {
|
|
"replication" => Ok(Protocol::BulkString(
|
|
"role:master\nmaster_replid:8371b4fb1155b71f4a04d3e1bc3e18c4a990aeea\nmaster_repl_offset:0\n".to_string()
|
|
)),
|
|
_ => Err(DBError(format!("unsupported section {:?}", s))),
|
|
},
|
|
None => {
|
|
Ok(Protocol::BulkString(info_string))
|
|
}
|
|
}
|
|
}
|
|
|
|
async fn type_cmd(server: &Server, k: &String) -> Result<Protocol, DBError> {
|
|
match server.current_storage()?.get_key_type(k)? {
|
|
Some(type_str) => Ok(Protocol::SimpleString(type_str)),
|
|
None => Ok(Protocol::SimpleString("none".to_string())),
|
|
}
|
|
}
|
|
|
|
async fn del_cmd(server: &Server, k: &str) -> Result<Protocol, DBError> {
|
|
server.current_storage()?.del(k.to_string())?;
|
|
Ok(Protocol::SimpleString("1".to_string()))
|
|
}
|
|
|
|
async fn set_ex_cmd(
|
|
server: &Server,
|
|
k: &str,
|
|
v: &str,
|
|
x: &u128,
|
|
) -> Result<Protocol, DBError> {
|
|
server.current_storage()?.setx(k.to_string(), v.to_string(), *x * 1000)?;
|
|
Ok(Protocol::SimpleString("OK".to_string()))
|
|
}
|
|
|
|
async fn set_px_cmd(
|
|
server: &Server,
|
|
k: &str,
|
|
v: &str,
|
|
x: &u128,
|
|
) -> Result<Protocol, DBError> {
|
|
server.current_storage()?.setx(k.to_string(), v.to_string(), *x)?;
|
|
Ok(Protocol::SimpleString("OK".to_string()))
|
|
}
|
|
|
|
async fn set_cmd(server: &Server, k: &str, v: &str) -> Result<Protocol, DBError> {
|
|
server.current_storage()?.set(k.to_string(), v.to_string())?;
|
|
Ok(Protocol::SimpleString("OK".to_string()))
|
|
}
|
|
|
|
async fn get_cmd(server: &Server, k: &str) -> Result<Protocol, DBError> {
|
|
let v = server.current_storage()?.get(k)?;
|
|
Ok(v.map_or(Protocol::Null, Protocol::BulkString))
|
|
}
|
|
|
|
// Hash command implementations
|
|
async fn hset_cmd(server: &Server, key: &str, pairs: &[(String, String)]) -> Result<Protocol, DBError> {
|
|
let new_fields = server.current_storage()?.hset(key, pairs.to_vec())?;
|
|
Ok(Protocol::SimpleString(new_fields.to_string()))
|
|
}
|
|
|
|
async fn hget_cmd(server: &Server, key: &str, field: &str) -> Result<Protocol, DBError> {
|
|
match server.current_storage()?.hget(key, field) {
|
|
Ok(Some(value)) => Ok(Protocol::BulkString(value)),
|
|
Ok(None) => Ok(Protocol::Null),
|
|
Err(e) => Ok(Protocol::err(&e.0)),
|
|
}
|
|
}
|
|
|
|
async fn hgetall_cmd(server: &Server, key: &str) -> Result<Protocol, DBError> {
|
|
match server.current_storage()?.hgetall(key) {
|
|
Ok(pairs) => {
|
|
let mut result = Vec::new();
|
|
for (field, value) in pairs {
|
|
result.push(Protocol::BulkString(field));
|
|
result.push(Protocol::BulkString(value));
|
|
}
|
|
Ok(Protocol::Array(result))
|
|
}
|
|
Err(e) => Ok(Protocol::err(&e.0)),
|
|
}
|
|
}
|
|
|
|
async fn hdel_cmd(server: &Server, key: &str, fields: &[String]) -> Result<Protocol, DBError> {
|
|
match server.current_storage()?.hdel(key, fields.to_vec()) {
|
|
Ok(deleted) => Ok(Protocol::SimpleString(deleted.to_string())),
|
|
Err(e) => Ok(Protocol::err(&e.0)),
|
|
}
|
|
}
|
|
|
|
async fn hexists_cmd(server: &Server, key: &str, field: &str) -> Result<Protocol, DBError> {
|
|
match server.current_storage()?.hexists(key, field) {
|
|
Ok(exists) => Ok(Protocol::SimpleString(if exists { "1" } else { "0" }.to_string())),
|
|
Err(e) => Ok(Protocol::err(&e.0)),
|
|
}
|
|
}
|
|
|
|
async fn hkeys_cmd(server: &Server, key: &str) -> Result<Protocol, DBError> {
|
|
match server.current_storage()?.hkeys(key) {
|
|
Ok(keys) => Ok(Protocol::Array(
|
|
keys.into_iter().map(Protocol::BulkString).collect(),
|
|
)),
|
|
Err(e) => Ok(Protocol::err(&e.0)),
|
|
}
|
|
}
|
|
|
|
async fn hvals_cmd(server: &Server, key: &str) -> Result<Protocol, DBError> {
|
|
match server.current_storage()?.hvals(key) {
|
|
Ok(values) => Ok(Protocol::Array(
|
|
values.into_iter().map(Protocol::BulkString).collect(),
|
|
)),
|
|
Err(e) => Ok(Protocol::err(&e.0)),
|
|
}
|
|
}
|
|
|
|
async fn hlen_cmd(server: &Server, key: &str) -> Result<Protocol, DBError> {
|
|
match server.current_storage()?.hlen(key) {
|
|
Ok(len) => Ok(Protocol::SimpleString(len.to_string())),
|
|
Err(e) => Ok(Protocol::err(&e.0)),
|
|
}
|
|
}
|
|
|
|
async fn hmget_cmd(server: &Server, key: &str, fields: &[String]) -> Result<Protocol, DBError> {
|
|
match server.current_storage()?.hmget(key, fields.to_vec()) {
|
|
Ok(values) => {
|
|
let result: Vec<Protocol> = values
|
|
.into_iter()
|
|
.map(|v| v.map_or(Protocol::Null, Protocol::BulkString))
|
|
.collect();
|
|
Ok(Protocol::Array(result))
|
|
}
|
|
Err(e) => Ok(Protocol::err(&e.0)),
|
|
}
|
|
}
|
|
|
|
async fn hsetnx_cmd(server: &Server, key: &str, field: &str, value: &str) -> Result<Protocol, DBError> {
|
|
match server.current_storage()?.hsetnx(key, field, value) {
|
|
Ok(was_set) => Ok(Protocol::SimpleString(if was_set { "1" } else { "0" }.to_string())),
|
|
Err(e) => Ok(Protocol::err(&e.0)),
|
|
}
|
|
}
|
|
|
|
async fn scan_cmd(
|
|
server: &Server,
|
|
cursor: &u64,
|
|
pattern: Option<&str>,
|
|
count: &Option<u64>
|
|
) -> Result<Protocol, DBError> {
|
|
match server.current_storage()?.scan(*cursor, pattern, *count) {
|
|
Ok((next_cursor, key_value_pairs)) => {
|
|
let mut result = Vec::new();
|
|
result.push(Protocol::BulkString(next_cursor.to_string()));
|
|
// For SCAN, we only return the keys, not the values
|
|
let keys: Vec<Protocol> = key_value_pairs.into_iter().map(|(key, _)| Protocol::BulkString(key)).collect();
|
|
result.push(Protocol::Array(keys));
|
|
Ok(Protocol::Array(result))
|
|
}
|
|
Err(e) => Ok(Protocol::err(&format!("ERR {}", e.0))),
|
|
}
|
|
}
|
|
|
|
async fn hscan_cmd(
|
|
server: &Server,
|
|
key: &str,
|
|
cursor: &u64,
|
|
pattern: Option<&str>,
|
|
count: &Option<u64>
|
|
) -> Result<Protocol, DBError> {
|
|
match server.current_storage()?.hscan(key, *cursor, pattern, *count) {
|
|
Ok((next_cursor, field_value_pairs)) => {
|
|
let mut result = Vec::new();
|
|
result.push(Protocol::BulkString(next_cursor.to_string()));
|
|
// For HSCAN, we return field-value pairs flattened
|
|
let mut fields_and_values = Vec::new();
|
|
for (field, value) in field_value_pairs {
|
|
fields_and_values.push(Protocol::BulkString(field));
|
|
fields_and_values.push(Protocol::BulkString(value));
|
|
}
|
|
result.push(Protocol::Array(fields_and_values));
|
|
Ok(Protocol::Array(result))
|
|
}
|
|
Err(e) => Ok(Protocol::err(&format!("ERR {}", e.0))),
|
|
}
|
|
}
|
|
|
|
async fn ttl_cmd(server: &Server, key: &str) -> Result<Protocol, DBError> {
|
|
match server.current_storage()?.ttl(key) {
|
|
Ok(ttl) => Ok(Protocol::SimpleString(ttl.to_string())),
|
|
Err(e) => Ok(Protocol::err(&e.0)),
|
|
}
|
|
}
|
|
|
|
async fn exists_cmd(server: &Server, key: &str) -> Result<Protocol, DBError> {
|
|
match server.current_storage()?.exists(key) {
|
|
Ok(exists) => Ok(Protocol::SimpleString(if exists { "1" } else { "0" }.to_string())),
|
|
Err(e) => Ok(Protocol::err(&e.0)),
|
|
}
|
|
}
|
|
|
|
async fn client_setname_cmd(server: &mut Server, name: &str) -> Result<Protocol, DBError> {
|
|
server.client_name = Some(name.to_string());
|
|
Ok(Protocol::SimpleString("OK".to_string()))
|
|
}
|
|
|
|
async fn client_getname_cmd(server: &Server) -> Result<Protocol, DBError> {
|
|
match &server.client_name {
|
|
Some(name) => Ok(Protocol::BulkString(name.clone())),
|
|
None => Ok(Protocol::Null),
|
|
}
|
|
}
|