diff --git a/src/cmd.rs b/src/cmd.rs index d15441c..9b6bbbc 100644 --- a/src/cmd.rs +++ b/src/cmd.rs @@ -610,18 +610,20 @@ async fn rpush_cmd(server: &Server, key: &str, elements: &[String]) -> Result Result { + // Move the queued commands out of `server` so we drop the borrow immediately. let cmds = if let Some(cmds) = server.queued_cmd.take() { cmds } else { return Ok(Protocol::err("ERR EXEC without MULTI")); }; - let mut vec = Vec::new(); + let mut out = Vec::new(); for (cmd, _) in cmds { - let res = cmd.run(server).await?; - vec.push(res); + // Use Box::pin to handle recursion in async function + let res = Box::pin(cmd.run(server)).await?; + out.push(res); } - Ok(Protocol::Array(vec)) + Ok(Protocol::Array(out)) } async fn incr_cmd(server: &Server, key: &String) -> Result { @@ -835,26 +837,39 @@ async fn hsetnx_cmd(server: &Server, key: &str, field: &str, value: &str) -> Res } } -async fn scan_cmd(server: &Server, cursor: &u64, pattern: Option<&str>, count: &Option) -> Result { - server.current_storage()?.scan(*cursor, pattern, *count).map(|(next_cursor, keys)| { - let mut result = Vec::new(); - result.push(Protocol::BulkString(next_cursor.to_string())); - result.push(Protocol::Array( - keys.into_iter().map(Protocol::BulkString).collect(), - )); - Protocol::Array(result) - }) +async fn scan_cmd( + server: &Server, + cursor: &u64, + pattern: Option<&str>, + count: &Option +) -> Result { + match server.current_storage()?.scan(*cursor, pattern, *count) { + Ok((next_cursor, keys)) => { + let mut result = Vec::new(); + result.push(Protocol::BulkString(next_cursor.to_string())); + result.push(Protocol::Array(keys.into_iter().map(Protocol::BulkString).collect())); + Ok(Protocol::Array(result)) + } + Err(e) => Ok(Protocol::err(&format!("ERR {}", e.0))), + } } -async fn hscan_cmd(server: &Server, key: &str, cursor: &u64, pattern: Option<&str>, count: &Option) -> Result { - server.current_storage()?.hscan(key, *cursor, pattern, *count).map(|(next_cursor, fields)| { - let mut result = Vec::new(); - result.push(Protocol::BulkString(next_cursor.to_string())); - result.push(Protocol::Array( - fields.into_iter().map(Protocol::BulkString).collect(), - )); - Protocol::Array(result) - }) +async fn hscan_cmd( + server: &Server, + key: &str, + cursor: &u64, + pattern: Option<&str>, + count: &Option +) -> Result { + match server.current_storage()?.hscan(key, *cursor, pattern, *count) { + Ok((next_cursor, fields)) => { + let mut result = Vec::new(); + result.push(Protocol::BulkString(next_cursor.to_string())); + result.push(Protocol::Array(fields.into_iter().map(Protocol::BulkString).collect())); + Ok(Protocol::Array(result)) + } + Err(e) => Ok(Protocol::err(&format!("ERR {}", e.0))), + } } async fn ttl_cmd(server: &Server, key: &str) -> Result { diff --git a/src/protocol.rs b/src/protocol.rs index ad42309..c9a2255 100644 --- a/src/protocol.rs +++ b/src/protocol.rs @@ -8,6 +8,7 @@ pub enum Protocol { BulkString(String), Null, Array(Vec), + Error(String), // NEW } impl fmt::Display for Protocol { @@ -45,7 +46,7 @@ impl Protocol { #[inline] pub fn err(msg: &str) -> Self { - Protocol::SimpleString(msg.to_string()) + Protocol::Error(msg.to_string()) } #[inline] @@ -69,22 +70,19 @@ impl Protocol { Protocol::BulkString(s) => s.to_string(), Protocol::Null => "".to_string(), Protocol::Array(s) => s.iter().map(|x| x.decode()).collect::>().join(" "), + Protocol::Error(s) => s.to_string(), } } pub fn encode(&self) -> String { match self { Protocol::SimpleString(s) => format!("+{}\r\n", s), - Protocol::BulkString(s) => format!("${}\r\n{}\r\n", s.len(), s), - Protocol::Array(ss) => { - format!("*{}\r\n", ss.len()) - + ss.iter() - .map(|x| x.encode()) - .collect::>() - .join("") - .as_str() + Protocol::BulkString(s) => format!("${}\r\n{}\r\n", s.len(), s), + Protocol::Array(ss) => { + format!("*{}\r\n", ss.len()) + &ss.iter().map(|x| x.encode()).collect::() } - Protocol::Null => "$-1\r\n".to_string(), + Protocol::Null => "$-1\r\n".to_string(), + Protocol::Error(s) => format!("-{}\r\n", s), // proper RESP error } } diff --git a/src/server.rs b/src/server.rs index 6a389af..8d294b0 100644 --- a/src/server.rs +++ b/src/server.rs @@ -97,7 +97,15 @@ impl Server { // Check if this is a QUIT command before processing let is_quit = matches!(cmd, Cmd::Quit); - let res = cmd.run(self).await.unwrap_or(Protocol::err("unknown cmd from server")); + let res = match cmd.run(self).await { + Ok(p) => p, + Err(e) => { + if self.option.debug { + eprintln!("[run error] {:?}", e); + } + Protocol::err(&format!("ERR {}", e.0)) + } + }; if self.option.debug { println!("\x1b[34;1mqueued cmd {:?}\x1b[0m", self.queued_cmd);