From bec9b20ec7a6cc30b4a8d279adcc857e944e612c Mon Sep 17 00:00:00 2001 From: despiegk Date: Sat, 16 Aug 2025 08:41:19 +0200 Subject: [PATCH] ... --- src/cmd.rs | 188 ++++++++++++++++++- src/protocol.rs | 78 +++----- src/server.rs | 36 ++-- src/storage.rs | 424 ++++++++++++++++++++++++++++++++++++++++++- tests/redis_tests.rs | 7 +- 5 files changed, 660 insertions(+), 73 deletions(-) diff --git a/src/cmd.rs b/src/cmd.rs index 9b8f6c2..3fbd367 100644 --- a/src/cmd.rs +++ b/src/cmd.rs @@ -39,13 +39,20 @@ pub enum Cmd { // List commands LPush(String, Vec), RPush(String, Vec), + LPop(String, Option), + RPop(String, Option), + LLen(String), + LRem(String, i64, String), + LTrim(String, i64, i64), + LIndex(String, i64), + LRange(String, i64, i64), Unknow(String), } impl Cmd { - pub fn from(s: &str) -> Result<(Self, Protocol), DBError> { - let protocol = Protocol::from(s)?; - match protocol.clone().0 { + pub fn from(s: &str) -> Result<(Self, Protocol, &str), DBError> { + let (protocol, remaining) = Protocol::from(s)?; + match protocol.clone() { Protocol::Array(p) => { let cmd = p.into_iter().map(|x| x.decode()).collect::>(); if cmd.is_empty() { @@ -303,14 +310,85 @@ impl Cmd { Cmd::Client(vec![]) } } + "lpush" => { + if cmd.len() < 3 { + return Err(DBError(format!("wrong number of arguments for LPUSH command"))); + } + Cmd::LPush(cmd[1].clone(), cmd[2..].to_vec()) + } + "rpush" => { + if cmd.len() < 3 { + return Err(DBError(format!("wrong number of arguments for RPUSH command"))); + } + Cmd::RPush(cmd[1].clone(), cmd[2..].to_vec()) + } + "lpop" => { + if cmd.len() < 2 || cmd.len() > 3 { + return Err(DBError(format!("wrong number of arguments for LPOP command"))); + } + let count = if cmd.len() == 3 { + Some(cmd[2].parse::().map_err(|_| DBError("ERR value is not an integer or out of range".to_string()))?) + } else { + None + }; + Cmd::LPop(cmd[1].clone(), count) + } + "rpop" => { + if cmd.len() < 2 || cmd.len() > 3 { + return Err(DBError(format!("wrong number of arguments for RPOP command"))); + } + let count = if cmd.len() == 3 { + Some(cmd[2].parse::().map_err(|_| DBError("ERR value is not an integer or out of range".to_string()))?) + } else { + None + }; + Cmd::RPop(cmd[1].clone(), count) + } + "llen" => { + if cmd.len() != 2 { + return Err(DBError(format!("wrong number of arguments for LLEN command"))); + } + Cmd::LLen(cmd[1].clone()) + } + "lrem" => { + if cmd.len() != 4 { + return Err(DBError(format!("wrong number of arguments for LREM command"))); + } + let count = cmd[2].parse::().map_err(|_| DBError("ERR value is not an integer or out of range".to_string()))?; + Cmd::LRem(cmd[1].clone(), count, cmd[3].clone()) + } + "ltrim" => { + if cmd.len() != 4 { + return Err(DBError(format!("wrong number of arguments for LTRIM command"))); + } + let start = cmd[2].parse::().map_err(|_| DBError("ERR value is not an integer or out of range".to_string()))?; + let stop = cmd[3].parse::().map_err(|_| DBError("ERR value is not an integer or out of range".to_string()))?; + Cmd::LTrim(cmd[1].clone(), start, stop) + } + "lindex" => { + if cmd.len() != 3 { + return Err(DBError(format!("wrong number of arguments for LINDEX command"))); + } + let index = cmd[2].parse::().map_err(|_| DBError("ERR value is not an integer or out of range".to_string()))?; + Cmd::LIndex(cmd[1].clone(), index) + } + "lrange" => { + if cmd.len() != 4 { + return Err(DBError(format!("wrong number of arguments for LRANGE command"))); + } + let start = cmd[2].parse::().map_err(|_| DBError("ERR value is not an integer or out of range".to_string()))?; + let stop = cmd[3].parse::().map_err(|_| DBError("ERR value is not an integer or out of range".to_string()))?; + Cmd::LRange(cmd[1].clone(), start, stop) + } _ => Cmd::Unknow(cmd[0].clone()), }, - protocol.0, + protocol, + remaining )) } _ => Err(DBError(format!( "fail to parse as cmd for {:?}", - protocol.0 + protocol ))), } } @@ -379,6 +457,16 @@ impl Cmd { Cmd::Client(_) => Ok(Protocol::SimpleString("OK".to_string())), Cmd::ClientSetName(name) => client_setname_cmd(server, name).await, Cmd::ClientGetName => client_getname_cmd(server).await, + // List commands + Cmd::LPush(key, elements) => lpush_cmd(server, key, elements).await, + Cmd::RPush(key, elements) => rpush_cmd(server, key, elements).await, + Cmd::LPop(key, count) => lpop_cmd(server, key, count).await, + Cmd::RPop(key, count) => rpop_cmd(server, key, count).await, + Cmd::LLen(key) => llen_cmd(server, key).await, + Cmd::LRem(key, count, element) => lrem_cmd(server, key, *count, element).await, + Cmd::LTrim(key, start, stop) => ltrim_cmd(server, key, *start, *stop).await, + Cmd::LIndex(key, index) => lindex_cmd(server, key, *index).await, + Cmd::LRange(key, start, stop) => lrange_cmd(server, key, *start, *stop).await, Cmd::Unknow(s) => { println!("\x1b[31;1munknown command: {}\x1b[0m", s); Ok(Protocol::err(&format!("ERR unknown command '{}'", s))) @@ -387,6 +475,96 @@ impl Cmd { } } +async fn lindex_cmd(server: &Server, key: &str, index: i64) -> Result { + match server.storage.lindex(key, index) { + Ok(Some(element)) => Ok(Protocol::BulkString(element)), + Ok(None) => Ok(Protocol::Null), + Err(e) => Ok(Protocol::err(&e.0)), + } +} + +async fn lrange_cmd(server: &Server, key: &str, start: i64, stop: i64) -> Result { + match server.storage.lrange(key, start, stop) { + Ok(elements) => Ok(Protocol::Array(elements.into_iter().map(Protocol::BulkString).collect())), + Err(e) => Ok(Protocol::err(&e.0)), + } +} + +async fn ltrim_cmd(server: &Server, key: &str, start: i64, stop: i64) -> Result { + match server.storage.ltrim(key, start, stop) { + Ok(_) => Ok(Protocol::SimpleString("OK".to_string())), + Err(e) => Ok(Protocol::err(&e.0)), + } +} + +async fn lrem_cmd(server: &Server, key: &str, count: i64, element: &str) -> Result { + match server.storage.lrem(key, count, element) { + Ok(removed_count) => Ok(Protocol::SimpleString(removed_count.to_string())), + Err(e) => Ok(Protocol::err(&e.0)), + } +} + +async fn llen_cmd(server: &Server, key: &str) -> Result { + match server.storage.llen(key) { + Ok(len) => Ok(Protocol::SimpleString(len.to_string())), + Err(e) => Ok(Protocol::err(&e.0)), + } +} + +async fn lpop_cmd(server: &Server, key: &str, count: &Option) -> Result { + match server.storage.lpop(key, *count) { + Ok(Some(elements)) => { + if count.is_some() { + Ok(Protocol::Array(elements.into_iter().map(Protocol::BulkString).collect())) + } else { + Ok(Protocol::BulkString(elements[0].clone())) + } + }, + Ok(None) => { + if count.is_some() { + Ok(Protocol::Array(vec![])) + } else { + Ok(Protocol::Null) + } + }, + Err(e) => Ok(Protocol::err(&e.0)), + } +} + +async fn rpop_cmd(server: &Server, key: &str, count: &Option) -> Result { + match server.storage.rpop(key, *count) { + Ok(Some(elements)) => { + if count.is_some() { + Ok(Protocol::Array(elements.into_iter().map(Protocol::BulkString).collect())) + } else { + Ok(Protocol::BulkString(elements[0].clone())) + } + }, + Ok(None) => { + if count.is_some() { + Ok(Protocol::Array(vec![])) + } else { + Ok(Protocol::Null) + } + }, + Err(e) => Ok(Protocol::err(&e.0)), + } +} + +async fn lpush_cmd(server: &Server, key: &str, elements: &[String]) -> Result { + match server.storage.lpush(key, elements.to_vec()) { + Ok(len) => Ok(Protocol::SimpleString(len.to_string())), + Err(e) => Ok(Protocol::err(&e.0)), + } +} + +async fn rpush_cmd(server: &Server, key: &str, elements: &[String]) -> Result { + match server.storage.rpush(key, elements.to_vec()) { + Ok(len) => Ok(Protocol::SimpleString(len.to_string())), + Err(e) => Ok(Protocol::err(&e.0)), + } +} + async fn exec_cmd( queued_cmd: &mut Option>, server: &mut Server, diff --git a/src/protocol.rs b/src/protocol.rs index c76402c..ad42309 100644 --- a/src/protocol.rs +++ b/src/protocol.rs @@ -17,7 +17,7 @@ impl fmt::Display for Protocol { } impl Protocol { - pub fn from(protocol: &str) -> Result<(Self, usize), DBError> { + pub fn from(protocol: &str) -> Result<(Self, &str), DBError> { let ret = match protocol.chars().nth(0) { Some('+') => Self::parse_simple_string_sfx(&protocol[1..]), Some('$') => Self::parse_bulk_string_sfx(&protocol[1..]), @@ -27,10 +27,7 @@ impl Protocol { protocol ))), }; - match ret { - Ok((p, s)) => Ok((p, s + 1)), - Err(e) => Err(e), - } + ret } pub fn from_vec(array: Vec<&str>) -> Self { @@ -91,9 +88,9 @@ impl Protocol { } } - fn parse_simple_string_sfx(protocol: &str) -> Result<(Self, usize), DBError> { + fn parse_simple_string_sfx(protocol: &str) -> Result<(Self, &str), DBError> { match protocol.find("\r\n") { - Some(x) => Ok((Self::SimpleString(protocol[..x].to_string()), x + 2)), + Some(x) => Ok((Self::SimpleString(protocol[..x].to_string()), &protocol[x + 2..])), _ => Err(DBError(format!( "[new simple string] unsupported protocol: {:?}", protocol @@ -101,27 +98,20 @@ impl Protocol { } } - fn parse_bulk_string_sfx(protocol: &str) -> Result<(Self, usize), DBError> { - if let Some(len) = protocol.find("\r\n") { - let size = Self::parse_usize(&protocol[..len])?; - if let Some(data_len) = protocol[len + 2..].find("\r\n") { - let s = Self::parse_string(&protocol[len + 2..len + 2 + data_len])?; - if size != s.len() { - Err(DBError(format!( - "[new bulk string] unmatched string length in prototocl {:?}", - protocol, - ))) - } else { - Ok(( - Protocol::BulkString(s.to_lowercase()), - len + 2 + data_len + 2, - )) - } - } else { - Err(DBError(format!( - "[new bulk string] unsupported protocol: {:?}", - protocol + fn parse_bulk_string_sfx(protocol: &str) -> Result<(Self, &str), DBError> { + if let Some(len_end) = protocol.find("\r\n") { + let size = Self::parse_usize(&protocol[..len_end])?; + let data_start = len_end + 2; + let data_end = data_start + size; + let s = Self::parse_string(&protocol[data_start..data_end])?; + + if protocol.len() < data_end + 2 || &protocol[data_end..data_end+2] != "\r\n" { + Err(DBError(format!( + "[new bulk string] unmatched string length in prototocl {:?}", + protocol, ))) + } else { + Ok((Protocol::BulkString(s), &protocol[data_end + 2..])) } } else { Err(DBError(format!( @@ -131,30 +121,22 @@ impl Protocol { } } - fn parse_array_sfx(s: &str) -> Result<(Self, usize), DBError> { - let mut offset = 0; - match s.find("\r\n") { - Some(x) => { - let array_len = s[..x].parse::()?; - offset += x + 2; - let mut vec = vec![]; - for _ in 0..array_len { - match Protocol::from(&s[offset..]) { - Ok((p, len)) => { - offset += len; - vec.push(p); - } - Err(e) => { - return Err(e); - } - } - } - Ok((Protocol::Array(vec), offset)) + fn parse_array_sfx(s: &str) -> Result<(Self, &str), DBError> { + if let Some(len_end) = s.find("\r\n") { + let array_len = s[..len_end].parse::()?; + let mut remaining = &s[len_end + 2..]; + let mut vec = vec![]; + for _ in 0..array_len { + let (p, rem) = Protocol::from(remaining)?; + vec.push(p); + remaining = rem; } - _ => Err(DBError(format!( + Ok((Protocol::Array(vec), remaining)) + } else { + Err(DBError(format!( "[new array] unsupported protocol: {:?}", s - ))), + ))) } } diff --git a/src/server.rs b/src/server.rs index 61b753c..c0c8e1e 100644 --- a/src/server.rs +++ b/src/server.rs @@ -41,20 +41,29 @@ impl Server { let mut queued_cmd: Option> = None; loop { - if let Ok(len) = stream.read(&mut buf).await { - if len == 0 { + let len = match stream.read(&mut buf).await { + Ok(0) => { println!("[handle] connection closed"); return Ok(()); } - - let s = str::from_utf8(&buf[..len])?; - let (cmd, protocol) = match Cmd::from(s) { - Ok((cmd, protocol)) => (cmd, protocol), + Ok(len) => len, + Err(e) => { + println!("[handle] read error: {:?}", e); + return Err(e.into()); + } + }; + + let mut s = str::from_utf8(&buf[..len])?; + while !s.is_empty() { + let (cmd, protocol, remaining) = match Cmd::from(s) { + Ok((cmd, protocol, remaining)) => (cmd, protocol, remaining), Err(e) => { println!("\x1b[31;1mprotocol error: {:?}\x1b[0m", e); - (Cmd::Unknow("protocol_error".to_string()), Protocol::err(&format!("protocol error: {}", e.0))) + (Cmd::Unknow("protocol_error".to_string()), Protocol::err(&format!("protocol error: {}", e.0)), "") } }; + s = remaining; + if self.option.debug { println!("\x1b[34;1mgot command: {:?}, protocol: {:?}\x1b[0m", cmd, protocol); } else { @@ -68,27 +77,22 @@ impl Server { .run(&mut self.clone(), protocol.clone(), &mut queued_cmd) .await .unwrap_or(Protocol::err("unknown cmd from server")); + if self.option.debug { println!("\x1b[34;1mqueued cmd {:?}\x1b[0m", queued_cmd); - } else { - print!("queued cmd {:?}", queued_cmd); - } - - if self.option.debug { println!("\x1b[32;1mgoing to send response {}\x1b[0m", res.encode()); } else { + print!("queued cmd {:?}", queued_cmd); println!("going to send response {}", res.encode()); } + _ = stream.write(res.encode().as_bytes()).await?; - + // If this was a QUIT command, close the connection if is_quit { println!("[handle] QUIT command received, closing connection"); return Ok(()); } - } else { - println!("[handle] going to break"); - break; } } Ok(()) diff --git a/src/storage.rs b/src/storage.rs index 6550c9a..7436cbc 100644 --- a/src/storage.rs +++ b/src/storage.rs @@ -12,6 +12,7 @@ use crate::error::DBError; const TYPES_TABLE: TableDefinition<&str, &str> = TableDefinition::new("types"); const STRINGS_TABLE: TableDefinition<&str, &[u8]> = TableDefinition::new("strings"); const HASHES_TABLE: TableDefinition<(&str, &str), &str> = TableDefinition::new("hashes"); +const LISTS_TABLE: TableDefinition<&str, &[u8]> = TableDefinition::new("lists"); const STREAMS_META_TABLE: TableDefinition<&str, &[u8]> = TableDefinition::new("streams_meta"); const STREAMS_DATA_TABLE: TableDefinition<(&str, &str), &[u8]> = TableDefinition::new("streams_data"); @@ -26,6 +27,12 @@ pub struct StreamEntry { pub fields: Vec<(String, String)>, } +#[derive(Serialize, Deserialize, Debug, Clone)] +pub struct ListValue { + pub elements: Vec, +} + + #[inline] pub fn now_in_millis() -> u128 { let start = SystemTime::now(); @@ -47,6 +54,7 @@ impl Storage { let _ = write_txn.open_table(TYPES_TABLE)?; let _ = write_txn.open_table(STRINGS_TABLE)?; let _ = write_txn.open_table(HASHES_TABLE)?; + let _ = write_txn.open_table(LISTS_TABLE)?; let _ = write_txn.open_table(STREAMS_META_TABLE)?; let _ = write_txn.open_table(STREAMS_DATA_TABLE)?; } @@ -143,6 +151,7 @@ impl Storage { let mut types_table = write_txn.open_table(TYPES_TABLE)?; let mut strings_table = write_txn.open_table(STRINGS_TABLE)?; let mut hashes_table = write_txn.open_table(HASHES_TABLE)?; + let mut lists_table = write_txn.open_table(LISTS_TABLE)?; // Remove from type table types_table.remove(key.as_str())?; @@ -165,6 +174,9 @@ impl Storage { for (hash_key, field) in to_remove { hashes_table.remove((hash_key.as_str(), field.as_str()))?; } + + // Remove from lists table + lists_table.remove(key.as_str())?; } write_txn.commit()?; @@ -633,4 +645,414 @@ impl Storage { None => Ok(false), } } -} + + // List operations + pub fn lpush(&self, key: &str, elements: Vec) -> Result { + let write_txn = self.db.begin_write()?; + let mut new_len = 0u64; + + { + let mut types_table = write_txn.open_table(TYPES_TABLE)?; + let mut lists_table = write_txn.open_table(LISTS_TABLE)?; + + let existing_type = match types_table.get(key)? { + Some(type_val) => Some(type_val.value().to_string()), + None => None, + }; + + match existing_type { + Some(ref type_str) if type_str != "list" => { + return Err(DBError("WRONGTYPE Operation against a key holding the wrong kind of value".to_string())); + } + None => { + types_table.insert(key, "list")?; + } + _ => {} + } + + let mut list_value: ListValue = match lists_table.get(key)? { + Some(data) => bincode::deserialize(data.value())?, + None => ListValue { elements: Vec::new() }, + }; + + for element in elements.into_iter().rev() { + list_value.elements.insert(0, element); + } + new_len = list_value.elements.len() as u64; + + let serialized = bincode::serialize(&list_value)?; + lists_table.insert(key, serialized.as_slice())?; + } + + write_txn.commit()?; + Ok(new_len) + } + + pub fn rpush(&self, key: &str, elements: Vec) -> Result { + let write_txn = self.db.begin_write()?; + let mut new_len = 0u64; + + { + let mut types_table = write_txn.open_table(TYPES_TABLE)?; + let mut lists_table = write_txn.open_table(LISTS_TABLE)?; + + let existing_type = match types_table.get(key)? { + Some(type_val) => Some(type_val.value().to_string()), + None => None, + }; + + match existing_type { + Some(ref type_str) if type_str != "list" => { + return Err(DBError("WRONGTYPE Operation against a key holding the wrong kind of value".to_string())); + } + None => { + types_table.insert(key, "list")?; + } + _ => {} + } + + let mut list_value: ListValue = match lists_table.get(key)? { + Some(data) => bincode::deserialize(data.value())?, + None => ListValue { elements: Vec::new() }, + }; + + for element in elements { + list_value.elements.push(element); + } + new_len = list_value.elements.len() as u64; + + let serialized = bincode::serialize(&list_value)?; + lists_table.insert(key, serialized.as_slice())?; + } + + write_txn.commit()?; + Ok(new_len) + } + + pub fn lpop(&self, key: &str, count: Option) -> Result>, DBError> { + let write_txn = self.db.begin_write()?; + let mut result_elements = Vec::new(); + + { + let mut types_table = write_txn.open_table(TYPES_TABLE)?; + let mut lists_table = write_txn.open_table(LISTS_TABLE)?; + + let existing_type = match types_table.get(key)? { + Some(type_val) => Some(type_val.value().to_string()), + None => None, + }; + + match existing_type { + Some(ref type_str) if type_str != "list" => { + return Err(DBError("WRONGTYPE Operation against a key holding the wrong kind of value".to_string())); + } + Some(_) => { + let mut list_value: ListValue = match lists_table.get(key)? { + Some(data) => bincode::deserialize(data.value())?, + None => return Ok(None), // Key exists but list is empty (shouldn't happen if type is "list") + }; + + let num_to_pop = count.unwrap_or(1) as usize; + for _ in 0..num_to_pop { + if !list_value.elements.is_empty() { + result_elements.push(list_value.elements.remove(0)); + } else { + break; + } + } + + if list_value.elements.is_empty() { + lists_table.remove(key)?; + types_table.remove(key)?; + } else { + let serialized = bincode::serialize(&list_value)?; + lists_table.insert(key, serialized.as_slice())?; + } + } + None => return Ok(None), + } + } + + write_txn.commit()?; + if result_elements.is_empty() { + Ok(None) + } else { + Ok(Some(result_elements)) + } + } + + pub fn rpop(&self, key: &str, count: Option) -> Result>, DBError> { + let write_txn = self.db.begin_write()?; + let mut result_elements = Vec::new(); + + { + let mut types_table = write_txn.open_table(TYPES_TABLE)?; + let mut lists_table = write_txn.open_table(LISTS_TABLE)?; + + let existing_type = match types_table.get(key)? { + Some(type_val) => Some(type_val.value().to_string()), + None => None, + }; + + match existing_type { + Some(ref type_str) if type_str != "list" => { + return Err(DBError("WRONGTYPE Operation against a key holding the wrong kind of value".to_string())); + } + Some(_) => { + let mut list_value: ListValue = match lists_table.get(key)? { + Some(data) => bincode::deserialize(data.value())?, + None => return Ok(None), + }; + + let num_to_pop = count.unwrap_or(1) as usize; + for _ in 0..num_to_pop { + if let Some(element) = list_value.elements.pop() { + result_elements.push(element); + } else { + break; + } + } + + if list_value.elements.is_empty() { + lists_table.remove(key)?; + types_table.remove(key)?; + } else { + let serialized = bincode::serialize(&list_value)?; + lists_table.insert(key, serialized.as_slice())?; + } + } + None => return Ok(None), + } + } + + write_txn.commit()?; + if result_elements.is_empty() { + Ok(None) + } else { + Ok(Some(result_elements)) + } + } + + pub fn llen(&self, key: &str) -> Result { + let read_txn = self.db.begin_read()?; + + let types_table = read_txn.open_table(TYPES_TABLE)?; + match types_table.get(key)? { + Some(type_val) if type_val.value() == "list" => { + let lists_table = read_txn.open_table(LISTS_TABLE)?; + match lists_table.get(key)? { + Some(data) => { + let list_value: ListValue = bincode::deserialize(data.value())?; + Ok(list_value.elements.len() as u64) + } + None => Ok(0), // Key exists but list is empty + } + } + Some(_) => Err(DBError("WRONGTYPE Operation against a key holding the wrong kind of value".to_string())), + None => Ok(0), // Key does not exist + } + } + + pub fn lrem(&self, key: &str, count: i64, element: &str) -> Result { + let write_txn = self.db.begin_write()?; + let mut removed_count = 0u64; + + { + let mut types_table = write_txn.open_table(TYPES_TABLE)?; + let mut lists_table = write_txn.open_table(LISTS_TABLE)?; + + let existing_type = match types_table.get(key)? { + Some(type_val) => Some(type_val.value().to_string()), + None => None, + }; + + match existing_type { + Some(ref type_str) if type_str != "list" => { + return Err(DBError("WRONGTYPE Operation against a key holding the wrong kind of value".to_string())); + } + Some(_) => { + let mut list_value: ListValue = match lists_table.get(key)? { + Some(data) => bincode::deserialize(data.value())?, + None => return Ok(0), + }; + + let initial_len = list_value.elements.len(); + + if count > 0 { + let mut i = 0; + let mut removed = 0; + while i < list_value.elements.len() && removed < count { + if list_value.elements[i] == element { + list_value.elements.remove(i); + removed += 1; + } else { + i += 1; + } + } + } else if count < 0 { + let mut i = list_value.elements.len() as i32 - 1; + let mut removed = 0; + while i >= 0 && removed < -count { + if list_value.elements[i as usize] == element { + list_value.elements.remove(i as usize); + removed += 1; + } + i -= 1; + } + } else { // count == 0 + list_value.elements.retain(|el| el != element); + } + + removed_count = (initial_len - list_value.elements.len()) as u64; + + if list_value.elements.is_empty() { + lists_table.remove(key)?; + types_table.remove(key)?; + } else { + let serialized = bincode::serialize(&list_value)?; + lists_table.insert(key, serialized.as_slice())?; + } + } + None => return Ok(0), + } + } + + write_txn.commit()?; + Ok(removed_count) + } + + pub fn ltrim(&self, key: &str, start: i64, stop: i64) -> Result<(), DBError> { + let write_txn = self.db.begin_write()?; + + { + let mut types_table = write_txn.open_table(TYPES_TABLE)?; + let mut lists_table = write_txn.open_table(LISTS_TABLE)?; + + let existing_type = match types_table.get(key)? { + Some(type_val) => Some(type_val.value().to_string()), + None => None, + }; + + match existing_type { + Some(ref type_str) if type_str != "list" => { + return Err(DBError("WRONGTYPE Operation against a key holding the wrong kind of value".to_string())); + } + Some(_) => { + let mut list_value: ListValue = match lists_table.get(key)? { + Some(data) => bincode::deserialize(data.value())?, + None => return Ok(()), + }; + + let len = list_value.elements.len() as i64; + let mut start = start; + let mut stop = stop; + + if start < 0 { + start += len; + } + if stop < 0 { + stop += len; + } + + if start < 0 { + start = 0; + } + + if start > stop || start >= len { + list_value.elements.clear(); + } else { + if stop >= len { + stop = len - 1; + } + let start = start as usize; + let stop = stop as usize; + list_value.elements = list_value.elements.drain(start..=stop).collect(); + } + + if list_value.elements.is_empty() { + lists_table.remove(key)?; + types_table.remove(key)?; + } else { + let serialized = bincode::serialize(&list_value)?; + lists_table.insert(key, serialized.as_slice())?; + } + } + None => {} + } + } + + write_txn.commit()?; + Ok(()) + } + + pub fn lindex(&self, key: &str, index: i64) -> Result, DBError> { + let read_txn = self.db.begin_read()?; + + let types_table = read_txn.open_table(TYPES_TABLE)?; + match types_table.get(key)? { + Some(type_val) if type_val.value() == "list" => { + let lists_table = read_txn.open_table(LISTS_TABLE)?; + match lists_table.get(key)? { + Some(data) => { + let list_value: ListValue = bincode::deserialize(data.value())?; + let len = list_value.elements.len() as i64; + let mut index = index; + if index < 0 { + index += len; + } + if index < 0 || index >= len { + Ok(None) + } else { + Ok(list_value.elements.get(index as usize).cloned()) + } + } + None => Ok(None), + } + } + Some(_) => Err(DBError("WRONGTYPE Operation against a key holding the wrong kind of value".to_string())), + None => Ok(None), + } + } + + pub fn lrange(&self, key: &str, start: i64, stop: i64) -> Result, DBError> { + let read_txn = self.db.begin_read()?; + + let types_table = read_txn.open_table(TYPES_TABLE)?; + match types_table.get(key)? { + Some(type_val) if type_val.value() == "list" => { + let lists_table = read_txn.open_table(LISTS_TABLE)?; + match lists_table.get(key)? { + Some(data) => { + let list_value: ListValue = bincode::deserialize(data.value())?; + let len = list_value.elements.len() as i64; + let mut start = start; + let mut stop = stop; + + if start < 0 { + start += len; + } + if stop < 0 { + stop += len; + } + + if start < 0 { + start = 0; + } + + if start > stop || start >= len { + Ok(Vec::new()) + } else { + if stop >= len { + stop = len - 1; + } + Ok(list_value.elements[start as usize..=stop as usize].to_vec()) + } + } + None => Ok(Vec::new()), + } + } + Some(_) => Err(DBError("WRONGTYPE Operation against a key holding the wrong kind of value".to_string())), + None => Ok(Vec::new()), + } + } +} \ No newline at end of file diff --git a/tests/redis_tests.rs b/tests/redis_tests.rs index 94e425f..423b1b7 100644 --- a/tests/redis_tests.rs +++ b/tests/redis_tests.rs @@ -12,7 +12,8 @@ async fn start_test_server(test_name: &str) -> (Server, u16) { let port = PORT_COUNTER.fetch_add(1, Ordering::SeqCst); let test_dir = format!("/tmp/herodb_test_{}", test_name); - // Create test directory + // Clean up and create test directory + let _ = std::fs::remove_dir_all(&test_dir); std::fs::create_dir_all(&test_dir).unwrap(); let option = DBOption { @@ -196,7 +197,7 @@ async fn test_hash_operations() { let response = send_command(&mut stream, "*3\r\n$7\r\nHEXISTS\r\n$4\r\nhash\r\n$6\r\nfield1\r\n").await; assert!(response.contains("1")); - let response = send_command(&mut stream, "*3\r\n$7\r\nHEXISTS\r\n$4\r\nhash\r\n$8\r\nnoexist\r\n").await; + let response = send_command(&mut stream, "*3\r\n$7\r\nHEXISTS\r\n$4\r\nhash\r\n$7\r\nnoexist\r\n").await; assert!(response.contains("0")); // Test HDEL @@ -442,7 +443,7 @@ async fn test_type_command() { assert!(response.contains("hash")); // Test non-existent key - let response = send_command(&mut stream, "*2\r\n$4\r\nTYPE\r\n$8\r\nnoexist\r\n").await; + let response = send_command(&mut stream, "*2\r\n$4\r\nTYPE\r\n$7\r\nnoexist\r\n").await; assert!(response.contains("none")); }