This commit is contained in:
2025-08-16 10:28:28 +02:00
parent 246304b9fa
commit 5eab3b080c
3 changed files with 54 additions and 33 deletions

View File

@@ -610,18 +610,20 @@ async fn rpush_cmd(server: &Server, key: &str, elements: &[String]) -> Result<Pr
} }
async fn exec_cmd(server: &mut Server) -> Result<Protocol, DBError> { async fn exec_cmd(server: &mut Server) -> Result<Protocol, DBError> {
// Move the queued commands out of `server` so we drop the borrow immediately.
let cmds = if let Some(cmds) = server.queued_cmd.take() { let cmds = if let Some(cmds) = server.queued_cmd.take() {
cmds cmds
} else { } else {
return Ok(Protocol::err("ERR EXEC without MULTI")); return Ok(Protocol::err("ERR EXEC without MULTI"));
}; };
let mut vec = Vec::new(); let mut out = Vec::new();
for (cmd, _) in cmds { for (cmd, _) in cmds {
let res = cmd.run(server).await?; // Use Box::pin to handle recursion in async function
vec.push(res); let res = Box::pin(cmd.run(server)).await?;
out.push(res);
} }
Ok(Protocol::Array(vec)) Ok(Protocol::Array(out))
} }
async fn incr_cmd(server: &Server, key: &String) -> Result<Protocol, DBError> { async fn incr_cmd(server: &Server, key: &String) -> Result<Protocol, DBError> {
@@ -835,26 +837,39 @@ async fn hsetnx_cmd(server: &Server, key: &str, field: &str, value: &str) -> Res
} }
} }
async fn scan_cmd(server: &Server, cursor: &u64, pattern: Option<&str>, count: &Option<u64>) -> Result<Protocol, DBError> { async fn scan_cmd(
server.current_storage()?.scan(*cursor, pattern, *count).map(|(next_cursor, keys)| { server: &Server,
let mut result = Vec::new(); cursor: &u64,
result.push(Protocol::BulkString(next_cursor.to_string())); pattern: Option<&str>,
result.push(Protocol::Array( count: &Option<u64>
keys.into_iter().map(Protocol::BulkString).collect(), ) -> Result<Protocol, DBError> {
)); match server.current_storage()?.scan(*cursor, pattern, *count) {
Protocol::Array(result) Ok((next_cursor, keys)) => {
}) let mut result = Vec::new();
result.push(Protocol::BulkString(next_cursor.to_string()));
result.push(Protocol::Array(keys.into_iter().map(Protocol::BulkString).collect()));
Ok(Protocol::Array(result))
}
Err(e) => Ok(Protocol::err(&format!("ERR {}", e.0))),
}
} }
async fn hscan_cmd(server: &Server, key: &str, cursor: &u64, pattern: Option<&str>, count: &Option<u64>) -> Result<Protocol, DBError> { async fn hscan_cmd(
server.current_storage()?.hscan(key, *cursor, pattern, *count).map(|(next_cursor, fields)| { server: &Server,
let mut result = Vec::new(); key: &str,
result.push(Protocol::BulkString(next_cursor.to_string())); cursor: &u64,
result.push(Protocol::Array( pattern: Option<&str>,
fields.into_iter().map(Protocol::BulkString).collect(), count: &Option<u64>
)); ) -> Result<Protocol, DBError> {
Protocol::Array(result) match server.current_storage()?.hscan(key, *cursor, pattern, *count) {
}) Ok((next_cursor, fields)) => {
let mut result = Vec::new();
result.push(Protocol::BulkString(next_cursor.to_string()));
result.push(Protocol::Array(fields.into_iter().map(Protocol::BulkString).collect()));
Ok(Protocol::Array(result))
}
Err(e) => Ok(Protocol::err(&format!("ERR {}", e.0))),
}
} }
async fn ttl_cmd(server: &Server, key: &str) -> Result<Protocol, DBError> { async fn ttl_cmd(server: &Server, key: &str) -> Result<Protocol, DBError> {

View File

@@ -8,6 +8,7 @@ pub enum Protocol {
BulkString(String), BulkString(String),
Null, Null,
Array(Vec<Protocol>), Array(Vec<Protocol>),
Error(String), // NEW
} }
impl fmt::Display for Protocol { impl fmt::Display for Protocol {
@@ -45,7 +46,7 @@ impl Protocol {
#[inline] #[inline]
pub fn err(msg: &str) -> Self { pub fn err(msg: &str) -> Self {
Protocol::SimpleString(msg.to_string()) Protocol::Error(msg.to_string())
} }
#[inline] #[inline]
@@ -69,22 +70,19 @@ impl Protocol {
Protocol::BulkString(s) => s.to_string(), Protocol::BulkString(s) => s.to_string(),
Protocol::Null => "".to_string(), Protocol::Null => "".to_string(),
Protocol::Array(s) => s.iter().map(|x| x.decode()).collect::<Vec<_>>().join(" "), Protocol::Array(s) => s.iter().map(|x| x.decode()).collect::<Vec<_>>().join(" "),
Protocol::Error(s) => s.to_string(),
} }
} }
pub fn encode(&self) -> String { pub fn encode(&self) -> String {
match self { match self {
Protocol::SimpleString(s) => format!("+{}\r\n", s), Protocol::SimpleString(s) => format!("+{}\r\n", s),
Protocol::BulkString(s) => format!("${}\r\n{}\r\n", s.len(), s), Protocol::BulkString(s) => format!("${}\r\n{}\r\n", s.len(), s),
Protocol::Array(ss) => { Protocol::Array(ss) => {
format!("*{}\r\n", ss.len()) format!("*{}\r\n", ss.len()) + &ss.iter().map(|x| x.encode()).collect::<String>()
+ ss.iter()
.map(|x| x.encode())
.collect::<Vec<_>>()
.join("")
.as_str()
} }
Protocol::Null => "$-1\r\n".to_string(), Protocol::Null => "$-1\r\n".to_string(),
Protocol::Error(s) => format!("-{}\r\n", s), // proper RESP error
} }
} }

View File

@@ -97,7 +97,15 @@ impl Server {
// Check if this is a QUIT command before processing // Check if this is a QUIT command before processing
let is_quit = matches!(cmd, Cmd::Quit); let is_quit = matches!(cmd, Cmd::Quit);
let res = cmd.run(self).await.unwrap_or(Protocol::err("unknown cmd from server")); let res = match cmd.run(self).await {
Ok(p) => p,
Err(e) => {
if self.option.debug {
eprintln!("[run error] {:?}", e);
}
Protocol::err(&format!("ERR {}", e.0))
}
};
if self.option.debug { if self.option.debug {
println!("\x1b[34;1mqueued cmd {:?}\x1b[0m", self.queued_cmd); println!("\x1b[34;1mqueued cmd {:?}\x1b[0m", self.queued_cmd);