This commit is contained in:
2025-08-16 08:41:19 +02:00
parent ad255a9f51
commit bec9b20ec7
5 changed files with 660 additions and 73 deletions

View File

@@ -39,13 +39,20 @@ pub enum Cmd {
// List commands // List commands
LPush(String, Vec<String>), LPush(String, Vec<String>),
RPush(String, Vec<String>), RPush(String, Vec<String>),
LPop(String, Option<u64>),
RPop(String, Option<u64>),
LLen(String),
LRem(String, i64, String),
LTrim(String, i64, i64),
LIndex(String, i64),
LRange(String, i64, i64),
Unknow(String), Unknow(String),
} }
impl Cmd { impl Cmd {
pub fn from(s: &str) -> Result<(Self, Protocol), DBError> { pub fn from(s: &str) -> Result<(Self, Protocol, &str), DBError> {
let protocol = Protocol::from(s)?; let (protocol, remaining) = Protocol::from(s)?;
match protocol.clone().0 { match protocol.clone() {
Protocol::Array(p) => { Protocol::Array(p) => {
let cmd = p.into_iter().map(|x| x.decode()).collect::<Vec<_>>(); let cmd = p.into_iter().map(|x| x.decode()).collect::<Vec<_>>();
if cmd.is_empty() { if cmd.is_empty() {
@@ -303,14 +310,85 @@ impl Cmd {
Cmd::Client(vec![]) Cmd::Client(vec![])
} }
} }
"lpush" => {
if cmd.len() < 3 {
return Err(DBError(format!("wrong number of arguments for LPUSH command")));
}
Cmd::LPush(cmd[1].clone(), cmd[2..].to_vec())
}
"rpush" => {
if cmd.len() < 3 {
return Err(DBError(format!("wrong number of arguments for RPUSH command")));
}
Cmd::RPush(cmd[1].clone(), cmd[2..].to_vec())
}
"lpop" => {
if cmd.len() < 2 || cmd.len() > 3 {
return Err(DBError(format!("wrong number of arguments for LPOP command")));
}
let count = if cmd.len() == 3 {
Some(cmd[2].parse::<u64>().map_err(|_| DBError("ERR value is not an integer or out of range".to_string()))?)
} else {
None
};
Cmd::LPop(cmd[1].clone(), count)
}
"rpop" => {
if cmd.len() < 2 || cmd.len() > 3 {
return Err(DBError(format!("wrong number of arguments for RPOP command")));
}
let count = if cmd.len() == 3 {
Some(cmd[2].parse::<u64>().map_err(|_| DBError("ERR value is not an integer or out of range".to_string()))?)
} else {
None
};
Cmd::RPop(cmd[1].clone(), count)
}
"llen" => {
if cmd.len() != 2 {
return Err(DBError(format!("wrong number of arguments for LLEN command")));
}
Cmd::LLen(cmd[1].clone())
}
"lrem" => {
if cmd.len() != 4 {
return Err(DBError(format!("wrong number of arguments for LREM command")));
}
let count = cmd[2].parse::<i64>().map_err(|_| DBError("ERR value is not an integer or out of range".to_string()))?;
Cmd::LRem(cmd[1].clone(), count, cmd[3].clone())
}
"ltrim" => {
if cmd.len() != 4 {
return Err(DBError(format!("wrong number of arguments for LTRIM command")));
}
let start = cmd[2].parse::<i64>().map_err(|_| DBError("ERR value is not an integer or out of range".to_string()))?;
let stop = cmd[3].parse::<i64>().map_err(|_| DBError("ERR value is not an integer or out of range".to_string()))?;
Cmd::LTrim(cmd[1].clone(), start, stop)
}
"lindex" => {
if cmd.len() != 3 {
return Err(DBError(format!("wrong number of arguments for LINDEX command")));
}
let index = cmd[2].parse::<i64>().map_err(|_| DBError("ERR value is not an integer or out of range".to_string()))?;
Cmd::LIndex(cmd[1].clone(), index)
}
"lrange" => {
if cmd.len() != 4 {
return Err(DBError(format!("wrong number of arguments for LRANGE command")));
}
let start = cmd[2].parse::<i64>().map_err(|_| DBError("ERR value is not an integer or out of range".to_string()))?;
let stop = cmd[3].parse::<i64>().map_err(|_| DBError("ERR value is not an integer or out of range".to_string()))?;
Cmd::LRange(cmd[1].clone(), start, stop)
}
_ => Cmd::Unknow(cmd[0].clone()), _ => Cmd::Unknow(cmd[0].clone()),
}, },
protocol.0, protocol,
remaining
)) ))
} }
_ => Err(DBError(format!( _ => Err(DBError(format!(
"fail to parse as cmd for {:?}", "fail to parse as cmd for {:?}",
protocol.0 protocol
))), ))),
} }
} }
@@ -379,6 +457,16 @@ impl Cmd {
Cmd::Client(_) => Ok(Protocol::SimpleString("OK".to_string())), Cmd::Client(_) => Ok(Protocol::SimpleString("OK".to_string())),
Cmd::ClientSetName(name) => client_setname_cmd(server, name).await, Cmd::ClientSetName(name) => client_setname_cmd(server, name).await,
Cmd::ClientGetName => client_getname_cmd(server).await, Cmd::ClientGetName => client_getname_cmd(server).await,
// List commands
Cmd::LPush(key, elements) => lpush_cmd(server, key, elements).await,
Cmd::RPush(key, elements) => rpush_cmd(server, key, elements).await,
Cmd::LPop(key, count) => lpop_cmd(server, key, count).await,
Cmd::RPop(key, count) => rpop_cmd(server, key, count).await,
Cmd::LLen(key) => llen_cmd(server, key).await,
Cmd::LRem(key, count, element) => lrem_cmd(server, key, *count, element).await,
Cmd::LTrim(key, start, stop) => ltrim_cmd(server, key, *start, *stop).await,
Cmd::LIndex(key, index) => lindex_cmd(server, key, *index).await,
Cmd::LRange(key, start, stop) => lrange_cmd(server, key, *start, *stop).await,
Cmd::Unknow(s) => { Cmd::Unknow(s) => {
println!("\x1b[31;1munknown command: {}\x1b[0m", s); println!("\x1b[31;1munknown command: {}\x1b[0m", s);
Ok(Protocol::err(&format!("ERR unknown command '{}'", s))) Ok(Protocol::err(&format!("ERR unknown command '{}'", s)))
@@ -387,6 +475,96 @@ impl Cmd {
} }
} }
async fn lindex_cmd(server: &Server, key: &str, index: i64) -> Result<Protocol, DBError> {
match server.storage.lindex(key, index) {
Ok(Some(element)) => Ok(Protocol::BulkString(element)),
Ok(None) => Ok(Protocol::Null),
Err(e) => Ok(Protocol::err(&e.0)),
}
}
async fn lrange_cmd(server: &Server, key: &str, start: i64, stop: i64) -> Result<Protocol, DBError> {
match server.storage.lrange(key, start, stop) {
Ok(elements) => Ok(Protocol::Array(elements.into_iter().map(Protocol::BulkString).collect())),
Err(e) => Ok(Protocol::err(&e.0)),
}
}
async fn ltrim_cmd(server: &Server, key: &str, start: i64, stop: i64) -> Result<Protocol, DBError> {
match server.storage.ltrim(key, start, stop) {
Ok(_) => Ok(Protocol::SimpleString("OK".to_string())),
Err(e) => Ok(Protocol::err(&e.0)),
}
}
async fn lrem_cmd(server: &Server, key: &str, count: i64, element: &str) -> Result<Protocol, DBError> {
match server.storage.lrem(key, count, element) {
Ok(removed_count) => Ok(Protocol::SimpleString(removed_count.to_string())),
Err(e) => Ok(Protocol::err(&e.0)),
}
}
async fn llen_cmd(server: &Server, key: &str) -> Result<Protocol, DBError> {
match server.storage.llen(key) {
Ok(len) => Ok(Protocol::SimpleString(len.to_string())),
Err(e) => Ok(Protocol::err(&e.0)),
}
}
async fn lpop_cmd(server: &Server, key: &str, count: &Option<u64>) -> Result<Protocol, DBError> {
match server.storage.lpop(key, *count) {
Ok(Some(elements)) => {
if count.is_some() {
Ok(Protocol::Array(elements.into_iter().map(Protocol::BulkString).collect()))
} else {
Ok(Protocol::BulkString(elements[0].clone()))
}
},
Ok(None) => {
if count.is_some() {
Ok(Protocol::Array(vec![]))
} else {
Ok(Protocol::Null)
}
},
Err(e) => Ok(Protocol::err(&e.0)),
}
}
async fn rpop_cmd(server: &Server, key: &str, count: &Option<u64>) -> Result<Protocol, DBError> {
match server.storage.rpop(key, *count) {
Ok(Some(elements)) => {
if count.is_some() {
Ok(Protocol::Array(elements.into_iter().map(Protocol::BulkString).collect()))
} else {
Ok(Protocol::BulkString(elements[0].clone()))
}
},
Ok(None) => {
if count.is_some() {
Ok(Protocol::Array(vec![]))
} else {
Ok(Protocol::Null)
}
},
Err(e) => Ok(Protocol::err(&e.0)),
}
}
async fn lpush_cmd(server: &Server, key: &str, elements: &[String]) -> Result<Protocol, DBError> {
match server.storage.lpush(key, elements.to_vec()) {
Ok(len) => Ok(Protocol::SimpleString(len.to_string())),
Err(e) => Ok(Protocol::err(&e.0)),
}
}
async fn rpush_cmd(server: &Server, key: &str, elements: &[String]) -> Result<Protocol, DBError> {
match server.storage.rpush(key, elements.to_vec()) {
Ok(len) => Ok(Protocol::SimpleString(len.to_string())),
Err(e) => Ok(Protocol::err(&e.0)),
}
}
async fn exec_cmd( async fn exec_cmd(
queued_cmd: &mut Option<Vec<(Cmd, Protocol)>>, queued_cmd: &mut Option<Vec<(Cmd, Protocol)>>,
server: &mut Server, server: &mut Server,

View File

@@ -17,7 +17,7 @@ impl fmt::Display for Protocol {
} }
impl Protocol { impl Protocol {
pub fn from(protocol: &str) -> Result<(Self, usize), DBError> { pub fn from(protocol: &str) -> Result<(Self, &str), DBError> {
let ret = match protocol.chars().nth(0) { let ret = match protocol.chars().nth(0) {
Some('+') => Self::parse_simple_string_sfx(&protocol[1..]), Some('+') => Self::parse_simple_string_sfx(&protocol[1..]),
Some('$') => Self::parse_bulk_string_sfx(&protocol[1..]), Some('$') => Self::parse_bulk_string_sfx(&protocol[1..]),
@@ -27,10 +27,7 @@ impl Protocol {
protocol protocol
))), ))),
}; };
match ret { ret
Ok((p, s)) => Ok((p, s + 1)),
Err(e) => Err(e),
}
} }
pub fn from_vec(array: Vec<&str>) -> Self { pub fn from_vec(array: Vec<&str>) -> Self {
@@ -91,9 +88,9 @@ impl Protocol {
} }
} }
fn parse_simple_string_sfx(protocol: &str) -> Result<(Self, usize), DBError> { fn parse_simple_string_sfx(protocol: &str) -> Result<(Self, &str), DBError> {
match protocol.find("\r\n") { match protocol.find("\r\n") {
Some(x) => Ok((Self::SimpleString(protocol[..x].to_string()), x + 2)), Some(x) => Ok((Self::SimpleString(protocol[..x].to_string()), &protocol[x + 2..])),
_ => Err(DBError(format!( _ => Err(DBError(format!(
"[new simple string] unsupported protocol: {:?}", "[new simple string] unsupported protocol: {:?}",
protocol protocol
@@ -101,27 +98,20 @@ impl Protocol {
} }
} }
fn parse_bulk_string_sfx(protocol: &str) -> Result<(Self, usize), DBError> { fn parse_bulk_string_sfx(protocol: &str) -> Result<(Self, &str), DBError> {
if let Some(len) = protocol.find("\r\n") { if let Some(len_end) = protocol.find("\r\n") {
let size = Self::parse_usize(&protocol[..len])?; let size = Self::parse_usize(&protocol[..len_end])?;
if let Some(data_len) = protocol[len + 2..].find("\r\n") { let data_start = len_end + 2;
let s = Self::parse_string(&protocol[len + 2..len + 2 + data_len])?; let data_end = data_start + size;
if size != s.len() { let s = Self::parse_string(&protocol[data_start..data_end])?;
if protocol.len() < data_end + 2 || &protocol[data_end..data_end+2] != "\r\n" {
Err(DBError(format!( Err(DBError(format!(
"[new bulk string] unmatched string length in prototocl {:?}", "[new bulk string] unmatched string length in prototocl {:?}",
protocol, protocol,
))) )))
} else { } else {
Ok(( Ok((Protocol::BulkString(s), &protocol[data_end + 2..]))
Protocol::BulkString(s.to_lowercase()),
len + 2 + data_len + 2,
))
}
} else {
Err(DBError(format!(
"[new bulk string] unsupported protocol: {:?}",
protocol
)))
} }
} else { } else {
Err(DBError(format!( Err(DBError(format!(
@@ -131,30 +121,22 @@ impl Protocol {
} }
} }
fn parse_array_sfx(s: &str) -> Result<(Self, usize), DBError> { fn parse_array_sfx(s: &str) -> Result<(Self, &str), DBError> {
let mut offset = 0; if let Some(len_end) = s.find("\r\n") {
match s.find("\r\n") { let array_len = s[..len_end].parse::<usize>()?;
Some(x) => { let mut remaining = &s[len_end + 2..];
let array_len = s[..x].parse::<usize>()?;
offset += x + 2;
let mut vec = vec![]; let mut vec = vec![];
for _ in 0..array_len { for _ in 0..array_len {
match Protocol::from(&s[offset..]) { let (p, rem) = Protocol::from(remaining)?;
Ok((p, len)) => {
offset += len;
vec.push(p); vec.push(p);
remaining = rem;
} }
Err(e) => { Ok((Protocol::Array(vec), remaining))
return Err(e); } else {
} Err(DBError(format!(
}
}
Ok((Protocol::Array(vec), offset))
}
_ => Err(DBError(format!(
"[new array] unsupported protocol: {:?}", "[new array] unsupported protocol: {:?}",
s s
))), )))
} }
} }

View File

@@ -41,20 +41,29 @@ impl Server {
let mut queued_cmd: Option<Vec<(Cmd, Protocol)>> = None; let mut queued_cmd: Option<Vec<(Cmd, Protocol)>> = None;
loop { loop {
if let Ok(len) = stream.read(&mut buf).await { let len = match stream.read(&mut buf).await {
if len == 0 { Ok(0) => {
println!("[handle] connection closed"); println!("[handle] connection closed");
return Ok(()); return Ok(());
} }
Ok(len) => len,
let s = str::from_utf8(&buf[..len])?;
let (cmd, protocol) = match Cmd::from(s) {
Ok((cmd, protocol)) => (cmd, protocol),
Err(e) => { Err(e) => {
println!("\x1b[31;1mprotocol error: {:?}\x1b[0m", e); println!("[handle] read error: {:?}", e);
(Cmd::Unknow("protocol_error".to_string()), Protocol::err(&format!("protocol error: {}", e.0))) return Err(e.into());
} }
}; };
let mut s = str::from_utf8(&buf[..len])?;
while !s.is_empty() {
let (cmd, protocol, remaining) = match Cmd::from(s) {
Ok((cmd, protocol, remaining)) => (cmd, protocol, remaining),
Err(e) => {
println!("\x1b[31;1mprotocol error: {:?}\x1b[0m", e);
(Cmd::Unknow("protocol_error".to_string()), Protocol::err(&format!("protocol error: {}", e.0)), "")
}
};
s = remaining;
if self.option.debug { if self.option.debug {
println!("\x1b[34;1mgot command: {:?}, protocol: {:?}\x1b[0m", cmd, protocol); println!("\x1b[34;1mgot command: {:?}, protocol: {:?}\x1b[0m", cmd, protocol);
} else { } else {
@@ -68,17 +77,15 @@ impl Server {
.run(&mut self.clone(), protocol.clone(), &mut queued_cmd) .run(&mut self.clone(), protocol.clone(), &mut queued_cmd)
.await .await
.unwrap_or(Protocol::err("unknown cmd from server")); .unwrap_or(Protocol::err("unknown cmd from server"));
if self.option.debug {
println!("\x1b[34;1mqueued cmd {:?}\x1b[0m", queued_cmd);
} else {
print!("queued cmd {:?}", queued_cmd);
}
if self.option.debug { if self.option.debug {
println!("\x1b[34;1mqueued cmd {:?}\x1b[0m", queued_cmd);
println!("\x1b[32;1mgoing to send response {}\x1b[0m", res.encode()); println!("\x1b[32;1mgoing to send response {}\x1b[0m", res.encode());
} else { } else {
print!("queued cmd {:?}", queued_cmd);
println!("going to send response {}", res.encode()); println!("going to send response {}", res.encode());
} }
_ = stream.write(res.encode().as_bytes()).await?; _ = stream.write(res.encode().as_bytes()).await?;
// If this was a QUIT command, close the connection // If this was a QUIT command, close the connection
@@ -86,9 +93,6 @@ impl Server {
println!("[handle] QUIT command received, closing connection"); println!("[handle] QUIT command received, closing connection");
return Ok(()); return Ok(());
} }
} else {
println!("[handle] going to break");
break;
} }
} }
Ok(()) Ok(())

View File

@@ -12,6 +12,7 @@ use crate::error::DBError;
const TYPES_TABLE: TableDefinition<&str, &str> = TableDefinition::new("types"); const TYPES_TABLE: TableDefinition<&str, &str> = TableDefinition::new("types");
const STRINGS_TABLE: TableDefinition<&str, &[u8]> = TableDefinition::new("strings"); const STRINGS_TABLE: TableDefinition<&str, &[u8]> = TableDefinition::new("strings");
const HASHES_TABLE: TableDefinition<(&str, &str), &str> = TableDefinition::new("hashes"); const HASHES_TABLE: TableDefinition<(&str, &str), &str> = TableDefinition::new("hashes");
const LISTS_TABLE: TableDefinition<&str, &[u8]> = TableDefinition::new("lists");
const STREAMS_META_TABLE: TableDefinition<&str, &[u8]> = TableDefinition::new("streams_meta"); const STREAMS_META_TABLE: TableDefinition<&str, &[u8]> = TableDefinition::new("streams_meta");
const STREAMS_DATA_TABLE: TableDefinition<(&str, &str), &[u8]> = TableDefinition::new("streams_data"); const STREAMS_DATA_TABLE: TableDefinition<(&str, &str), &[u8]> = TableDefinition::new("streams_data");
@@ -26,6 +27,12 @@ pub struct StreamEntry {
pub fields: Vec<(String, String)>, pub fields: Vec<(String, String)>,
} }
#[derive(Serialize, Deserialize, Debug, Clone)]
pub struct ListValue {
pub elements: Vec<String>,
}
#[inline] #[inline]
pub fn now_in_millis() -> u128 { pub fn now_in_millis() -> u128 {
let start = SystemTime::now(); let start = SystemTime::now();
@@ -47,6 +54,7 @@ impl Storage {
let _ = write_txn.open_table(TYPES_TABLE)?; let _ = write_txn.open_table(TYPES_TABLE)?;
let _ = write_txn.open_table(STRINGS_TABLE)?; let _ = write_txn.open_table(STRINGS_TABLE)?;
let _ = write_txn.open_table(HASHES_TABLE)?; let _ = write_txn.open_table(HASHES_TABLE)?;
let _ = write_txn.open_table(LISTS_TABLE)?;
let _ = write_txn.open_table(STREAMS_META_TABLE)?; let _ = write_txn.open_table(STREAMS_META_TABLE)?;
let _ = write_txn.open_table(STREAMS_DATA_TABLE)?; let _ = write_txn.open_table(STREAMS_DATA_TABLE)?;
} }
@@ -143,6 +151,7 @@ impl Storage {
let mut types_table = write_txn.open_table(TYPES_TABLE)?; let mut types_table = write_txn.open_table(TYPES_TABLE)?;
let mut strings_table = write_txn.open_table(STRINGS_TABLE)?; let mut strings_table = write_txn.open_table(STRINGS_TABLE)?;
let mut hashes_table = write_txn.open_table(HASHES_TABLE)?; let mut hashes_table = write_txn.open_table(HASHES_TABLE)?;
let mut lists_table = write_txn.open_table(LISTS_TABLE)?;
// Remove from type table // Remove from type table
types_table.remove(key.as_str())?; types_table.remove(key.as_str())?;
@@ -165,6 +174,9 @@ impl Storage {
for (hash_key, field) in to_remove { for (hash_key, field) in to_remove {
hashes_table.remove((hash_key.as_str(), field.as_str()))?; hashes_table.remove((hash_key.as_str(), field.as_str()))?;
} }
// Remove from lists table
lists_table.remove(key.as_str())?;
} }
write_txn.commit()?; write_txn.commit()?;
@@ -633,4 +645,414 @@ impl Storage {
None => Ok(false), None => Ok(false),
} }
} }
// List operations
pub fn lpush(&self, key: &str, elements: Vec<String>) -> Result<u64, DBError> {
let write_txn = self.db.begin_write()?;
let mut new_len = 0u64;
{
let mut types_table = write_txn.open_table(TYPES_TABLE)?;
let mut lists_table = write_txn.open_table(LISTS_TABLE)?;
let existing_type = match types_table.get(key)? {
Some(type_val) => Some(type_val.value().to_string()),
None => None,
};
match existing_type {
Some(ref type_str) if type_str != "list" => {
return Err(DBError("WRONGTYPE Operation against a key holding the wrong kind of value".to_string()));
}
None => {
types_table.insert(key, "list")?;
}
_ => {}
}
let mut list_value: ListValue = match lists_table.get(key)? {
Some(data) => bincode::deserialize(data.value())?,
None => ListValue { elements: Vec::new() },
};
for element in elements.into_iter().rev() {
list_value.elements.insert(0, element);
}
new_len = list_value.elements.len() as u64;
let serialized = bincode::serialize(&list_value)?;
lists_table.insert(key, serialized.as_slice())?;
}
write_txn.commit()?;
Ok(new_len)
}
pub fn rpush(&self, key: &str, elements: Vec<String>) -> Result<u64, DBError> {
let write_txn = self.db.begin_write()?;
let mut new_len = 0u64;
{
let mut types_table = write_txn.open_table(TYPES_TABLE)?;
let mut lists_table = write_txn.open_table(LISTS_TABLE)?;
let existing_type = match types_table.get(key)? {
Some(type_val) => Some(type_val.value().to_string()),
None => None,
};
match existing_type {
Some(ref type_str) if type_str != "list" => {
return Err(DBError("WRONGTYPE Operation against a key holding the wrong kind of value".to_string()));
}
None => {
types_table.insert(key, "list")?;
}
_ => {}
}
let mut list_value: ListValue = match lists_table.get(key)? {
Some(data) => bincode::deserialize(data.value())?,
None => ListValue { elements: Vec::new() },
};
for element in elements {
list_value.elements.push(element);
}
new_len = list_value.elements.len() as u64;
let serialized = bincode::serialize(&list_value)?;
lists_table.insert(key, serialized.as_slice())?;
}
write_txn.commit()?;
Ok(new_len)
}
pub fn lpop(&self, key: &str, count: Option<u64>) -> Result<Option<Vec<String>>, DBError> {
let write_txn = self.db.begin_write()?;
let mut result_elements = Vec::new();
{
let mut types_table = write_txn.open_table(TYPES_TABLE)?;
let mut lists_table = write_txn.open_table(LISTS_TABLE)?;
let existing_type = match types_table.get(key)? {
Some(type_val) => Some(type_val.value().to_string()),
None => None,
};
match existing_type {
Some(ref type_str) if type_str != "list" => {
return Err(DBError("WRONGTYPE Operation against a key holding the wrong kind of value".to_string()));
}
Some(_) => {
let mut list_value: ListValue = match lists_table.get(key)? {
Some(data) => bincode::deserialize(data.value())?,
None => return Ok(None), // Key exists but list is empty (shouldn't happen if type is "list")
};
let num_to_pop = count.unwrap_or(1) as usize;
for _ in 0..num_to_pop {
if !list_value.elements.is_empty() {
result_elements.push(list_value.elements.remove(0));
} else {
break;
}
}
if list_value.elements.is_empty() {
lists_table.remove(key)?;
types_table.remove(key)?;
} else {
let serialized = bincode::serialize(&list_value)?;
lists_table.insert(key, serialized.as_slice())?;
}
}
None => return Ok(None),
}
}
write_txn.commit()?;
if result_elements.is_empty() {
Ok(None)
} else {
Ok(Some(result_elements))
}
}
pub fn rpop(&self, key: &str, count: Option<u64>) -> Result<Option<Vec<String>>, DBError> {
let write_txn = self.db.begin_write()?;
let mut result_elements = Vec::new();
{
let mut types_table = write_txn.open_table(TYPES_TABLE)?;
let mut lists_table = write_txn.open_table(LISTS_TABLE)?;
let existing_type = match types_table.get(key)? {
Some(type_val) => Some(type_val.value().to_string()),
None => None,
};
match existing_type {
Some(ref type_str) if type_str != "list" => {
return Err(DBError("WRONGTYPE Operation against a key holding the wrong kind of value".to_string()));
}
Some(_) => {
let mut list_value: ListValue = match lists_table.get(key)? {
Some(data) => bincode::deserialize(data.value())?,
None => return Ok(None),
};
let num_to_pop = count.unwrap_or(1) as usize;
for _ in 0..num_to_pop {
if let Some(element) = list_value.elements.pop() {
result_elements.push(element);
} else {
break;
}
}
if list_value.elements.is_empty() {
lists_table.remove(key)?;
types_table.remove(key)?;
} else {
let serialized = bincode::serialize(&list_value)?;
lists_table.insert(key, serialized.as_slice())?;
}
}
None => return Ok(None),
}
}
write_txn.commit()?;
if result_elements.is_empty() {
Ok(None)
} else {
Ok(Some(result_elements))
}
}
pub fn llen(&self, key: &str) -> Result<u64, DBError> {
let read_txn = self.db.begin_read()?;
let types_table = read_txn.open_table(TYPES_TABLE)?;
match types_table.get(key)? {
Some(type_val) if type_val.value() == "list" => {
let lists_table = read_txn.open_table(LISTS_TABLE)?;
match lists_table.get(key)? {
Some(data) => {
let list_value: ListValue = bincode::deserialize(data.value())?;
Ok(list_value.elements.len() as u64)
}
None => Ok(0), // Key exists but list is empty
}
}
Some(_) => Err(DBError("WRONGTYPE Operation against a key holding the wrong kind of value".to_string())),
None => Ok(0), // Key does not exist
}
}
pub fn lrem(&self, key: &str, count: i64, element: &str) -> Result<u64, DBError> {
let write_txn = self.db.begin_write()?;
let mut removed_count = 0u64;
{
let mut types_table = write_txn.open_table(TYPES_TABLE)?;
let mut lists_table = write_txn.open_table(LISTS_TABLE)?;
let existing_type = match types_table.get(key)? {
Some(type_val) => Some(type_val.value().to_string()),
None => None,
};
match existing_type {
Some(ref type_str) if type_str != "list" => {
return Err(DBError("WRONGTYPE Operation against a key holding the wrong kind of value".to_string()));
}
Some(_) => {
let mut list_value: ListValue = match lists_table.get(key)? {
Some(data) => bincode::deserialize(data.value())?,
None => return Ok(0),
};
let initial_len = list_value.elements.len();
if count > 0 {
let mut i = 0;
let mut removed = 0;
while i < list_value.elements.len() && removed < count {
if list_value.elements[i] == element {
list_value.elements.remove(i);
removed += 1;
} else {
i += 1;
}
}
} else if count < 0 {
let mut i = list_value.elements.len() as i32 - 1;
let mut removed = 0;
while i >= 0 && removed < -count {
if list_value.elements[i as usize] == element {
list_value.elements.remove(i as usize);
removed += 1;
}
i -= 1;
}
} else { // count == 0
list_value.elements.retain(|el| el != element);
}
removed_count = (initial_len - list_value.elements.len()) as u64;
if list_value.elements.is_empty() {
lists_table.remove(key)?;
types_table.remove(key)?;
} else {
let serialized = bincode::serialize(&list_value)?;
lists_table.insert(key, serialized.as_slice())?;
}
}
None => return Ok(0),
}
}
write_txn.commit()?;
Ok(removed_count)
}
pub fn ltrim(&self, key: &str, start: i64, stop: i64) -> Result<(), DBError> {
let write_txn = self.db.begin_write()?;
{
let mut types_table = write_txn.open_table(TYPES_TABLE)?;
let mut lists_table = write_txn.open_table(LISTS_TABLE)?;
let existing_type = match types_table.get(key)? {
Some(type_val) => Some(type_val.value().to_string()),
None => None,
};
match existing_type {
Some(ref type_str) if type_str != "list" => {
return Err(DBError("WRONGTYPE Operation against a key holding the wrong kind of value".to_string()));
}
Some(_) => {
let mut list_value: ListValue = match lists_table.get(key)? {
Some(data) => bincode::deserialize(data.value())?,
None => return Ok(()),
};
let len = list_value.elements.len() as i64;
let mut start = start;
let mut stop = stop;
if start < 0 {
start += len;
}
if stop < 0 {
stop += len;
}
if start < 0 {
start = 0;
}
if start > stop || start >= len {
list_value.elements.clear();
} else {
if stop >= len {
stop = len - 1;
}
let start = start as usize;
let stop = stop as usize;
list_value.elements = list_value.elements.drain(start..=stop).collect();
}
if list_value.elements.is_empty() {
lists_table.remove(key)?;
types_table.remove(key)?;
} else {
let serialized = bincode::serialize(&list_value)?;
lists_table.insert(key, serialized.as_slice())?;
}
}
None => {}
}
}
write_txn.commit()?;
Ok(())
}
pub fn lindex(&self, key: &str, index: i64) -> Result<Option<String>, DBError> {
let read_txn = self.db.begin_read()?;
let types_table = read_txn.open_table(TYPES_TABLE)?;
match types_table.get(key)? {
Some(type_val) if type_val.value() == "list" => {
let lists_table = read_txn.open_table(LISTS_TABLE)?;
match lists_table.get(key)? {
Some(data) => {
let list_value: ListValue = bincode::deserialize(data.value())?;
let len = list_value.elements.len() as i64;
let mut index = index;
if index < 0 {
index += len;
}
if index < 0 || index >= len {
Ok(None)
} else {
Ok(list_value.elements.get(index as usize).cloned())
}
}
None => Ok(None),
}
}
Some(_) => Err(DBError("WRONGTYPE Operation against a key holding the wrong kind of value".to_string())),
None => Ok(None),
}
}
pub fn lrange(&self, key: &str, start: i64, stop: i64) -> Result<Vec<String>, DBError> {
let read_txn = self.db.begin_read()?;
let types_table = read_txn.open_table(TYPES_TABLE)?;
match types_table.get(key)? {
Some(type_val) if type_val.value() == "list" => {
let lists_table = read_txn.open_table(LISTS_TABLE)?;
match lists_table.get(key)? {
Some(data) => {
let list_value: ListValue = bincode::deserialize(data.value())?;
let len = list_value.elements.len() as i64;
let mut start = start;
let mut stop = stop;
if start < 0 {
start += len;
}
if stop < 0 {
stop += len;
}
if start < 0 {
start = 0;
}
if start > stop || start >= len {
Ok(Vec::new())
} else {
if stop >= len {
stop = len - 1;
}
Ok(list_value.elements[start as usize..=stop as usize].to_vec())
}
}
None => Ok(Vec::new()),
}
}
Some(_) => Err(DBError("WRONGTYPE Operation against a key holding the wrong kind of value".to_string())),
None => Ok(Vec::new()),
}
}
} }

View File

@@ -12,7 +12,8 @@ async fn start_test_server(test_name: &str) -> (Server, u16) {
let port = PORT_COUNTER.fetch_add(1, Ordering::SeqCst); let port = PORT_COUNTER.fetch_add(1, Ordering::SeqCst);
let test_dir = format!("/tmp/herodb_test_{}", test_name); let test_dir = format!("/tmp/herodb_test_{}", test_name);
// Create test directory // Clean up and create test directory
let _ = std::fs::remove_dir_all(&test_dir);
std::fs::create_dir_all(&test_dir).unwrap(); std::fs::create_dir_all(&test_dir).unwrap();
let option = DBOption { let option = DBOption {
@@ -196,7 +197,7 @@ async fn test_hash_operations() {
let response = send_command(&mut stream, "*3\r\n$7\r\nHEXISTS\r\n$4\r\nhash\r\n$6\r\nfield1\r\n").await; let response = send_command(&mut stream, "*3\r\n$7\r\nHEXISTS\r\n$4\r\nhash\r\n$6\r\nfield1\r\n").await;
assert!(response.contains("1")); assert!(response.contains("1"));
let response = send_command(&mut stream, "*3\r\n$7\r\nHEXISTS\r\n$4\r\nhash\r\n$8\r\nnoexist\r\n").await; let response = send_command(&mut stream, "*3\r\n$7\r\nHEXISTS\r\n$4\r\nhash\r\n$7\r\nnoexist\r\n").await;
assert!(response.contains("0")); assert!(response.contains("0"));
// Test HDEL // Test HDEL
@@ -442,7 +443,7 @@ async fn test_type_command() {
assert!(response.contains("hash")); assert!(response.contains("hash"));
// Test non-existent key // Test non-existent key
let response = send_command(&mut stream, "*2\r\n$4\r\nTYPE\r\n$8\r\nnoexist\r\n").await; let response = send_command(&mut stream, "*2\r\n$4\r\nTYPE\r\n$7\r\nnoexist\r\n").await;
assert!(response.contains("none")); assert!(response.contains("none"));
} }