...
This commit is contained in:
		
							
								
								
									
										188
									
								
								src/cmd.rs
									
									
									
									
									
								
							
							
						
						
									
										188
									
								
								src/cmd.rs
									
									
									
									
									
								
							| @@ -39,13 +39,20 @@ pub enum Cmd { | ||||
|     // List commands | ||||
|     LPush(String, Vec<String>), | ||||
|     RPush(String, Vec<String>), | ||||
|     LPop(String, Option<u64>), | ||||
|     RPop(String, Option<u64>), | ||||
|     LLen(String), | ||||
|     LRem(String, i64, String), | ||||
|     LTrim(String, i64, i64), | ||||
|     LIndex(String, i64), | ||||
|     LRange(String, i64, i64), | ||||
|     Unknow(String), | ||||
| } | ||||
|  | ||||
| impl Cmd { | ||||
|     pub fn from(s: &str) -> Result<(Self, Protocol), DBError> { | ||||
|         let protocol = Protocol::from(s)?; | ||||
|         match protocol.clone().0 { | ||||
|     pub fn from(s: &str) -> Result<(Self, Protocol, &str), DBError> { | ||||
|         let (protocol, remaining) = Protocol::from(s)?; | ||||
|         match protocol.clone() { | ||||
|             Protocol::Array(p) => { | ||||
|                 let cmd = p.into_iter().map(|x| x.decode()).collect::<Vec<_>>(); | ||||
|                 if cmd.is_empty() { | ||||
| @@ -303,14 +310,85 @@ impl Cmd { | ||||
|                                 Cmd::Client(vec![]) | ||||
|                             } | ||||
|                         } | ||||
|                         "lpush" => { | ||||
|                             if cmd.len() < 3 { | ||||
|                                 return Err(DBError(format!("wrong number of arguments for LPUSH command"))); | ||||
|                             } | ||||
|                             Cmd::LPush(cmd[1].clone(), cmd[2..].to_vec()) | ||||
|                         } | ||||
|                         "rpush" => { | ||||
|                             if cmd.len() < 3 { | ||||
|                                 return Err(DBError(format!("wrong number of arguments for RPUSH command"))); | ||||
|                             } | ||||
|                             Cmd::RPush(cmd[1].clone(), cmd[2..].to_vec()) | ||||
|                         } | ||||
|                         "lpop" => { | ||||
|                             if cmd.len() < 2 || cmd.len() > 3 { | ||||
|                                 return Err(DBError(format!("wrong number of arguments for LPOP command"))); | ||||
|                             } | ||||
|                             let count = if cmd.len() == 3 { | ||||
|                                 Some(cmd[2].parse::<u64>().map_err(|_| DBError("ERR value is not an integer or out of range".to_string()))?) | ||||
|                             } else { | ||||
|                                 None | ||||
|                             }; | ||||
|                             Cmd::LPop(cmd[1].clone(), count) | ||||
|                         } | ||||
|                         "rpop" => { | ||||
|                             if cmd.len() < 2 || cmd.len() > 3 { | ||||
|                                 return Err(DBError(format!("wrong number of arguments for RPOP command"))); | ||||
|                             } | ||||
|                             let count = if cmd.len() == 3 { | ||||
|                                 Some(cmd[2].parse::<u64>().map_err(|_| DBError("ERR value is not an integer or out of range".to_string()))?) | ||||
|                             } else { | ||||
|                                 None | ||||
|                             }; | ||||
|                             Cmd::RPop(cmd[1].clone(), count) | ||||
|                         } | ||||
|                         "llen" => { | ||||
|                             if cmd.len() != 2 { | ||||
|                                 return Err(DBError(format!("wrong number of arguments for LLEN command"))); | ||||
|                             } | ||||
|                             Cmd::LLen(cmd[1].clone()) | ||||
|                         } | ||||
|                         "lrem" => { | ||||
|                             if cmd.len() != 4 { | ||||
|                                 return Err(DBError(format!("wrong number of arguments for LREM command"))); | ||||
|                             } | ||||
|                             let count = cmd[2].parse::<i64>().map_err(|_| DBError("ERR value is not an integer or out of range".to_string()))?; | ||||
|                             Cmd::LRem(cmd[1].clone(), count, cmd[3].clone()) | ||||
|                         } | ||||
|                         "ltrim" => { | ||||
|                             if cmd.len() != 4 { | ||||
|                                 return Err(DBError(format!("wrong number of arguments for LTRIM command"))); | ||||
|                             } | ||||
|                             let start = cmd[2].parse::<i64>().map_err(|_| DBError("ERR value is not an integer or out of range".to_string()))?; | ||||
|                             let stop = cmd[3].parse::<i64>().map_err(|_| DBError("ERR value is not an integer or out of range".to_string()))?; | ||||
|                             Cmd::LTrim(cmd[1].clone(), start, stop) | ||||
|                         } | ||||
|                         "lindex" => { | ||||
|                             if cmd.len() != 3 { | ||||
|                                 return Err(DBError(format!("wrong number of arguments for LINDEX command"))); | ||||
|                             } | ||||
|                             let index = cmd[2].parse::<i64>().map_err(|_| DBError("ERR value is not an integer or out of range".to_string()))?; | ||||
|                             Cmd::LIndex(cmd[1].clone(), index) | ||||
|                         } | ||||
|                         "lrange" => { | ||||
|                             if cmd.len() != 4 { | ||||
|                                 return Err(DBError(format!("wrong number of arguments for LRANGE command"))); | ||||
|                             } | ||||
|                             let start = cmd[2].parse::<i64>().map_err(|_| DBError("ERR value is not an integer or out of range".to_string()))?; | ||||
|                             let stop = cmd[3].parse::<i64>().map_err(|_| DBError("ERR value is not an integer or out of range".to_string()))?; | ||||
|                             Cmd::LRange(cmd[1].clone(), start, stop) | ||||
|                         } | ||||
|                         _ => Cmd::Unknow(cmd[0].clone()), | ||||
|                     }, | ||||
|                     protocol.0, | ||||
|                     protocol, | ||||
|                     remaining | ||||
|                 )) | ||||
|             } | ||||
|             _ => Err(DBError(format!( | ||||
|                 "fail to parse as cmd for {:?}", | ||||
|                 protocol.0 | ||||
|                 protocol | ||||
|             ))), | ||||
|         } | ||||
|     } | ||||
| @@ -379,6 +457,16 @@ impl Cmd { | ||||
|             Cmd::Client(_) => Ok(Protocol::SimpleString("OK".to_string())), | ||||
|             Cmd::ClientSetName(name) => client_setname_cmd(server, name).await, | ||||
|             Cmd::ClientGetName => client_getname_cmd(server).await, | ||||
|             // List commands | ||||
|             Cmd::LPush(key, elements) => lpush_cmd(server, key, elements).await, | ||||
|             Cmd::RPush(key, elements) => rpush_cmd(server, key, elements).await, | ||||
|             Cmd::LPop(key, count) => lpop_cmd(server, key, count).await, | ||||
|             Cmd::RPop(key, count) => rpop_cmd(server, key, count).await, | ||||
|             Cmd::LLen(key) => llen_cmd(server, key).await, | ||||
|             Cmd::LRem(key, count, element) => lrem_cmd(server, key, *count, element).await, | ||||
|             Cmd::LTrim(key, start, stop) => ltrim_cmd(server, key, *start, *stop).await, | ||||
|             Cmd::LIndex(key, index) => lindex_cmd(server, key, *index).await, | ||||
|             Cmd::LRange(key, start, stop) => lrange_cmd(server, key, *start, *stop).await, | ||||
|             Cmd::Unknow(s) => { | ||||
|                 println!("\x1b[31;1munknown command: {}\x1b[0m", s); | ||||
|                 Ok(Protocol::err(&format!("ERR unknown command '{}'", s))) | ||||
| @@ -387,6 +475,96 @@ impl Cmd { | ||||
|     } | ||||
| } | ||||
|  | ||||
| async fn lindex_cmd(server: &Server, key: &str, index: i64) -> Result<Protocol, DBError> { | ||||
|     match server.storage.lindex(key, index) { | ||||
|         Ok(Some(element)) => Ok(Protocol::BulkString(element)), | ||||
|         Ok(None) => Ok(Protocol::Null), | ||||
|         Err(e) => Ok(Protocol::err(&e.0)), | ||||
|     } | ||||
| } | ||||
|  | ||||
| async fn lrange_cmd(server: &Server, key: &str, start: i64, stop: i64) -> Result<Protocol, DBError> { | ||||
|     match server.storage.lrange(key, start, stop) { | ||||
|         Ok(elements) => Ok(Protocol::Array(elements.into_iter().map(Protocol::BulkString).collect())), | ||||
|         Err(e) => Ok(Protocol::err(&e.0)), | ||||
|     } | ||||
| } | ||||
|  | ||||
| async fn ltrim_cmd(server: &Server, key: &str, start: i64, stop: i64) -> Result<Protocol, DBError> { | ||||
|     match server.storage.ltrim(key, start, stop) { | ||||
|         Ok(_) => Ok(Protocol::SimpleString("OK".to_string())), | ||||
|         Err(e) => Ok(Protocol::err(&e.0)), | ||||
|     } | ||||
| } | ||||
|  | ||||
| async fn lrem_cmd(server: &Server, key: &str, count: i64, element: &str) -> Result<Protocol, DBError> { | ||||
|     match server.storage.lrem(key, count, element) { | ||||
|         Ok(removed_count) => Ok(Protocol::SimpleString(removed_count.to_string())), | ||||
|         Err(e) => Ok(Protocol::err(&e.0)), | ||||
|     } | ||||
| } | ||||
|  | ||||
| async fn llen_cmd(server: &Server, key: &str) -> Result<Protocol, DBError> { | ||||
|     match server.storage.llen(key) { | ||||
|         Ok(len) => Ok(Protocol::SimpleString(len.to_string())), | ||||
|         Err(e) => Ok(Protocol::err(&e.0)), | ||||
|     } | ||||
| } | ||||
|  | ||||
| async fn lpop_cmd(server: &Server, key: &str, count: &Option<u64>) -> Result<Protocol, DBError> { | ||||
|     match server.storage.lpop(key, *count) { | ||||
|         Ok(Some(elements)) => { | ||||
|             if count.is_some() { | ||||
|                 Ok(Protocol::Array(elements.into_iter().map(Protocol::BulkString).collect())) | ||||
|             } else { | ||||
|                 Ok(Protocol::BulkString(elements[0].clone())) | ||||
|             } | ||||
|         }, | ||||
|         Ok(None) => { | ||||
|             if count.is_some() { | ||||
|                 Ok(Protocol::Array(vec![])) | ||||
|             } else { | ||||
|                 Ok(Protocol::Null) | ||||
|             } | ||||
|         }, | ||||
|         Err(e) => Ok(Protocol::err(&e.0)), | ||||
|     } | ||||
| } | ||||
|  | ||||
| async fn rpop_cmd(server: &Server, key: &str, count: &Option<u64>) -> Result<Protocol, DBError> { | ||||
|     match server.storage.rpop(key, *count) { | ||||
|         Ok(Some(elements)) => { | ||||
|             if count.is_some() { | ||||
|                 Ok(Protocol::Array(elements.into_iter().map(Protocol::BulkString).collect())) | ||||
|             } else { | ||||
|                 Ok(Protocol::BulkString(elements[0].clone())) | ||||
|             } | ||||
|         }, | ||||
|         Ok(None) => { | ||||
|             if count.is_some() { | ||||
|                 Ok(Protocol::Array(vec![])) | ||||
|             } else { | ||||
|                 Ok(Protocol::Null) | ||||
|             } | ||||
|         }, | ||||
|         Err(e) => Ok(Protocol::err(&e.0)), | ||||
|     } | ||||
| } | ||||
|  | ||||
| async fn lpush_cmd(server: &Server, key: &str, elements: &[String]) -> Result<Protocol, DBError> { | ||||
|     match server.storage.lpush(key, elements.to_vec()) { | ||||
|         Ok(len) => Ok(Protocol::SimpleString(len.to_string())), | ||||
|         Err(e) => Ok(Protocol::err(&e.0)), | ||||
|     } | ||||
| } | ||||
|  | ||||
| async fn rpush_cmd(server: &Server, key: &str, elements: &[String]) -> Result<Protocol, DBError> { | ||||
|     match server.storage.rpush(key, elements.to_vec()) { | ||||
|         Ok(len) => Ok(Protocol::SimpleString(len.to_string())), | ||||
|         Err(e) => Ok(Protocol::err(&e.0)), | ||||
|     } | ||||
| } | ||||
|  | ||||
| async fn exec_cmd( | ||||
|     queued_cmd: &mut Option<Vec<(Cmd, Protocol)>>, | ||||
|     server: &mut Server, | ||||
|   | ||||
| @@ -17,7 +17,7 @@ impl fmt::Display for Protocol { | ||||
| } | ||||
|  | ||||
| impl Protocol { | ||||
|     pub fn from(protocol: &str) -> Result<(Self, usize), DBError> { | ||||
|     pub fn from(protocol: &str) -> Result<(Self, &str), DBError> { | ||||
|         let ret = match protocol.chars().nth(0) { | ||||
|             Some('+') => Self::parse_simple_string_sfx(&protocol[1..]), | ||||
|             Some('$') => Self::parse_bulk_string_sfx(&protocol[1..]), | ||||
| @@ -27,10 +27,7 @@ impl Protocol { | ||||
|                 protocol | ||||
|             ))), | ||||
|         }; | ||||
|         match ret { | ||||
|             Ok((p, s)) => Ok((p, s + 1)), | ||||
|             Err(e) => Err(e), | ||||
|         } | ||||
|         ret | ||||
|     } | ||||
|  | ||||
|     pub fn from_vec(array: Vec<&str>) -> Self { | ||||
| @@ -91,9 +88,9 @@ impl Protocol { | ||||
|         } | ||||
|     } | ||||
|  | ||||
|     fn parse_simple_string_sfx(protocol: &str) -> Result<(Self, usize), DBError> { | ||||
|     fn parse_simple_string_sfx(protocol: &str) -> Result<(Self, &str), DBError> { | ||||
|         match protocol.find("\r\n") { | ||||
|             Some(x) => Ok((Self::SimpleString(protocol[..x].to_string()), x + 2)), | ||||
|             Some(x) => Ok((Self::SimpleString(protocol[..x].to_string()), &protocol[x + 2..])), | ||||
|             _ => Err(DBError(format!( | ||||
|                 "[new simple string] unsupported protocol: {:?}", | ||||
|                 protocol | ||||
| @@ -101,27 +98,20 @@ impl Protocol { | ||||
|         } | ||||
|     } | ||||
|  | ||||
|     fn parse_bulk_string_sfx(protocol: &str) -> Result<(Self, usize), DBError> { | ||||
|         if let Some(len) = protocol.find("\r\n") { | ||||
|             let size = Self::parse_usize(&protocol[..len])?; | ||||
|             if let Some(data_len) = protocol[len + 2..].find("\r\n") { | ||||
|                 let s = Self::parse_string(&protocol[len + 2..len + 2 + data_len])?; | ||||
|                 if size != s.len() { | ||||
|                     Err(DBError(format!( | ||||
|                         "[new bulk string] unmatched string length in prototocl {:?}", | ||||
|                         protocol, | ||||
|                     ))) | ||||
|                 } else { | ||||
|                     Ok(( | ||||
|                         Protocol::BulkString(s.to_lowercase()), | ||||
|                         len + 2 + data_len + 2, | ||||
|                     )) | ||||
|                 } | ||||
|             } else { | ||||
|                 Err(DBError(format!( | ||||
|                     "[new bulk string] unsupported protocol: {:?}", | ||||
|                     protocol | ||||
|     fn parse_bulk_string_sfx(protocol: &str) -> Result<(Self, &str), DBError> { | ||||
|         if let Some(len_end) = protocol.find("\r\n") { | ||||
|             let size = Self::parse_usize(&protocol[..len_end])?; | ||||
|             let data_start = len_end + 2; | ||||
|             let data_end = data_start + size; | ||||
|             let s = Self::parse_string(&protocol[data_start..data_end])?; | ||||
|  | ||||
|             if protocol.len() < data_end + 2 || &protocol[data_end..data_end+2] != "\r\n" { | ||||
|                  Err(DBError(format!( | ||||
|                     "[new bulk string] unmatched string length in prototocl {:?}", | ||||
|                     protocol, | ||||
|                 ))) | ||||
|             } else { | ||||
|                 Ok((Protocol::BulkString(s), &protocol[data_end + 2..])) | ||||
|             } | ||||
|         } else { | ||||
|             Err(DBError(format!( | ||||
| @@ -131,30 +121,22 @@ impl Protocol { | ||||
|         } | ||||
|     } | ||||
|  | ||||
|     fn parse_array_sfx(s: &str) -> Result<(Self, usize), DBError> { | ||||
|         let mut offset = 0; | ||||
|         match s.find("\r\n") { | ||||
|             Some(x) => { | ||||
|                 let array_len = s[..x].parse::<usize>()?; | ||||
|                 offset += x + 2; | ||||
|                 let mut vec = vec![]; | ||||
|                 for _ in 0..array_len { | ||||
|                     match Protocol::from(&s[offset..]) { | ||||
|                         Ok((p, len)) => { | ||||
|                             offset += len; | ||||
|                             vec.push(p); | ||||
|                         } | ||||
|                         Err(e) => { | ||||
|                             return Err(e); | ||||
|                         } | ||||
|                     } | ||||
|                 } | ||||
|                 Ok((Protocol::Array(vec), offset)) | ||||
|     fn parse_array_sfx(s: &str) -> Result<(Self, &str), DBError> { | ||||
|         if let Some(len_end) = s.find("\r\n") { | ||||
|             let array_len = s[..len_end].parse::<usize>()?; | ||||
|             let mut remaining = &s[len_end + 2..]; | ||||
|             let mut vec = vec![]; | ||||
|             for _ in 0..array_len { | ||||
|                 let (p, rem) = Protocol::from(remaining)?; | ||||
|                 vec.push(p); | ||||
|                 remaining = rem; | ||||
|             } | ||||
|             _ => Err(DBError(format!( | ||||
|             Ok((Protocol::Array(vec), remaining)) | ||||
|         } else { | ||||
|             Err(DBError(format!( | ||||
|                 "[new array] unsupported protocol: {:?}", | ||||
|                 s | ||||
|             ))), | ||||
|             ))) | ||||
|         } | ||||
|     } | ||||
|  | ||||
|   | ||||
| @@ -41,20 +41,29 @@ impl Server { | ||||
|         let mut queued_cmd: Option<Vec<(Cmd, Protocol)>> = None; | ||||
|          | ||||
|         loop { | ||||
|             if let Ok(len) = stream.read(&mut buf).await { | ||||
|                 if len == 0 { | ||||
|             let len = match stream.read(&mut buf).await { | ||||
|                 Ok(0) => { | ||||
|                     println!("[handle] connection closed"); | ||||
|                     return Ok(()); | ||||
|                 } | ||||
|                  | ||||
|                 let s = str::from_utf8(&buf[..len])?; | ||||
|                 let (cmd, protocol) = match Cmd::from(s) { | ||||
|                     Ok((cmd, protocol)) => (cmd, protocol), | ||||
|                 Ok(len) => len, | ||||
|                 Err(e) => { | ||||
|                     println!("[handle] read error: {:?}", e); | ||||
|                     return Err(e.into()); | ||||
|                 } | ||||
|             }; | ||||
|  | ||||
|             let mut s = str::from_utf8(&buf[..len])?; | ||||
|             while !s.is_empty() { | ||||
|                 let (cmd, protocol, remaining) = match Cmd::from(s) { | ||||
|                     Ok((cmd, protocol, remaining)) => (cmd, protocol, remaining), | ||||
|                     Err(e) => { | ||||
|                         println!("\x1b[31;1mprotocol error: {:?}\x1b[0m", e); | ||||
|                         (Cmd::Unknow("protocol_error".to_string()), Protocol::err(&format!("protocol error: {}", e.0))) | ||||
|                         (Cmd::Unknow("protocol_error".to_string()), Protocol::err(&format!("protocol error: {}", e.0)), "") | ||||
|                     } | ||||
|                 }; | ||||
|                 s = remaining; | ||||
|  | ||||
|                 if self.option.debug { | ||||
|                     println!("\x1b[34;1mgot command: {:?}, protocol: {:?}\x1b[0m", cmd, protocol); | ||||
|                 } else { | ||||
| @@ -68,27 +77,22 @@ impl Server { | ||||
|                     .run(&mut self.clone(), protocol.clone(), &mut queued_cmd) | ||||
|                     .await | ||||
|                     .unwrap_or(Protocol::err("unknown cmd from server")); | ||||
|                  | ||||
|                 if self.option.debug { | ||||
|                     println!("\x1b[34;1mqueued cmd {:?}\x1b[0m", queued_cmd); | ||||
|                 } else { | ||||
|                     print!("queued cmd {:?}", queued_cmd); | ||||
|                 } | ||||
|  | ||||
|                 if self.option.debug { | ||||
|                     println!("\x1b[32;1mgoing to send response {}\x1b[0m", res.encode()); | ||||
|                 } else { | ||||
|                     print!("queued cmd {:?}", queued_cmd); | ||||
|                     println!("going to send response {}", res.encode()); | ||||
|                 } | ||||
|  | ||||
|                 _ = stream.write(res.encode().as_bytes()).await?; | ||||
|                  | ||||
|  | ||||
|                 // If this was a QUIT command, close the connection | ||||
|                 if is_quit { | ||||
|                     println!("[handle] QUIT command received, closing connection"); | ||||
|                     return Ok(()); | ||||
|                 } | ||||
|             } else { | ||||
|                 println!("[handle] going to break"); | ||||
|                 break; | ||||
|             } | ||||
|         } | ||||
|         Ok(()) | ||||
|   | ||||
							
								
								
									
										424
									
								
								src/storage.rs
									
									
									
									
									
								
							
							
						
						
									
										424
									
								
								src/storage.rs
									
									
									
									
									
								
							| @@ -12,6 +12,7 @@ use crate::error::DBError; | ||||
| const TYPES_TABLE: TableDefinition<&str, &str> = TableDefinition::new("types"); | ||||
| const STRINGS_TABLE: TableDefinition<&str, &[u8]> = TableDefinition::new("strings"); | ||||
| const HASHES_TABLE: TableDefinition<(&str, &str), &str> = TableDefinition::new("hashes"); | ||||
| const LISTS_TABLE: TableDefinition<&str, &[u8]> = TableDefinition::new("lists"); | ||||
| const STREAMS_META_TABLE: TableDefinition<&str, &[u8]> = TableDefinition::new("streams_meta"); | ||||
| const STREAMS_DATA_TABLE: TableDefinition<(&str, &str), &[u8]> = TableDefinition::new("streams_data"); | ||||
|  | ||||
| @@ -26,6 +27,12 @@ pub struct StreamEntry { | ||||
|     pub fields: Vec<(String, String)>, | ||||
| } | ||||
|  | ||||
| #[derive(Serialize, Deserialize, Debug, Clone)] | ||||
| pub struct ListValue { | ||||
|     pub elements: Vec<String>, | ||||
| } | ||||
|  | ||||
|  | ||||
| #[inline] | ||||
| pub fn now_in_millis() -> u128 { | ||||
|     let start = SystemTime::now(); | ||||
| @@ -47,6 +54,7 @@ impl Storage { | ||||
|             let _ = write_txn.open_table(TYPES_TABLE)?; | ||||
|             let _ = write_txn.open_table(STRINGS_TABLE)?; | ||||
|             let _ = write_txn.open_table(HASHES_TABLE)?; | ||||
|             let _ = write_txn.open_table(LISTS_TABLE)?; | ||||
|             let _ = write_txn.open_table(STREAMS_META_TABLE)?; | ||||
|             let _ = write_txn.open_table(STREAMS_DATA_TABLE)?; | ||||
|         } | ||||
| @@ -143,6 +151,7 @@ impl Storage { | ||||
|             let mut types_table = write_txn.open_table(TYPES_TABLE)?; | ||||
|             let mut strings_table = write_txn.open_table(STRINGS_TABLE)?; | ||||
|             let mut hashes_table = write_txn.open_table(HASHES_TABLE)?; | ||||
|             let mut lists_table = write_txn.open_table(LISTS_TABLE)?; | ||||
|              | ||||
|             // Remove from type table | ||||
|             types_table.remove(key.as_str())?; | ||||
| @@ -165,6 +174,9 @@ impl Storage { | ||||
|             for (hash_key, field) in to_remove { | ||||
|                 hashes_table.remove((hash_key.as_str(), field.as_str()))?; | ||||
|             } | ||||
|  | ||||
|             // Remove from lists table | ||||
|             lists_table.remove(key.as_str())?; | ||||
|         } | ||||
|          | ||||
|         write_txn.commit()?; | ||||
| @@ -633,4 +645,414 @@ impl Storage { | ||||
|             None => Ok(false), | ||||
|         } | ||||
|     } | ||||
| } | ||||
|      | ||||
|     // List operations | ||||
|     pub fn lpush(&self, key: &str, elements: Vec<String>) -> Result<u64, DBError> { | ||||
|         let write_txn = self.db.begin_write()?; | ||||
|         let mut new_len = 0u64; | ||||
|  | ||||
|         { | ||||
|             let mut types_table = write_txn.open_table(TYPES_TABLE)?; | ||||
|             let mut lists_table = write_txn.open_table(LISTS_TABLE)?; | ||||
|  | ||||
|             let existing_type = match types_table.get(key)? { | ||||
|                 Some(type_val) => Some(type_val.value().to_string()), | ||||
|                 None => None, | ||||
|             }; | ||||
|  | ||||
|             match existing_type { | ||||
|                 Some(ref type_str) if type_str != "list" => { | ||||
|                     return Err(DBError("WRONGTYPE Operation against a key holding the wrong kind of value".to_string())); | ||||
|                 } | ||||
|                 None => { | ||||
|                     types_table.insert(key, "list")?; | ||||
|                 } | ||||
|                 _ => {} | ||||
|             } | ||||
|  | ||||
|             let mut list_value: ListValue = match lists_table.get(key)? { | ||||
|                 Some(data) => bincode::deserialize(data.value())?, | ||||
|                 None => ListValue { elements: Vec::new() }, | ||||
|             }; | ||||
|  | ||||
|             for element in elements.into_iter().rev() { | ||||
|                 list_value.elements.insert(0, element); | ||||
|             } | ||||
|             new_len = list_value.elements.len() as u64; | ||||
|  | ||||
|             let serialized = bincode::serialize(&list_value)?; | ||||
|             lists_table.insert(key, serialized.as_slice())?; | ||||
|         } | ||||
|  | ||||
|         write_txn.commit()?; | ||||
|         Ok(new_len) | ||||
|     } | ||||
|  | ||||
|     pub fn rpush(&self, key: &str, elements: Vec<String>) -> Result<u64, DBError> { | ||||
|         let write_txn = self.db.begin_write()?; | ||||
|         let mut new_len = 0u64; | ||||
|  | ||||
|         { | ||||
|             let mut types_table = write_txn.open_table(TYPES_TABLE)?; | ||||
|             let mut lists_table = write_txn.open_table(LISTS_TABLE)?; | ||||
|  | ||||
|             let existing_type = match types_table.get(key)? { | ||||
|                 Some(type_val) => Some(type_val.value().to_string()), | ||||
|                 None => None, | ||||
|             }; | ||||
|  | ||||
|             match existing_type { | ||||
|                 Some(ref type_str) if type_str != "list" => { | ||||
|                     return Err(DBError("WRONGTYPE Operation against a key holding the wrong kind of value".to_string())); | ||||
|                 } | ||||
|                 None => { | ||||
|                     types_table.insert(key, "list")?; | ||||
|                 } | ||||
|                 _ => {} | ||||
|             } | ||||
|  | ||||
|             let mut list_value: ListValue = match lists_table.get(key)? { | ||||
|                 Some(data) => bincode::deserialize(data.value())?, | ||||
|                 None => ListValue { elements: Vec::new() }, | ||||
|             }; | ||||
|  | ||||
|             for element in elements { | ||||
|                 list_value.elements.push(element); | ||||
|             } | ||||
|             new_len = list_value.elements.len() as u64; | ||||
|  | ||||
|             let serialized = bincode::serialize(&list_value)?; | ||||
|             lists_table.insert(key, serialized.as_slice())?; | ||||
|         } | ||||
|  | ||||
|         write_txn.commit()?; | ||||
|         Ok(new_len) | ||||
|     } | ||||
|  | ||||
|     pub fn lpop(&self, key: &str, count: Option<u64>) -> Result<Option<Vec<String>>, DBError> { | ||||
|         let write_txn = self.db.begin_write()?; | ||||
|         let mut result_elements = Vec::new(); | ||||
|  | ||||
|         { | ||||
|             let mut types_table = write_txn.open_table(TYPES_TABLE)?; | ||||
|             let mut lists_table = write_txn.open_table(LISTS_TABLE)?; | ||||
|  | ||||
|             let existing_type = match types_table.get(key)? { | ||||
|                 Some(type_val) => Some(type_val.value().to_string()), | ||||
|                 None => None, | ||||
|             }; | ||||
|  | ||||
|             match existing_type { | ||||
|                 Some(ref type_str) if type_str != "list" => { | ||||
|                     return Err(DBError("WRONGTYPE Operation against a key holding the wrong kind of value".to_string())); | ||||
|                 } | ||||
|                 Some(_) => { | ||||
|                     let mut list_value: ListValue = match lists_table.get(key)? { | ||||
|                         Some(data) => bincode::deserialize(data.value())?, | ||||
|                         None => return Ok(None), // Key exists but list is empty (shouldn't happen if type is "list") | ||||
|                     }; | ||||
|  | ||||
|                     let num_to_pop = count.unwrap_or(1) as usize; | ||||
|                     for _ in 0..num_to_pop { | ||||
|                         if !list_value.elements.is_empty() { | ||||
|                             result_elements.push(list_value.elements.remove(0)); | ||||
|                         } else { | ||||
|                             break; | ||||
|                         } | ||||
|                     } | ||||
|  | ||||
|                     if list_value.elements.is_empty() { | ||||
|                         lists_table.remove(key)?; | ||||
|                         types_table.remove(key)?; | ||||
|                     } else { | ||||
|                         let serialized = bincode::serialize(&list_value)?; | ||||
|                         lists_table.insert(key, serialized.as_slice())?; | ||||
|                     } | ||||
|                 } | ||||
|                 None => return Ok(None), | ||||
|             } | ||||
|         } | ||||
|  | ||||
|         write_txn.commit()?; | ||||
|         if result_elements.is_empty() { | ||||
|             Ok(None) | ||||
|         } else { | ||||
|             Ok(Some(result_elements)) | ||||
|         } | ||||
|     } | ||||
|  | ||||
|     pub fn rpop(&self, key: &str, count: Option<u64>) -> Result<Option<Vec<String>>, DBError> { | ||||
|         let write_txn = self.db.begin_write()?; | ||||
|         let mut result_elements = Vec::new(); | ||||
|  | ||||
|         { | ||||
|             let mut types_table = write_txn.open_table(TYPES_TABLE)?; | ||||
|             let mut lists_table = write_txn.open_table(LISTS_TABLE)?; | ||||
|  | ||||
|             let existing_type = match types_table.get(key)? { | ||||
|                 Some(type_val) => Some(type_val.value().to_string()), | ||||
|                 None => None, | ||||
|             }; | ||||
|  | ||||
|             match existing_type { | ||||
|                 Some(ref type_str) if type_str != "list" => { | ||||
|                     return Err(DBError("WRONGTYPE Operation against a key holding the wrong kind of value".to_string())); | ||||
|                 } | ||||
|                 Some(_) => { | ||||
|                     let mut list_value: ListValue = match lists_table.get(key)? { | ||||
|                         Some(data) => bincode::deserialize(data.value())?, | ||||
|                         None => return Ok(None), | ||||
|                     }; | ||||
|  | ||||
|                     let num_to_pop = count.unwrap_or(1) as usize; | ||||
|                     for _ in 0..num_to_pop { | ||||
|                         if let Some(element) = list_value.elements.pop() { | ||||
|                             result_elements.push(element); | ||||
|                         } else { | ||||
|                             break; | ||||
|                         } | ||||
|                     } | ||||
|  | ||||
|                     if list_value.elements.is_empty() { | ||||
|                         lists_table.remove(key)?; | ||||
|                         types_table.remove(key)?; | ||||
|                     } else { | ||||
|                         let serialized = bincode::serialize(&list_value)?; | ||||
|                         lists_table.insert(key, serialized.as_slice())?; | ||||
|                     } | ||||
|                 } | ||||
|                 None => return Ok(None), | ||||
|             } | ||||
|         } | ||||
|  | ||||
|         write_txn.commit()?; | ||||
|         if result_elements.is_empty() { | ||||
|             Ok(None) | ||||
|         } else { | ||||
|             Ok(Some(result_elements)) | ||||
|         } | ||||
|     } | ||||
|  | ||||
|     pub fn llen(&self, key: &str) -> Result<u64, DBError> { | ||||
|         let read_txn = self.db.begin_read()?; | ||||
|          | ||||
|         let types_table = read_txn.open_table(TYPES_TABLE)?; | ||||
|         match types_table.get(key)? { | ||||
|             Some(type_val) if type_val.value() == "list" => { | ||||
|                 let lists_table = read_txn.open_table(LISTS_TABLE)?; | ||||
|                 match lists_table.get(key)? { | ||||
|                     Some(data) => { | ||||
|                         let list_value: ListValue = bincode::deserialize(data.value())?; | ||||
|                         Ok(list_value.elements.len() as u64) | ||||
|                     } | ||||
|                     None => Ok(0), // Key exists but list is empty | ||||
|                 } | ||||
|             } | ||||
|             Some(_) => Err(DBError("WRONGTYPE Operation against a key holding the wrong kind of value".to_string())), | ||||
|             None => Ok(0), // Key does not exist | ||||
|         } | ||||
|     } | ||||
|      | ||||
|     pub fn lrem(&self, key: &str, count: i64, element: &str) -> Result<u64, DBError> { | ||||
|         let write_txn = self.db.begin_write()?; | ||||
|         let mut removed_count = 0u64; | ||||
|  | ||||
|         { | ||||
|             let mut types_table = write_txn.open_table(TYPES_TABLE)?; | ||||
|             let mut lists_table = write_txn.open_table(LISTS_TABLE)?; | ||||
|  | ||||
|             let existing_type = match types_table.get(key)? { | ||||
|                 Some(type_val) => Some(type_val.value().to_string()), | ||||
|                 None => None, | ||||
|             }; | ||||
|  | ||||
|             match existing_type { | ||||
|                 Some(ref type_str) if type_str != "list" => { | ||||
|                     return Err(DBError("WRONGTYPE Operation against a key holding the wrong kind of value".to_string())); | ||||
|                 } | ||||
|                 Some(_) => { | ||||
|                     let mut list_value: ListValue = match lists_table.get(key)? { | ||||
|                         Some(data) => bincode::deserialize(data.value())?, | ||||
|                         None => return Ok(0), | ||||
|                     }; | ||||
|  | ||||
|                     let initial_len = list_value.elements.len(); | ||||
|                      | ||||
|                     if count > 0 { | ||||
|                         let mut i = 0; | ||||
|                         let mut removed = 0; | ||||
|                         while i < list_value.elements.len() && removed < count { | ||||
|                             if list_value.elements[i] == element { | ||||
|                                 list_value.elements.remove(i); | ||||
|                                 removed += 1; | ||||
|                             } else { | ||||
|                                 i += 1; | ||||
|                             } | ||||
|                         } | ||||
|                     } else if count < 0 { | ||||
|                         let mut i = list_value.elements.len() as i32 - 1; | ||||
|                         let mut removed = 0; | ||||
|                         while i >= 0 && removed < -count { | ||||
|                             if list_value.elements[i as usize] == element { | ||||
|                                 list_value.elements.remove(i as usize); | ||||
|                                 removed += 1; | ||||
|                             } | ||||
|                             i -= 1; | ||||
|                         } | ||||
|                     } else { // count == 0 | ||||
|                         list_value.elements.retain(|el| el != element); | ||||
|                     } | ||||
|  | ||||
|                     removed_count = (initial_len - list_value.elements.len()) as u64; | ||||
|  | ||||
|                     if list_value.elements.is_empty() { | ||||
|                         lists_table.remove(key)?; | ||||
|                         types_table.remove(key)?; | ||||
|                     } else { | ||||
|                         let serialized = bincode::serialize(&list_value)?; | ||||
|                         lists_table.insert(key, serialized.as_slice())?; | ||||
|                     } | ||||
|                 } | ||||
|                 None => return Ok(0), | ||||
|             } | ||||
|         } | ||||
|  | ||||
|         write_txn.commit()?; | ||||
|         Ok(removed_count) | ||||
|     } | ||||
|  | ||||
|     pub fn ltrim(&self, key: &str, start: i64, stop: i64) -> Result<(), DBError> { | ||||
|         let write_txn = self.db.begin_write()?; | ||||
|          | ||||
|         { | ||||
|             let mut types_table = write_txn.open_table(TYPES_TABLE)?; | ||||
|             let mut lists_table = write_txn.open_table(LISTS_TABLE)?; | ||||
|  | ||||
|             let existing_type = match types_table.get(key)? { | ||||
|                 Some(type_val) => Some(type_val.value().to_string()), | ||||
|                 None => None, | ||||
|             }; | ||||
|              | ||||
|             match existing_type { | ||||
|                 Some(ref type_str) if type_str != "list" => { | ||||
|                     return Err(DBError("WRONGTYPE Operation against a key holding the wrong kind of value".to_string())); | ||||
|                 } | ||||
|                 Some(_) => { | ||||
|                     let mut list_value: ListValue = match lists_table.get(key)? { | ||||
|                         Some(data) => bincode::deserialize(data.value())?, | ||||
|                         None => return Ok(()), | ||||
|                     }; | ||||
|  | ||||
|                     let len = list_value.elements.len() as i64; | ||||
|                     let mut start = start; | ||||
|                     let mut stop = stop; | ||||
|  | ||||
|                     if start < 0 { | ||||
|                         start += len; | ||||
|                     } | ||||
|                     if stop < 0 { | ||||
|                         stop += len; | ||||
|                     } | ||||
|  | ||||
|                     if start < 0 { | ||||
|                         start = 0; | ||||
|                     } | ||||
|  | ||||
|                     if start > stop || start >= len { | ||||
|                         list_value.elements.clear(); | ||||
|                     } else { | ||||
|                         if stop >= len { | ||||
|                             stop = len - 1; | ||||
|                         } | ||||
|                         let start = start as usize; | ||||
|                         let stop = stop as usize; | ||||
|                         list_value.elements = list_value.elements.drain(start..=stop).collect(); | ||||
|                     } | ||||
|  | ||||
|                     if list_value.elements.is_empty() { | ||||
|                         lists_table.remove(key)?; | ||||
|                         types_table.remove(key)?; | ||||
|                     } else { | ||||
|                         let serialized = bincode::serialize(&list_value)?; | ||||
|                         lists_table.insert(key, serialized.as_slice())?; | ||||
|                     } | ||||
|                 } | ||||
|                 None => {} | ||||
|             } | ||||
|         } | ||||
|          | ||||
|         write_txn.commit()?; | ||||
|         Ok(()) | ||||
|     } | ||||
|  | ||||
|     pub fn lindex(&self, key: &str, index: i64) -> Result<Option<String>, DBError> { | ||||
|         let read_txn = self.db.begin_read()?; | ||||
|          | ||||
|         let types_table = read_txn.open_table(TYPES_TABLE)?; | ||||
|         match types_table.get(key)? { | ||||
|             Some(type_val) if type_val.value() == "list" => { | ||||
|                 let lists_table = read_txn.open_table(LISTS_TABLE)?; | ||||
|                 match lists_table.get(key)? { | ||||
|                     Some(data) => { | ||||
|                         let list_value: ListValue = bincode::deserialize(data.value())?; | ||||
|                         let len = list_value.elements.len() as i64; | ||||
|                         let mut index = index; | ||||
|                         if index < 0 { | ||||
|                             index += len; | ||||
|                         } | ||||
|                         if index < 0 || index >= len { | ||||
|                             Ok(None) | ||||
|                         } else { | ||||
|                             Ok(list_value.elements.get(index as usize).cloned()) | ||||
|                         } | ||||
|                     } | ||||
|                     None => Ok(None), | ||||
|                 } | ||||
|             } | ||||
|             Some(_) => Err(DBError("WRONGTYPE Operation against a key holding the wrong kind of value".to_string())), | ||||
|             None => Ok(None), | ||||
|         } | ||||
|     } | ||||
|  | ||||
|     pub fn lrange(&self, key: &str, start: i64, stop: i64) -> Result<Vec<String>, DBError> { | ||||
|         let read_txn = self.db.begin_read()?; | ||||
|          | ||||
|         let types_table = read_txn.open_table(TYPES_TABLE)?; | ||||
|         match types_table.get(key)? { | ||||
|             Some(type_val) if type_val.value() == "list" => { | ||||
|                 let lists_table = read_txn.open_table(LISTS_TABLE)?; | ||||
|                 match lists_table.get(key)? { | ||||
|                     Some(data) => { | ||||
|                         let list_value: ListValue = bincode::deserialize(data.value())?; | ||||
|                         let len = list_value.elements.len() as i64; | ||||
|                         let mut start = start; | ||||
|                         let mut stop = stop; | ||||
|  | ||||
|                         if start < 0 { | ||||
|                             start += len; | ||||
|                         } | ||||
|                         if stop < 0 { | ||||
|                             stop += len; | ||||
|                         } | ||||
|  | ||||
|                         if start < 0 { | ||||
|                             start = 0; | ||||
|                         } | ||||
|                          | ||||
|                         if start > stop || start >= len { | ||||
|                             Ok(Vec::new()) | ||||
|                         } else { | ||||
|                             if stop >= len { | ||||
|                                 stop = len - 1; | ||||
|                             } | ||||
|                             Ok(list_value.elements[start as usize..=stop as usize].to_vec()) | ||||
|                         } | ||||
|                     } | ||||
|                     None => Ok(Vec::new()), | ||||
|                 } | ||||
|             } | ||||
|             Some(_) => Err(DBError("WRONGTYPE Operation against a key holding the wrong kind of value".to_string())), | ||||
|             None => Ok(Vec::new()), | ||||
|         } | ||||
|     } | ||||
| } | ||||
| @@ -12,7 +12,8 @@ async fn start_test_server(test_name: &str) -> (Server, u16) { | ||||
|     let port = PORT_COUNTER.fetch_add(1, Ordering::SeqCst); | ||||
|     let test_dir = format!("/tmp/herodb_test_{}", test_name); | ||||
|      | ||||
|     // Create test directory | ||||
|     // Clean up and create test directory | ||||
|     let _ = std::fs::remove_dir_all(&test_dir); | ||||
|     std::fs::create_dir_all(&test_dir).unwrap(); | ||||
|      | ||||
|     let option = DBOption { | ||||
| @@ -196,7 +197,7 @@ async fn test_hash_operations() { | ||||
|     let response = send_command(&mut stream, "*3\r\n$7\r\nHEXISTS\r\n$4\r\nhash\r\n$6\r\nfield1\r\n").await; | ||||
|     assert!(response.contains("1")); | ||||
|      | ||||
|     let response = send_command(&mut stream, "*3\r\n$7\r\nHEXISTS\r\n$4\r\nhash\r\n$8\r\nnoexist\r\n").await; | ||||
|     let response = send_command(&mut stream, "*3\r\n$7\r\nHEXISTS\r\n$4\r\nhash\r\n$7\r\nnoexist\r\n").await; | ||||
|     assert!(response.contains("0")); | ||||
|      | ||||
|     // Test HDEL | ||||
| @@ -442,7 +443,7 @@ async fn test_type_command() { | ||||
|     assert!(response.contains("hash")); | ||||
|      | ||||
|     // Test non-existent key | ||||
|     let response = send_command(&mut stream, "*2\r\n$4\r\nTYPE\r\n$8\r\nnoexist\r\n").await; | ||||
|     let response = send_command(&mut stream, "*2\r\n$4\r\nTYPE\r\n$7\r\nnoexist\r\n").await; | ||||
|     assert!(response.contains("none")); | ||||
| } | ||||
|  | ||||
|   | ||||
		Reference in New Issue
	
	Block a user