diff --git a/src/cmd.rs b/src/cmd.rs index bb98e38..d15441c 100644 --- a/src/cmd.rs +++ b/src/cmd.rs @@ -415,81 +415,74 @@ impl Cmd { } } - pub async fn run( - &self, - server: &mut Server, - protocol: Protocol, - queued_cmd: &mut Option>, - ) -> Result { + pub async fn run(self, server: &mut Server) -> Result { // Handle queued commands for transactions - if queued_cmd.is_some() + if server.queued_cmd.is_some() && !matches!(self, Cmd::Exec) && !matches!(self, Cmd::Multi) && !matches!(self, Cmd::Discard) { - queued_cmd - .as_mut() - .unwrap() - .push((self.clone(), protocol.clone())); + let protocol = self.clone().to_protocol(); + server.queued_cmd.as_mut().unwrap().push((self, protocol)); return Ok(Protocol::SimpleString("QUEUED".to_string())); } match self { - Cmd::Select(db) => select_cmd(server, *db).await, + Cmd::Select(db) => select_cmd(server, db).await, Cmd::Ping => Ok(Protocol::SimpleString("PONG".to_string())), - Cmd::Echo(s) => Ok(Protocol::SimpleString(s.clone())), - Cmd::Get(k) => get_cmd(server, k).await, - Cmd::Set(k, v) => set_cmd(server, k, v).await, - Cmd::SetPx(k, v, x) => set_px_cmd(server, k, v, x).await, - Cmd::SetEx(k, v, x) => set_ex_cmd(server, k, v, x).await, - Cmd::Del(k) => del_cmd(server, k).await, - Cmd::ConfigGet(name) => config_get_cmd(name, server), + Cmd::Echo(s) => Ok(Protocol::BulkString(s)), + Cmd::Get(k) => get_cmd(server, &k).await, + Cmd::Set(k, v) => set_cmd(server, &k, &v).await, + Cmd::SetPx(k, v, x) => set_px_cmd(server, &k, &v, &x).await, + Cmd::SetEx(k, v, x) => set_ex_cmd(server, &k, &v, &x).await, + Cmd::Del(k) => del_cmd(server, &k).await, + Cmd::ConfigGet(name) => config_get_cmd(&name, server), Cmd::Keys => keys_cmd(server).await, - Cmd::Info(section) => info_cmd(server, section).await, - Cmd::Type(k) => type_cmd(server, k).await, - Cmd::Incr(key) => incr_cmd(server, key).await, + Cmd::Info(section) => info_cmd(server, §ion).await, + Cmd::Type(k) => type_cmd(server, &k).await, + Cmd::Incr(key) => incr_cmd(server, &key).await, Cmd::Multi => { - *queued_cmd = Some(Vec::<(Cmd, Protocol)>::new()); + server.queued_cmd = Some(Vec::<(Cmd, Protocol)>::new()); Ok(Protocol::SimpleString("OK".to_string())) } - Cmd::Exec => exec_cmd(queued_cmd, server).await, + Cmd::Exec => exec_cmd(server).await, Cmd::Discard => { - if queued_cmd.is_some() { - *queued_cmd = None; + if server.queued_cmd.is_some() { + server.queued_cmd = None; Ok(Protocol::SimpleString("OK".to_string())) } else { Ok(Protocol::err("ERR DISCARD without MULTI")) } } // Hash commands - Cmd::HSet(key, pairs) => hset_cmd(server, key, pairs).await, - Cmd::HGet(key, field) => hget_cmd(server, key, field).await, - Cmd::HGetAll(key) => hgetall_cmd(server, key).await, - Cmd::HDel(key, fields) => hdel_cmd(server, key, fields).await, - Cmd::HExists(key, field) => hexists_cmd(server, key, field).await, - Cmd::HKeys(key) => hkeys_cmd(server, key).await, - Cmd::HVals(key) => hvals_cmd(server, key).await, - Cmd::HLen(key) => hlen_cmd(server, key).await, - Cmd::HMGet(key, fields) => hmget_cmd(server, key, fields).await, - Cmd::HSetNx(key, field, value) => hsetnx_cmd(server, key, field, value).await, - Cmd::HScan(key, cursor, pattern, count) => hscan_cmd(server, key, cursor, pattern.as_deref(), count).await, - Cmd::Scan(cursor, pattern, count) => scan_cmd(server, cursor, pattern.as_deref(), count).await, - Cmd::Ttl(key) => ttl_cmd(server, key).await, - Cmd::Exists(key) => exists_cmd(server, key).await, + Cmd::HSet(key, pairs) => hset_cmd(server, &key, &pairs).await, + Cmd::HGet(key, field) => hget_cmd(server, &key, &field).await, + Cmd::HGetAll(key) => hgetall_cmd(server, &key).await, + Cmd::HDel(key, fields) => hdel_cmd(server, &key, &fields).await, + Cmd::HExists(key, field) => hexists_cmd(server, &key, &field).await, + Cmd::HKeys(key) => hkeys_cmd(server, &key).await, + Cmd::HVals(key) => hvals_cmd(server, &key).await, + Cmd::HLen(key) => hlen_cmd(server, &key).await, + Cmd::HMGet(key, fields) => hmget_cmd(server, &key, &fields).await, + Cmd::HSetNx(key, field, value) => hsetnx_cmd(server, &key, &field, &value).await, + Cmd::HScan(key, cursor, pattern, count) => hscan_cmd(server, &key, &cursor, pattern.as_deref(), &count).await, + Cmd::Scan(cursor, pattern, count) => scan_cmd(server, &cursor, pattern.as_deref(), &count).await, + Cmd::Ttl(key) => ttl_cmd(server, &key).await, + Cmd::Exists(key) => exists_cmd(server, &key).await, Cmd::Quit => Ok(Protocol::SimpleString("OK".to_string())), Cmd::Client(_) => Ok(Protocol::SimpleString("OK".to_string())), - Cmd::ClientSetName(name) => client_setname_cmd(server, name).await, + Cmd::ClientSetName(name) => client_setname_cmd(server, &name).await, Cmd::ClientGetName => client_getname_cmd(server).await, // List commands - Cmd::LPush(key, elements) => lpush_cmd(server, key, elements).await, - Cmd::RPush(key, elements) => rpush_cmd(server, key, elements).await, - Cmd::LPop(key, count) => lpop_cmd(server, key, count).await, - Cmd::RPop(key, count) => rpop_cmd(server, key, count).await, - Cmd::LLen(key) => llen_cmd(server, key).await, - Cmd::LRem(key, count, element) => lrem_cmd(server, key, *count, element).await, - Cmd::LTrim(key, start, stop) => ltrim_cmd(server, key, *start, *stop).await, - Cmd::LIndex(key, index) => lindex_cmd(server, key, *index).await, - Cmd::LRange(key, start, stop) => lrange_cmd(server, key, *start, *stop).await, + Cmd::LPush(key, elements) => lpush_cmd(server, &key, &elements).await, + Cmd::RPush(key, elements) => rpush_cmd(server, &key, &elements).await, + Cmd::LPop(key, count) => lpop_cmd(server, &key, &count).await, + Cmd::RPop(key, count) => rpop_cmd(server, &key, &count).await, + Cmd::LLen(key) => llen_cmd(server, &key).await, + Cmd::LRem(key, count, element) => lrem_cmd(server, &key, count, &element).await, + Cmd::LTrim(key, start, stop) => ltrim_cmd(server, &key, start, stop).await, + Cmd::LIndex(key, index) => lindex_cmd(server, &key, index).await, + Cmd::LRange(key, start, stop) => lrange_cmd(server, &key, start, stop).await, Cmd::FlushDb => flushdb_cmd(server).await, Cmd::Unknow(s) => { println!("\x1b[31;1munknown command: {}\x1b[0m", s); @@ -497,6 +490,17 @@ impl Cmd { } } } + + pub fn to_protocol(self) -> Protocol { + match self { + Cmd::Select(db) => Protocol::Array(vec![Protocol::BulkString("select".to_string()), Protocol::BulkString(db.to_string())]), + Cmd::Ping => Protocol::Array(vec![Protocol::BulkString("ping".to_string())]), + Cmd::Echo(s) => Protocol::Array(vec![Protocol::BulkString("echo".to_string()), Protocol::BulkString(s)]), + Cmd::Get(k) => Protocol::Array(vec![Protocol::BulkString("get".to_string()), Protocol::BulkString(k)]), + Cmd::Set(k, v) => Protocol::Array(vec![Protocol::BulkString("set".to_string()), Protocol::BulkString(k), Protocol::BulkString(v)]), + _ => Protocol::SimpleString("...".to_string()) + } + } } async fn flushdb_cmd(server: &mut Server) -> Result { @@ -605,21 +609,19 @@ async fn rpush_cmd(server: &Server, key: &str, elements: &[String]) -> Result>, - server: &mut Server, -) -> Result { - if queued_cmd.is_some() { - let mut vec = Vec::new(); - for (cmd, protocol) in queued_cmd.as_ref().unwrap() { - let res = Box::pin(cmd.run(server, protocol.clone(), &mut None)).await?; - vec.push(res); - } - *queued_cmd = None; - Ok(Protocol::Array(vec)) +async fn exec_cmd(server: &mut Server) -> Result { + let cmds = if let Some(cmds) = server.queued_cmd.take() { + cmds } else { - Ok(Protocol::err("ERR EXEC without MULTI")) + return Ok(Protocol::err("ERR EXEC without MULTI")); + }; + + let mut vec = Vec::new(); + for (cmd, _) in cmds { + let res = cmd.run(server).await?; + vec.push(res); } + Ok(Protocol::Array(vec)) } async fn incr_cmd(server: &Server, key: &String) -> Result { @@ -834,31 +836,25 @@ async fn hsetnx_cmd(server: &Server, key: &str, field: &str, value: &str) -> Res } async fn scan_cmd(server: &Server, cursor: &u64, pattern: Option<&str>, count: &Option) -> Result { - match server.current_storage()?.scan(*cursor, pattern, *count) { - Ok((next_cursor, keys)) => { - let mut result = Vec::new(); - result.push(Protocol::BulkString(next_cursor.to_string())); - result.push(Protocol::Array( - keys.into_iter().map(Protocol::BulkString).collect(), - )); - Ok(Protocol::Array(result)) - } - Err(e) => Ok(Protocol::err(&e.0)), - } + server.current_storage()?.scan(*cursor, pattern, *count).map(|(next_cursor, keys)| { + let mut result = Vec::new(); + result.push(Protocol::BulkString(next_cursor.to_string())); + result.push(Protocol::Array( + keys.into_iter().map(Protocol::BulkString).collect(), + )); + Protocol::Array(result) + }) } async fn hscan_cmd(server: &Server, key: &str, cursor: &u64, pattern: Option<&str>, count: &Option) -> Result { - match server.current_storage()?.hscan(key, *cursor, pattern, *count) { - Ok((next_cursor, fields)) => { - let mut result = Vec::new(); - result.push(Protocol::BulkString(next_cursor.to_string())); - result.push(Protocol::Array( - fields.into_iter().map(Protocol::BulkString).collect(), - )); - Ok(Protocol::Array(result)) - } - Err(e) => Ok(Protocol::err(&e.0)), - } + server.current_storage()?.hscan(key, *cursor, pattern, *count).map(|(next_cursor, fields)| { + let mut result = Vec::new(); + result.push(Protocol::BulkString(next_cursor.to_string())); + result.push(Protocol::Array( + fields.into_iter().map(Protocol::BulkString).collect(), + )); + Protocol::Array(result) + }) } async fn ttl_cmd(server: &Server, key: &str) -> Result { diff --git a/src/server.rs b/src/server.rs index 036b82b..6a389af 100644 --- a/src/server.rs +++ b/src/server.rs @@ -16,6 +16,7 @@ pub struct Server { pub option: options::DBOption, pub client_name: Option, pub selected_db: u64, // Changed from usize to u64 + pub queued_cmd: Option>, } impl Server { @@ -25,6 +26,7 @@ impl Server { option, client_name: None, selected_db: 0, + queued_cmd: None, } } @@ -61,7 +63,6 @@ impl Server { mut stream: tokio::net::TcpStream, ) -> Result<(), DBError> { let mut buf = [0; 512]; - let mut queued_cmd: Option> = None; loop { let len = match stream.read(&mut buf).await { @@ -96,16 +97,13 @@ impl Server { // Check if this is a QUIT command before processing let is_quit = matches!(cmd, Cmd::Quit); - let res = cmd - .run(&mut self.clone(), protocol.clone(), &mut queued_cmd) - .await - .unwrap_or(Protocol::err("unknown cmd from server")); + let res = cmd.run(self).await.unwrap_or(Protocol::err("unknown cmd from server")); if self.option.debug { - println!("\x1b[34;1mqueued cmd {:?}\x1b[0m", queued_cmd); + println!("\x1b[34;1mqueued cmd {:?}\x1b[0m", self.queued_cmd); println!("\x1b[32;1mgoing to send response {}\x1b[0m", res.encode()); } else { - print!("queued cmd {:?}", queued_cmd); + print!("queued cmd {:?}", self.queued_cmd); println!("going to send response {}", res.encode()); } diff --git a/src/storage.rs b/src/storage.rs index 0ff2d6c..36f0c43 100644 --- a/src/storage.rs +++ b/src/storage.rs @@ -736,9 +736,9 @@ impl Storage { current_cursor += 1; } - // If we've reached the end of iteration, return cursor 0 to indicate completion - let next_cursor = if iter.next().is_none() { 0 } else { current_cursor }; - + // If we've reached the end of the iteration, return cursor 0, otherwise return the next cursor position + let next_cursor = if returned_keys < count { 0 } else { current_cursor }; + Ok((next_cursor, keys)) }