From de2be4a7854e035b0dd78daaf2a0980371571466 Mon Sep 17 00:00:00 2001 From: despiegk Date: Sat, 16 Aug 2025 07:18:55 +0200 Subject: [PATCH] ... --- Cargo.lock | 43 ++- Cargo.toml | 7 +- src/cmd.rs | 645 ++++++++++++++------------------------ src/error.rs | 45 ++- src/lib.rs | 2 - src/main.rs | 44 +-- src/options.rs | 10 - src/rdb.rs | 201 ------------ src/replication_client.rs | 155 --------- src/server.rs | 116 +------ src/storage.rs | 445 ++++++++++++++++++++++++-- test_herodb.sh | 302 ++++++++++++++++++ 12 files changed, 1060 insertions(+), 955 deletions(-) delete mode 100644 src/rdb.rs delete mode 100644 src/replication_client.rs create mode 100755 test_herodb.sh diff --git a/Cargo.lock b/Cargo.lock index 5049203..4c63baf 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1,6 +1,6 @@ # This file is automatically @generated by Cargo. # It is not intended for manual editing. -version = 3 +version = 4 [[package]] name = "addr2line" @@ -93,6 +93,15 @@ dependencies = [ "rustc-demangle", ] +[[package]] +name = "bincode" +version = "1.3.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b1f45e9417d87227c7a56d22e471c6206462cba514c7590c09aff4cf6d1ddcad" +dependencies = [ + "serde", +] + [[package]] name = "bitflags" version = "2.6.0" @@ -396,15 +405,27 @@ dependencies = [ "proc-macro2", ] +[[package]] +name = "redb" +version = "2.6.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "59b38b05028f398f08bea4691640503ec25fcb60b82fb61ce1f8fd1f4fccd3f7" +dependencies = [ + "libc", +] + [[package]] name = "redis-rs" version = "0.0.1" dependencies = [ "anyhow", + "bincode", "byteorder", "bytes", "clap", "futures", + "redb", + "serde", "thiserror", "tokio", ] @@ -430,6 +451,26 @@ version = "1.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "94143f37725109f92c262ed2cf5e59bce7498c01bcc1502d7b9afe439a4e9f49" +[[package]] +name = "serde" +version = "1.0.210" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c8e3592472072e6e22e0a54d5904d9febf8508f65fb8552499a1abc7d1078c3a" +dependencies = [ + "serde_derive", +] + +[[package]] +name = "serde_derive" +version = "1.0.210" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "243902eda00fad750862fc144cea25caca5e20d615af0a81bee94ca738f1df1f" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + [[package]] name = "signal-hook-registry" version = "1.4.2" diff --git a/Cargo.toml b/Cargo.toml index 91402a9..6af3454 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -6,10 +6,13 @@ edition = "2021" [dependencies] anyhow = "1.0.59" -bytes = "1.3.0" +bytes = "1.3.0" thiserror = "1.0.32" -tokio = { version = "1.23.0", features = ["full"] } +tokio = { version = "1.23.0", features = ["full"] } clap = { version = "4.5.20", features = ["derive"] } byteorder = "1.4.3" futures = "0.3" +redb = "2.1.3" +serde = { version = "1.0", features = ["derive"] } +bincode = "1.3.3" diff --git a/src/cmd.rs b/src/cmd.rs index 4a15c84..437331b 100644 --- a/src/cmd.rs +++ b/src/cmd.rs @@ -1,8 +1,4 @@ -use std::{collections::BTreeMap, ops::Bound, time::Duration, u64}; - -use tokio::sync::mpsc; - -use crate::{error::DBError, protocol::Protocol, server::Server, storage::now_in_millis}; +use crate::{error::DBError, protocol::Protocol, server::Server}; #[derive(Debug, Clone)] pub enum Cmd { @@ -16,17 +12,23 @@ pub enum Cmd { ConfigGet(String), Info(Option), Del(String), - Replconf(String), - Psync, Type(String), - Xadd(String, String, Vec<(String, String)>), - Xrange(String, String, String), - Xread(Vec, Vec, Option), Incr(String), Multi, Exec, - Unknow, Discard, + // Hash commands + HSet(String, Vec<(String, String)>), + HGet(String, String), + HGetAll(String), + HDel(String, Vec), + HExists(String, String), + HKeys(String), + HVals(String), + HLen(String), + HMGet(String, Vec), + HSetNx(String, String, String), + Unknow, } impl Cmd { @@ -39,14 +41,14 @@ impl Cmd { return Err(DBError("cmd length is 0".to_string())); } Ok(( - match cmd[0].as_str() { + match cmd[0].to_lowercase().as_str() { "echo" => Cmd::Echo(cmd[1].clone()), "ping" => Cmd::Ping, "get" => Cmd::Get(cmd[1].clone()), "set" => { - if cmd.len() == 5 && cmd[3] == "px" { + if cmd.len() == 5 && cmd[3].to_lowercase() == "px" { Cmd::SetPx(cmd[1].clone(), cmd[2].clone(), cmd[4].parse().unwrap()) - } else if cmd.len() == 5 && cmd[3] == "ex" { + } else if cmd.len() == 5 && cmd[3].to_lowercase() == "ex" { Cmd::SetEx(cmd[1].clone(), cmd[2].clone(), cmd[4].parse().unwrap()) } else if cmd.len() == 3 { Cmd::Set(cmd[1].clone(), cmd[2].clone()) @@ -55,7 +57,7 @@ impl Cmd { } } "config" => { - if cmd.len() != 3 || cmd[1] != "get" { + if cmd.len() != 3 || cmd[1].to_lowercase() != "get" { return Err(DBError(format!("unsupported cmd {:?}", cmd))); } else { Cmd::ConfigGet(cmd[2].clone()) @@ -76,18 +78,6 @@ impl Cmd { }; Cmd::Info(section) } - "replconf" => { - if cmd.len() < 3 { - return Err(DBError(format!("unsupported cmd {:?}", cmd))); - } - Cmd::Replconf(cmd[1].clone()) - } - "psync" => { - if cmd.len() != 3 { - return Err(DBError(format!("unsupported cmd {:?}", cmd))); - } - Cmd::Psync - } "del" => { if cmd.len() != 2 { return Err(DBError(format!("unsupported cmd {:?}", cmd))); @@ -100,44 +90,6 @@ impl Cmd { } Cmd::Type(cmd[1].clone()) } - "xadd" => { - if cmd.len() < 5 { - return Err(DBError(format!("unsupported cmd {:?}", cmd))); - } - - let mut key_value = Vec::<(String, String)>::new(); - let mut i = 3; - while i < cmd.len() - 1 { - key_value.push((cmd[i].clone(), cmd[i + 1].clone())); - i += 2; - } - Cmd::Xadd(cmd[1].clone(), cmd[2].clone(), key_value) - } - "xrange" => { - if cmd.len() != 4 { - return Err(DBError(format!("unsupported cmd {:?}", cmd))); - } - Cmd::Xrange(cmd[1].clone(), cmd[2].clone(), cmd[3].clone()) - } - "xread" => { - if cmd.len() < 4 || cmd.len() % 2 != 0 { - return Err(DBError(format!("unsupported cmd {:?}", cmd))); - } - let mut offset = 2; - // block cmd - let mut block = None; - if cmd[1] == "block" { - offset += 2; - if let Ok(block_time) = cmd[2].parse() { - block = Some(block_time); - } else { - return Err(DBError(format!("unsupported cmd {:?}", cmd))); - } - } - let cmd2 = &cmd[offset..]; - let len2 = cmd2.len() / 2; - Cmd::Xread(cmd2[0..len2].to_vec(), cmd2[len2..].to_vec(), block) - } "incr" => { if cmd.len() != 2 { return Err(DBError(format!("unsupported cmd {:?}", cmd))); @@ -157,6 +109,73 @@ impl Cmd { Cmd::Exec } "discard" => Cmd::Discard, + // Hash commands + "hset" => { + if cmd.len() < 4 || (cmd.len() - 2) % 2 != 0 { + return Err(DBError(format!("wrong number of arguments for HSET command"))); + } + let mut pairs = Vec::new(); + let mut i = 2; + while i < cmd.len() - 1 { + pairs.push((cmd[i].clone(), cmd[i + 1].clone())); + i += 2; + } + Cmd::HSet(cmd[1].clone(), pairs) + } + "hget" => { + if cmd.len() != 3 { + return Err(DBError(format!("wrong number of arguments for HGET command"))); + } + Cmd::HGet(cmd[1].clone(), cmd[2].clone()) + } + "hgetall" => { + if cmd.len() != 2 { + return Err(DBError(format!("wrong number of arguments for HGETALL command"))); + } + Cmd::HGetAll(cmd[1].clone()) + } + "hdel" => { + if cmd.len() < 3 { + return Err(DBError(format!("wrong number of arguments for HDEL command"))); + } + Cmd::HDel(cmd[1].clone(), cmd[2..].to_vec()) + } + "hexists" => { + if cmd.len() != 3 { + return Err(DBError(format!("wrong number of arguments for HEXISTS command"))); + } + Cmd::HExists(cmd[1].clone(), cmd[2].clone()) + } + "hkeys" => { + if cmd.len() != 2 { + return Err(DBError(format!("wrong number of arguments for HKEYS command"))); + } + Cmd::HKeys(cmd[1].clone()) + } + "hvals" => { + if cmd.len() != 2 { + return Err(DBError(format!("wrong number of arguments for HVALS command"))); + } + Cmd::HVals(cmd[1].clone()) + } + "hlen" => { + if cmd.len() != 2 { + return Err(DBError(format!("wrong number of arguments for HLEN command"))); + } + Cmd::HLen(cmd[1].clone()) + } + "hmget" => { + if cmd.len() < 3 { + return Err(DBError(format!("wrong number of arguments for HMGET command"))); + } + Cmd::HMGet(cmd[1].clone(), cmd[2..].to_vec()) + } + "hsetnx" => { + if cmd.len() != 4 { + return Err(DBError(format!("wrong number of arguments for HSETNX command"))); + } + Cmd::HSetNx(cmd[1].clone(), cmd[2].clone(), cmd[3].clone()) + } _ => Cmd::Unknow, }, protocol.0, @@ -171,13 +190,11 @@ impl Cmd { pub async fn run( &self, - server: &mut Server, + server: &Server, protocol: Protocol, - is_rep_con: bool, queued_cmd: &mut Option>, ) -> Result { - // return if the command is a write command - let p = protocol.clone(); + // Handle queued commands for transactions if queued_cmd.is_some() && !matches!(self, Cmd::Exec) && !matches!(self, Cmd::Multi) @@ -189,71 +206,57 @@ impl Cmd { .push((self.clone(), protocol.clone())); return Ok(Protocol::SimpleString("QUEUED".to_string())); } - let ret = match self { + + match self { Cmd::Ping => Ok(Protocol::SimpleString("PONG".to_string())), Cmd::Echo(s) => Ok(Protocol::SimpleString(s.clone())), Cmd::Get(k) => get_cmd(server, k).await, - Cmd::Set(k, v) => set_cmd(server, k, v, protocol, is_rep_con).await, - Cmd::SetPx(k, v, x) => set_px_cmd(server, k, v, x, protocol, is_rep_con).await, - Cmd::SetEx(k, v, x) => set_ex_cmd(server, k, v, x, protocol, is_rep_con).await, - Cmd::Del(k) => del_cmd(server, k, protocol, is_rep_con).await, + Cmd::Set(k, v) => set_cmd(server, k, v).await, + Cmd::SetPx(k, v, x) => set_px_cmd(server, k, v, x).await, + Cmd::SetEx(k, v, x) => set_ex_cmd(server, k, v, x).await, + Cmd::Del(k) => del_cmd(server, k).await, Cmd::ConfigGet(name) => config_get_cmd(name, server), Cmd::Keys => keys_cmd(server).await, - Cmd::Info(section) => info_cmd(section, server), - Cmd::Replconf(sub_cmd) => replconf_cmd(sub_cmd, server), - Cmd::Psync => psync_cmd(server), + Cmd::Info(section) => info_cmd(section), Cmd::Type(k) => type_cmd(server, k).await, - Cmd::Xadd(stream_key, offset, kvps) => { - xadd_cmd( - offset.as_str(), - server, - stream_key.as_str(), - kvps, - protocol, - is_rep_con, - ) - .await - } - - Cmd::Xrange(stream_key, start, end) => xrange_cmd(server, stream_key, start, end).await, - Cmd::Xread(stream_keys, starts, block) => { - xread_cmd(starts, server, stream_keys, block).await - } Cmd::Incr(key) => incr_cmd(server, key).await, Cmd::Multi => { *queued_cmd = Some(Vec::<(Cmd, Protocol)>::new()); - Ok(Protocol::SimpleString("ok".to_string())) + Ok(Protocol::SimpleString("OK".to_string())) } - Cmd::Exec => exec_cmd(queued_cmd, server, is_rep_con).await, + Cmd::Exec => exec_cmd(queued_cmd, server).await, Cmd::Discard => { if queued_cmd.is_some() { *queued_cmd = None; - Ok(Protocol::SimpleString("ok".to_string())) + Ok(Protocol::SimpleString("OK".to_string())) } else { - Ok(Protocol::err("ERR Discard without MULTI")) + Ok(Protocol::err("ERR DISCARD without MULTI")) } } - Cmd::Unknow => Ok(Protocol::err("unknow cmd")), - }; - if ret.is_ok() { - server.offset.fetch_add( - p.encode().len() as u64, - std::sync::atomic::Ordering::Relaxed, - ); + // Hash commands + Cmd::HSet(key, pairs) => hset_cmd(server, key, pairs).await, + Cmd::HGet(key, field) => hget_cmd(server, key, field).await, + Cmd::HGetAll(key) => hgetall_cmd(server, key).await, + Cmd::HDel(key, fields) => hdel_cmd(server, key, fields).await, + Cmd::HExists(key, field) => hexists_cmd(server, key, field).await, + Cmd::HKeys(key) => hkeys_cmd(server, key).await, + Cmd::HVals(key) => hvals_cmd(server, key).await, + Cmd::HLen(key) => hlen_cmd(server, key).await, + Cmd::HMGet(key, fields) => hmget_cmd(server, key, fields).await, + Cmd::HSetNx(key, field, value) => hsetnx_cmd(server, key, field, value).await, + Cmd::Unknow => Ok(Protocol::err("unknown cmd")), } - ret } } async fn exec_cmd( queued_cmd: &mut Option>, - server: &mut Server, - is_rep_con: bool, + server: &Server, ) -> Result { if queued_cmd.is_some() { let mut vec = Vec::new(); for (cmd, protocol) in queued_cmd.as_ref().unwrap() { - let res = Box::pin(cmd.run(server, protocol.clone(), is_rep_con, &mut None)).await?; + let res = Box::pin(cmd.run(server, protocol.clone(), &mut None)).await?; vec.push(res); } *queued_cmd = None; @@ -263,22 +266,24 @@ async fn exec_cmd( } } -async fn incr_cmd(server: &mut Server, key: &String) -> Result { - let mut storage = server.storage.lock().await; - let v = storage.get(key); - // return 1 if key is missing - let v = v.map_or("1".to_string(), |v| v); - - if let Ok(x) = v.parse::() { - let v = (x + 1).to_string(); - storage.set(key.clone(), v.clone()); - Ok(Protocol::SimpleString(v)) - } else { - Ok(Protocol::err("ERR value is not an integer or out of range")) - } +async fn incr_cmd(server: &Server, key: &String) -> Result { + let current_value = server.storage.get(key)?; + + let new_value = match current_value { + Some(v) => { + match v.parse::() { + Ok(num) => num + 1, + Err(_) => return Ok(Protocol::err("ERR value is not an integer or out of range")), + } + } + None => 1, + }; + + server.storage.set(key.clone(), new_value.to_string())?; + Ok(Protocol::SimpleString(new_value.to_string())) } -fn config_get_cmd(name: &String, server: &mut Server) -> Result { +fn config_get_cmd(name: &String, server: &Server) -> Result { match name.as_str() { "dir" => Ok(Protocol::Array(vec![ Protocol::BulkString(name.clone()), @@ -286,336 +291,156 @@ fn config_get_cmd(name: &String, server: &mut Server) -> Result Ok(Protocol::Array(vec![ Protocol::BulkString(name.clone()), - Protocol::BulkString(server.option.db_file_name.clone()), + Protocol::BulkString("herodb.redb".to_string()), ])), _ => Err(DBError(format!("unsupported config {:?}", name))), } } -async fn keys_cmd(server: &mut Server) -> Result { - let keys = { server.storage.lock().await.keys() }; +async fn keys_cmd(server: &Server) -> Result { + let keys = server.storage.keys("*")?; Ok(Protocol::Array( keys.into_iter().map(Protocol::BulkString).collect(), )) } -fn info_cmd(section: &Option, server: &mut Server) -> Result { +fn info_cmd(section: &Option) -> Result { match section { Some(s) => match s.as_str() { - "replication" => Ok(Protocol::BulkString(format!( - "role:{}\nmaster_replid:{}\nmaster_repl_offset:{}\n", - server.option.replication.role, - server.option.replication.master_replid, - server.option.replication.master_repl_offset - ))), + "replication" => Ok(Protocol::BulkString( + "role:master\nmaster_replid:8371b4fb1155b71f4a04d3e1bc3e18c4a990aeea\nmaster_repl_offset:0\n".to_string() + )), _ => Err(DBError(format!("unsupported section {:?}", s))), }, - None => Ok(Protocol::BulkString("default".to_string())), + None => Ok(Protocol::BulkString("# Server\nredis_version:7.0.0\n".to_string())), } } -async fn xread_cmd( - starts: &[String], - server: &mut Server, - stream_keys: &[String], - block_millis: &Option, -) -> Result { - if let Some(t) = block_millis { - if t > &0 { - tokio::time::sleep(Duration::from_millis(*t)).await; - } else { - let (sender, mut receiver) = mpsc::channel(4); - { - let mut blocker = server.stream_reader_blocker.lock().await; - blocker.push(sender.clone()); - } - while let Some(_) = receiver.recv().await { - println!("get new xadd cmd, release block"); - // break; - } - } - } - let streams = server.streams.lock().await; - let mut ret = Vec::new(); - for (i, stream_key) in stream_keys.iter().enumerate() { - let stream = streams.get(stream_key); - if let Some(s) = stream { - let (offset_id, mut offset_seq, _) = split_offset(starts[i].as_str()); - offset_seq += 1; - let start = format!("{}-{}", offset_id, offset_seq); - let end = format!("{}-{}", u64::MAX - 1, 0); - - // query stream range - let range = s.range::((Bound::Included(&start), Bound::Included(&end))); - let mut array = Vec::new(); - for (k, v) in range { - array.push(Protocol::BulkString(k.clone())); - array.push(Protocol::from_vec( - v.iter() - .flat_map(|(a, b)| vec![a.as_str(), b.as_str()]) - .collect(), - )) - } - ret.push(Protocol::BulkString(stream_key.clone())); - ret.push(Protocol::Array(array)); - } - } - Ok(Protocol::Array(ret)) -} - -fn replconf_cmd(sub_cmd: &str, server: &mut Server) -> Result { - match sub_cmd { - "getack" => Ok(Protocol::from_vec(vec![ - "REPLCONF", - "ACK", - server - .offset - .load(std::sync::atomic::Ordering::Relaxed) - .to_string() - .as_str(), - ])), - _ => Ok(Protocol::SimpleString("OK".to_string())), +async fn type_cmd(server: &Server, k: &String) -> Result { + match server.storage.get_key_type(k)? { + Some(type_str) => Ok(Protocol::SimpleString(type_str)), + None => Ok(Protocol::SimpleString("none".to_string())), } } -async fn xrange_cmd( - server: &mut Server, - stream_key: &String, - start: &String, - end: &String, -) -> Result { - let streams = server.streams.lock().await; - let stream = streams.get(stream_key); - Ok(stream.map_or(Protocol::none(), |s| { - // support query with '-' - let start = if start == "-" { - "0".to_string() - } else { - start.clone() - }; - - // support query with '+' - let end = if end == "+" { - u64::MAX.to_string() - } else { - end.clone() - }; - - // query stream range - let range = s.range::((Bound::Included(&start), Bound::Included(&end))); - let mut array = Vec::new(); - for (k, v) in range { - array.push(Protocol::BulkString(k.clone())); - array.push(Protocol::from_vec( - v.iter() - .flat_map(|(a, b)| vec![a.as_str(), b.as_str()]) - .collect(), - )) - } - println!("after xrange: {:?}", array); - Protocol::Array(array) - })) -} - -async fn xadd_cmd( - offset: &str, - server: &mut Server, - stream_key: &str, - kvps: &Vec<(String, String)>, - protocol: Protocol, - is_rep_con: bool, -) -> Result { - let mut offset = offset.to_string(); - if offset == "*" { - offset = format!("{}-*", now_in_millis() as u64); - } - let (offset_id, mut offset_seq, has_wildcard) = split_offset(offset.as_str()); - if offset_id == 0 && offset_seq == 0 && !has_wildcard { - return Ok(Protocol::err( - "ERR The ID specified in XADD must be greater than 0-0", - )); - } - { - let mut streams = server.streams.lock().await; - let stream = streams - .entry(stream_key.to_string()) - .or_insert_with(BTreeMap::new); - - if let Some((last_offset, _)) = stream.last_key_value() { - let (last_offset_id, last_offset_seq, _) = split_offset(last_offset.as_str()); - if last_offset_id > offset_id - || (last_offset_id == offset_id && last_offset_seq >= offset_seq && !has_wildcard) - { - return Ok(Protocol::err("ERR The ID specified in XADD is equal or smaller than the target stream top item")); - } - - if last_offset_id == offset_id && last_offset_seq >= offset_seq && has_wildcard { - offset_seq = last_offset_seq + 1; - } - } - - let offset = format!("{}-{}", offset_id, offset_seq); - - let s = stream.entry(offset.clone()).or_insert_with(Vec::new); - for (key, value) in kvps { - s.push((key.clone(), value.clone())); - } - } - { - let mut blocker = server.stream_reader_blocker.lock().await; - for sender in blocker.iter() { - sender.send(()).await?; - } - blocker.clear(); - } - resp_and_replicate( - server, - Protocol::BulkString(offset.to_string()), - protocol, - is_rep_con, - ) - .await -} - -async fn type_cmd(server: &mut Server, k: &String) -> Result { - let v = { server.storage.lock().await.get(k) }; - if v.is_some() { - return Ok(Protocol::SimpleString("string".to_string())); - } - let streams = server.streams.lock().await; - let v = streams.get(k); - Ok(v.map_or(Protocol::none(), |_| { - Protocol::SimpleString("stream".to_string()) - })) -} - -fn psync_cmd(server: &mut Server) -> Result { - if server.is_master() { - Ok(Protocol::SimpleString(format!( - "FULLRESYNC {} 0", - server.option.replication.master_replid - ))) - } else { - Ok(Protocol::psync_on_slave_err()) - } -} - -async fn del_cmd( - server: &mut Server, - k: &str, - protocol: Protocol, - is_rep_con: bool, -) -> Result { - // offset - let _ = { - let mut s = server.storage.lock().await; - s.del(k.to_string()); - server - .offset - .fetch_add(1, std::sync::atomic::Ordering::Relaxed) - }; - resp_and_replicate(server, Protocol::ok(), protocol, is_rep_con).await +async fn del_cmd(server: &Server, k: &str) -> Result { + server.storage.del(k.to_string())?; + Ok(Protocol::SimpleString("1".to_string())) } async fn set_ex_cmd( - server: &mut Server, + server: &Server, k: &str, v: &str, x: &u128, - protocol: Protocol, - is_rep_con: bool, ) -> Result { - // offset - let _ = { - let mut s = server.storage.lock().await; - s.setx(k.to_string(), v.to_string(), *x * 1000); - server - .offset - .fetch_add(1, std::sync::atomic::Ordering::Relaxed) - }; - resp_and_replicate(server, Protocol::ok(), protocol, is_rep_con).await + server.storage.setx(k.to_string(), v.to_string(), *x * 1000)?; + Ok(Protocol::SimpleString("OK".to_string())) } async fn set_px_cmd( - server: &mut Server, + server: &Server, k: &str, v: &str, x: &u128, - protocol: Protocol, - is_rep_con: bool, ) -> Result { - // offset - let _ = { - let mut s = server.storage.lock().await; - s.setx(k.to_string(), v.to_string(), *x); - server - .offset - .fetch_add(1, std::sync::atomic::Ordering::Relaxed) - }; - resp_and_replicate(server, Protocol::ok(), protocol, is_rep_con).await + server.storage.setx(k.to_string(), v.to_string(), *x)?; + Ok(Protocol::SimpleString("OK".to_string())) } -async fn set_cmd( - server: &mut Server, - k: &str, - v: &str, - protocol: Protocol, - is_rep_con: bool, -) -> Result { - // offset - let _ = { - let mut s = server.storage.lock().await; - s.set(k.to_string(), v.to_string()); - server - .offset - .fetch_add(1, std::sync::atomic::Ordering::Relaxed) - + 1 - }; - resp_and_replicate(server, Protocol::ok(), protocol, is_rep_con).await +async fn set_cmd(server: &Server, k: &str, v: &str) -> Result { + server.storage.set(k.to_string(), v.to_string())?; + Ok(Protocol::SimpleString("OK".to_string())) } -async fn get_cmd(server: &mut Server, k: &str) -> Result { - let v = { - let mut s = server.storage.lock().await; - s.get(k) - }; - Ok(v.map_or(Protocol::Null, Protocol::SimpleString)) +async fn get_cmd(server: &Server, k: &str) -> Result { + let v = server.storage.get(k)?; + Ok(v.map_or(Protocol::Null, Protocol::BulkString)) } -async fn resp_and_replicate( - server: &mut Server, - resp: Protocol, - replication: Protocol, - is_rep_con: bool, -) -> Result { - if server.is_master() { - server - .master_repl_clients - .lock() - .await - .as_mut() - .unwrap() - .send_command(replication) - .await?; - Ok(resp) - } else if !is_rep_con { - Ok(Protocol::write_on_slave_err()) - } else { - Ok(resp) +// Hash command implementations +async fn hset_cmd(server: &Server, key: &str, pairs: &[(String, String)]) -> Result { + let new_fields = server.storage.hset(key, pairs)?; + Ok(Protocol::SimpleString(new_fields.to_string())) +} + +async fn hget_cmd(server: &Server, key: &str, field: &str) -> Result { + match server.storage.hget(key, field) { + Ok(Some(value)) => Ok(Protocol::BulkString(value)), + Ok(None) => Ok(Protocol::Null), + Err(e) => Ok(Protocol::err(&e.0)), } } -fn split_offset(offset: &str) -> (u64, u64, bool) { - let offset_split = offset.split('-').collect::>(); - let offset_id = offset_split[0].parse::().expect(&format!( - "ERR The ID specified in XADD must be a number: {}", - offset - )); - - if offset_split.len() == 1 || offset_split[1] == "*" { - return (offset_id, if offset_id == 0 { 1 } else { 0 }, true); +async fn hgetall_cmd(server: &Server, key: &str) -> Result { + match server.storage.hgetall(key) { + Ok(pairs) => { + let mut result = Vec::new(); + for (field, value) in pairs { + result.push(Protocol::BulkString(field)); + result.push(Protocol::BulkString(value)); + } + Ok(Protocol::Array(result)) + } + Err(e) => Ok(Protocol::err(&e.0)), + } +} + +async fn hdel_cmd(server: &Server, key: &str, fields: &[String]) -> Result { + match server.storage.hdel(key, fields) { + Ok(deleted) => Ok(Protocol::SimpleString(deleted.to_string())), + Err(e) => Ok(Protocol::err(&e.0)), + } +} + +async fn hexists_cmd(server: &Server, key: &str, field: &str) -> Result { + match server.storage.hexists(key, field) { + Ok(exists) => Ok(Protocol::SimpleString(if exists { "1" } else { "0" }.to_string())), + Err(e) => Ok(Protocol::err(&e.0)), + } +} + +async fn hkeys_cmd(server: &Server, key: &str) -> Result { + match server.storage.hkeys(key) { + Ok(keys) => Ok(Protocol::Array( + keys.into_iter().map(Protocol::BulkString).collect(), + )), + Err(e) => Ok(Protocol::err(&e.0)), + } +} + +async fn hvals_cmd(server: &Server, key: &str) -> Result { + match server.storage.hvals(key) { + Ok(values) => Ok(Protocol::Array( + values.into_iter().map(Protocol::BulkString).collect(), + )), + Err(e) => Ok(Protocol::err(&e.0)), + } +} + +async fn hlen_cmd(server: &Server, key: &str) -> Result { + match server.storage.hlen(key) { + Ok(len) => Ok(Protocol::SimpleString(len.to_string())), + Err(e) => Ok(Protocol::err(&e.0)), + } +} + +async fn hmget_cmd(server: &Server, key: &str, fields: &[String]) -> Result { + match server.storage.hmget(key, fields) { + Ok(values) => { + let result: Vec = values + .into_iter() + .map(|v| v.map_or(Protocol::Null, Protocol::BulkString)) + .collect(); + Ok(Protocol::Array(result)) + } + Err(e) => Ok(Protocol::err(&e.0)), + } +} + +async fn hsetnx_cmd(server: &Server, key: &str, field: &str, value: &str) -> Result { + match server.storage.hsetnx(key, field, value) { + Ok(was_set) => Ok(Protocol::SimpleString(if was_set { "1" } else { "0" }.to_string())), + Err(e) => Ok(Protocol::err(&e.0)), } - - let offset_seq = offset_split[1].parse::().unwrap(); - (offset_id, offset_seq, false) } diff --git a/src/error.rs b/src/error.rs index 36555d2..feed31e 100644 --- a/src/error.rs +++ b/src/error.rs @@ -1,6 +1,8 @@ use std::num::ParseIntError; use tokio::sync::mpsc; +use redb; +use bincode; use crate::protocol::Protocol; @@ -32,11 +34,48 @@ impl From for DBError { } } -impl From> for DBError { - fn from(item: mpsc::error::SendError<(Protocol, u64)>) -> Self { - DBError(item.to_string().clone()) +impl From for DBError { + fn from(item: redb::Error) -> Self { + DBError(item.to_string()) } } + +impl From for DBError { + fn from(item: redb::DatabaseError) -> Self { + DBError(item.to_string()) + } +} + +impl From for DBError { + fn from(item: redb::TransactionError) -> Self { + DBError(item.to_string()) + } +} + +impl From for DBError { + fn from(item: redb::TableError) -> Self { + DBError(item.to_string()) + } +} + +impl From for DBError { + fn from(item: redb::StorageError) -> Self { + DBError(item.to_string()) + } +} + +impl From for DBError { + fn from(item: redb::CommitError) -> Self { + DBError(item.to_string()) + } +} + +impl From> for DBError { + fn from(item: Box) -> Self { + DBError(item.to_string()) + } +} + impl From> for DBError { fn from(item: mpsc::error::SendError<()>) -> Self { DBError(item.to_string().clone()) diff --git a/src/lib.rs b/src/lib.rs index 6c3654e..cc46e75 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -2,7 +2,5 @@ mod cmd; pub mod error; pub mod options; mod protocol; -mod rdb; -mod replication_client; pub mod server; mod storage; diff --git a/src/main.rs b/src/main.rs index 964f4b3..12ebc7c 100644 --- a/src/main.rs +++ b/src/main.rs @@ -2,7 +2,7 @@ use tokio::net::TcpListener; -use redis_rs::{options::ReplicationOption, server}; +use redis_rs::server; use clap::Parser; @@ -14,17 +14,10 @@ struct Args { #[arg(long)] dir: String, - /// The name of the Redis DB file - #[arg(long)] - dbfilename: String, /// The port of the Redis server, default is 6379 if not specified #[arg(long)] port: Option, - - /// The address of the master Redis server, if the server is a replica. None if the server is a master. - #[arg(long)] - replicaof: Option, } #[tokio::main] @@ -42,42 +35,11 @@ async fn main() { // new DB option let option = redis_rs::options::DBOption { dir: args.dir, - db_file_name: args.dbfilename, port, - replication: ReplicationOption { - role: if let Some(_) = args.replicaof { - "slave".to_string() - } else { - "master".to_string() - }, - master_replid: "8371b4fb1155b71f4a04d3e1bc3e18c4a990aeea".to_string(), // should be a random string but hard code for now - master_repl_offset: 0, - replica_of: args.replicaof, - }, }; // new server - let mut server = server::Server::new(option).await; - - //start receive replication cmds for slave - if server.is_slave() { - let mut sc = server.clone(); - - let mut follower_repl_client = server.get_follower_repl_client().await.unwrap(); - follower_repl_client.ping_master().await.unwrap(); - follower_repl_client - .report_port(server.option.port) - .await - .unwrap(); - follower_repl_client.report_sync_protocol().await.unwrap(); - follower_repl_client.start_psync(&mut sc).await.unwrap(); - - tokio::spawn(async move { - if let Err(e) = sc.handle(follower_repl_client.stream, true).await { - println!("error: {:?}, will close the connection. Bye", e); - } - }); - } + let server = server::Server::new(option).await; // accept new connections loop { @@ -88,7 +50,7 @@ async fn main() { let mut sc = server.clone(); tokio::spawn(async move { - if let Err(e) = sc.handle(stream, false).await { + if let Err(e) = sc.handle(stream).await { println!("error: {:?}, will close the connection. Bye", e); } }); diff --git a/src/options.rs b/src/options.rs index 5bc533b..77afc2b 100644 --- a/src/options.rs +++ b/src/options.rs @@ -1,15 +1,5 @@ #[derive(Clone)] pub struct DBOption { pub dir: String, - pub db_file_name: String, - pub replication: ReplicationOption, pub port: u16, } - -#[derive(Clone)] -pub struct ReplicationOption { - pub role: String, - pub master_replid: String, - pub master_repl_offset: u64, - pub replica_of: Option, -} diff --git a/src/rdb.rs b/src/rdb.rs deleted file mode 100644 index 3a384fc..0000000 --- a/src/rdb.rs +++ /dev/null @@ -1,201 +0,0 @@ -// parse Redis RDB file format: https://rdb.fnordig.de/file_format.html - -use tokio::{ - fs, - io::{AsyncRead, AsyncReadExt, BufReader}, -}; - -use crate::{error::DBError, server::Server}; - -use futures::pin_mut; - -enum StringEncoding { - Raw, - I8, - I16, - I32, - LZF, -} - -// RDB file format. -const MAGIC: &[u8; 5] = b"REDIS"; -const META: u8 = 0xFA; -const DB_SELECT: u8 = 0xFE; -const TABLE_SIZE_INFO: u8 = 0xFB; -pub const EOF: u8 = 0xFF; - -pub async fn parse_rdb( - reader: &mut R, - server: &mut Server, -) -> Result<(), DBError> { - let mut storage = server.storage.lock().await; - parse_magic(reader).await?; - let _version = parse_version(reader).await?; - pin_mut!(reader); - loop { - let op = reader.read_u8().await?; - match op { - META => { - let _ = parse_aux(&mut *reader).await?; - let _ = parse_aux(&mut *reader).await?; - // just ignore the aux info for now - } - DB_SELECT => { - let (_, _) = parse_len(&mut *reader).await?; - // just ignore the db index for now - } - TABLE_SIZE_INFO => { - let size_no_expire = parse_len(&mut *reader).await?.0; - let size_expire = parse_len(&mut *reader).await?.0; - for _ in 0..size_no_expire { - let (k, v) = parse_no_expire_entry(&mut *reader).await?; - storage.set(k, v); - } - for _ in 0..size_expire { - let (k, v, expire_timestamp) = parse_expire_entry(&mut *reader).await?; - storage.setx(k, v, expire_timestamp); - } - } - EOF => { - // not verify crc for now - let _crc = reader.read_u64().await?; - break; - } - _ => return Err(DBError(format!("unexpected op: {}", op))), - } - } - Ok(()) -} - -pub async fn parse_rdb_file(f: &mut fs::File, server: &mut Server) -> Result<(), DBError> { - let mut reader = BufReader::new(f); - parse_rdb(&mut reader, server).await -} - -async fn parse_no_expire_entry( - input: &mut R, -) -> Result<(String, String), DBError> { - let b = input.read_u8().await?; - if b != 0 { - return Err(DBError(format!("unexpected key type: {}", b))); - } - let k = parse_aux(input).await?; - let v = parse_aux(input).await?; - Ok((k, v)) -} - -async fn parse_expire_entry( - input: &mut R, -) -> Result<(String, String, u128), DBError> { - let b = input.read_u8().await?; - match b { - 0xFC => { - // expire in milliseconds - let expire_stamp = input.read_u64_le().await?; - let (k, v) = parse_no_expire_entry(input).await?; - Ok((k, v, expire_stamp as u128)) - } - 0xFD => { - // expire in seconds - let expire_timestamp = input.read_u32_le().await?; - let (k, v) = parse_no_expire_entry(input).await?; - Ok((k, v, (expire_timestamp * 1000) as u128)) - } - _ => return Err(DBError(format!("unexpected expire type: {}", b))), - } -} - -async fn parse_magic(input: &mut R) -> Result<(), DBError> { - let mut magic = [0; 5]; - let size_read = input.read(&mut magic).await?; - if size_read != 5 { - Err(DBError("expected 5 chars for magic number".to_string())) - } else if magic.as_slice() == MAGIC { - Ok(()) - } else { - Err(DBError(format!( - "expected magic string {:?}, but got: {:?}", - MAGIC, magic - ))) - } -} - -async fn parse_version(input: &mut R) -> Result<[u8; 4], DBError> { - let mut version = [0; 4]; - let size_read = input.read(&mut version).await?; - if size_read != 4 { - Err(DBError("expected 4 chars for redis version".to_string())) - } else { - Ok(version) - } -} - -async fn parse_aux(input: &mut R) -> Result { - let (len, encoding) = parse_len(input).await?; - let s = parse_string(input, len, encoding).await?; - Ok(s) -} - -async fn parse_len(input: &mut R) -> Result<(u32, StringEncoding), DBError> { - let first = input.read_u8().await?; - match first & 0xC0 { - 0x00 => { - // The size is the remaining 6 bits of the byte. - Ok((first as u32, StringEncoding::Raw)) - } - 0x04 => { - // The size is the next 14 bits of the byte. - let second = input.read_u8().await?; - Ok(( - (((first & 0x3F) as u32) << 8 | second as u32) as u32, - StringEncoding::Raw, - )) - } - 0x80 => { - //Ignore the remaining 6 bits of the first byte. The size is the next 4 bytes, in big-endian - let second = input.read_u32().await?; - Ok((second, StringEncoding::Raw)) - } - 0xC0 => { - // The remaining 6 bits specify a type of string encoding. - match first { - 0xC0 => Ok((1, StringEncoding::I8)), - 0xC1 => Ok((2, StringEncoding::I16)), - 0xC2 => Ok((4, StringEncoding::I32)), - 0xC3 => Ok((0, StringEncoding::LZF)), // not supported yet - _ => Err(DBError(format!("unexpected string encoding: {}", first))), - } - } - _ => Err(DBError(format!("unexpected len prefix: {}", first))), - } -} - -async fn parse_string( - input: &mut R, - len: u32, - encoding: StringEncoding, -) -> Result { - match encoding { - StringEncoding::Raw => { - let mut s = vec![0; len as usize]; - input.read_exact(&mut s).await?; - Ok(String::from_utf8(s).unwrap()) - } - StringEncoding::I8 => { - let b = input.read_u8().await?; - Ok(b.to_string()) - } - StringEncoding::I16 => { - let b = input.read_u16_le().await?; - Ok(b.to_string()) - } - StringEncoding::I32 => { - let b = input.read_u32_le().await?; - Ok(b.to_string()) - } - StringEncoding::LZF => { - // not supported yet - Err(DBError("LZF encoding not supported yet".to_string())) - } - } -} diff --git a/src/replication_client.rs b/src/replication_client.rs deleted file mode 100644 index 572aae7..0000000 --- a/src/replication_client.rs +++ /dev/null @@ -1,155 +0,0 @@ -use std::{num::ParseIntError, sync::Arc}; - -use tokio::{ - io::{AsyncBufReadExt, AsyncReadExt, AsyncWriteExt, BufReader}, - net::TcpStream, - sync::Mutex, -}; - -use crate::{error::DBError, protocol::Protocol, rdb, server::Server}; - -const EMPTY_RDB_FILE_HEX_STRING: &str = "524544495330303131fa0972656469732d76657205372e322e30fa0a72656469732d62697473c040fa056374696d65c26d08bc65fa08757365642d6d656dc2b0c41000fa08616f662d62617365c000fff06e3bfec0ff5aa2"; - -pub struct FollowerReplicationClient { - pub stream: TcpStream, -} - -impl FollowerReplicationClient { - pub async fn new(addr: String) -> FollowerReplicationClient { - FollowerReplicationClient { - stream: TcpStream::connect(addr).await.unwrap(), - } - } - - pub async fn ping_master(self: &mut Self) -> Result<(), DBError> { - let protocol = Protocol::Array(vec![Protocol::BulkString("PING".to_string())]); - self.stream.write_all(protocol.encode().as_bytes()).await?; - - self.check_resp("PONG").await - } - - pub async fn report_port(self: &mut Self, port: u16) -> Result<(), DBError> { - let protocol = Protocol::from_vec(vec![ - "REPLCONF", - "listening-port", - port.to_string().as_str(), - ]); - self.stream.write_all(protocol.encode().as_bytes()).await?; - - self.check_resp("OK").await - } - - pub async fn report_sync_protocol(self: &mut Self) -> Result<(), DBError> { - let p = Protocol::from_vec(vec!["REPLCONF", "capa", "psync2"]); - self.stream.write_all(p.encode().as_bytes()).await?; - self.check_resp("OK").await - } - - pub async fn start_psync(self: &mut Self, server: &mut Server) -> Result<(), DBError> { - let p = Protocol::from_vec(vec!["PSYNC", "?", "-1"]); - self.stream.write_all(p.encode().as_bytes()).await?; - self.recv_rdb_file(server).await?; - Ok(()) - } - - pub async fn recv_rdb_file(self: &mut Self, server: &mut Server) -> Result<(), DBError> { - let mut reader = BufReader::new(&mut self.stream); - - let mut buf = Vec::new(); - let _ = reader.read_until(b'\n', &mut buf).await?; - buf.pop(); - buf.pop(); - - let replication_info = String::from_utf8(buf)?; - let replication_info = replication_info - .split_whitespace() - .map(|x| x.to_string()) - .collect::>(); - if replication_info.len() != 3 { - return Err(DBError(format!( - "expect 3 args but found {:?}", - replication_info - ))); - } - println!( - "Get replication info: {:?} {:?} {:?}", - replication_info[0], replication_info[1], replication_info[2] - ); - - let c = reader.read_u8().await?; - if c != b'$' { - return Err(DBError(format!("expect $ but found {}", c))); - } - let mut buf = Vec::new(); - reader.read_until(b'\n', &mut buf).await?; - buf.pop(); - buf.pop(); - let rdb_file_len = String::from_utf8(buf)?.parse::()?; - println!("rdb file len: {}", rdb_file_len); - - // receive rdb file content - rdb::parse_rdb(&mut reader, server).await?; - Ok(()) - } - - pub async fn check_resp(&mut self, expected: &str) -> Result<(), DBError> { - let mut buf = [0; 1024]; - let n_bytes = self.stream.read(&mut buf).await?; - println!( - "check resp: recv {:?}", - String::from_utf8(buf[..n_bytes].to_vec()).unwrap() - ); - let expect = Protocol::SimpleString(expected.to_string()).encode(); - if expect.as_bytes() != &buf[..n_bytes] { - return Err(DBError(format!( - "expect response {:?} but found {:?}", - expect, - &buf[..n_bytes] - ))); - } - Ok(()) - } -} - -#[derive(Clone)] -pub struct MasterReplicationClient { - pub streams: Arc>>, -} - -impl MasterReplicationClient { - pub fn new() -> MasterReplicationClient { - MasterReplicationClient { - streams: Arc::new(Mutex::new(Vec::new())), - } - } - - pub async fn send_rdb_file(&mut self, stream: &mut TcpStream) -> Result<(), DBError> { - let empty_rdb_file_bytes = (0..EMPTY_RDB_FILE_HEX_STRING.len()) - .step_by(2) - .map(|i| u8::from_str_radix(&EMPTY_RDB_FILE_HEX_STRING[i..i + 2], 16)) - .collect::, ParseIntError>>()?; - - println!("going to send rdb file"); - _ = stream.write("$".as_bytes()).await?; - _ = stream - .write(empty_rdb_file_bytes.len().to_string().as_bytes()) - .await?; - _ = stream.write_all("\r\n".as_bytes()).await?; - _ = stream.write_all(&empty_rdb_file_bytes).await?; - Ok(()) - } - - pub async fn add_stream(&mut self, stream: TcpStream) -> Result<(), DBError> { - let mut streams = self.streams.lock().await; - streams.push(stream); - Ok(()) - } - - pub async fn send_command(&mut self, protocol: Protocol) -> Result<(), DBError> { - let mut streams = self.streams.lock().await; - for stream in streams.iter_mut() { - stream.write_all(protocol.encode().as_bytes()).await?; - } - Ok(()) - } -} diff --git a/src/server.rs b/src/server.rs index 03a0b0a..e464149 100644 --- a/src/server.rs +++ b/src/server.rs @@ -1,143 +1,63 @@ use core::str; -use std::collections::BTreeMap; -use std::collections::HashMap; use std::path::PathBuf; -use std::sync::atomic::AtomicU64; use std::sync::Arc; -use tokio::fs::OpenOptions; use tokio::io::AsyncReadExt; use tokio::io::AsyncWriteExt; -use tokio::sync::mpsc::Sender; -use tokio::sync::Mutex; use crate::cmd::Cmd; use crate::error::DBError; use crate::options; use crate::protocol::Protocol; -use crate::rdb; -use crate::replication_client::FollowerReplicationClient; -use crate::replication_client::MasterReplicationClient; use crate::storage::Storage; -type Stream = BTreeMap>; - #[derive(Clone)] pub struct Server { - pub storage: Arc>, - pub streams: Arc>>, + pub storage: Arc, pub option: options::DBOption, - pub offset: Arc, - pub master_repl_clients: Arc>>, - pub stream_reader_blocker: Arc>>>, - master_addr: Option, } impl Server { pub async fn new(option: options::DBOption) -> Self { - let master_addr = match option.replication.role.as_str() { - "slave" => Some( - option - .replication - .replica_of - .clone() - .unwrap() - .replace(' ', ":"), - ), - _ => None, - }; + // Create database file path with fixed filename + let db_file_path = PathBuf::from(option.dir.clone()).join("herodb.redb"); + println!("will open db file path: {}", db_file_path.display()); - let is_master = option.replication.role == "master"; + // Initialize storage with redb + let storage = Storage::new(db_file_path).expect("Failed to initialize storage"); - let mut server = Server { - storage: Arc::new(Mutex::new(Storage::new())), - streams: Arc::new(Mutex::new(HashMap::new())), + Server { + storage: Arc::new(storage), option, - master_repl_clients: if is_master { - Arc::new(Mutex::new(Some(MasterReplicationClient::new()))) - } else { - Arc::new(Mutex::new(None)) - }, - offset: Arc::new(AtomicU64::new(0)), - stream_reader_blocker: Arc::new(Mutex::new(Vec::new())), - master_addr, - }; - - server.init().await.unwrap(); - server - } - - pub async fn init(&mut self) -> Result<(), DBError> { - // master initialization - if self.is_master() { - println!("Start as master\n"); - let db_file_path = - PathBuf::from(self.option.dir.clone()).join(self.option.db_file_name.clone()); - println!("will open db file path: {}", db_file_path.display()); - - // create empty db file if not exits - let mut file = OpenOptions::new() - .read(true) - .write(true) - .create(true) - .truncate(false) - .open(db_file_path.clone()) - .await?; - - if file.metadata().await?.len() != 0 { - rdb::parse_rdb_file(&mut file, self).await?; - } - } - Ok(()) - } - - pub async fn get_follower_repl_client(&mut self) -> Option { - if self.is_slave() { - Some(FollowerReplicationClient::new(self.master_addr.clone().unwrap()).await) - } else { - None } } pub async fn handle( &mut self, mut stream: tokio::net::TcpStream, - is_rep_conn: bool, ) -> Result<(), DBError> { let mut buf = [0; 512]; let mut queued_cmd: Option> = None; + loop { if let Ok(len) = stream.read(&mut buf).await { if len == 0 { println!("[handle] connection closed"); return Ok(()); } + let s = str::from_utf8(&buf[..len])?; let (cmd, protocol) = Cmd::from(s).unwrap_or((Cmd::Unknow, Protocol::err("unknow cmd"))); println!("got command: {:?}, protocol: {:?}", cmd, protocol); let res = cmd - .run(self, protocol, is_rep_conn, &mut queued_cmd) + .run(self, protocol, &mut queued_cmd) .await .unwrap_or(Protocol::err("unknow cmd")); - print!("queued 2 cmd {:?}", queued_cmd); + print!("queued cmd {:?}", queued_cmd); - // only send response to normal client, do not send response to replication client - if !is_rep_conn { - println!("going to send response {}", res.encode()); - _ = stream.write(res.encode().as_bytes()).await?; - } - - // send a full RDB file to slave - if self.is_master() { - if let Cmd::Psync = cmd { - let mut master_rep_client = self.master_repl_clients.lock().await; - let master_rep_client = master_rep_client.as_mut().unwrap(); - master_rep_client.send_rdb_file(&mut stream).await?; - master_rep_client.add_stream(stream).await?; - break; - } - } + println!("going to send response {}", res.encode()); + _ = stream.write(res.encode().as_bytes()).await?; } else { println!("[handle] going to break"); break; @@ -145,12 +65,4 @@ impl Server { } Ok(()) } - - pub fn is_slave(&self) -> bool { - self.option.replication.role == "slave" - } - - pub fn is_master(&self) -> bool { - !self.is_slave() - } } diff --git a/src/storage.rs b/src/storage.rs index f44e069..77c66ca 100644 --- a/src/storage.rs +++ b/src/storage.rs @@ -1,13 +1,29 @@ use std::{ - collections::HashMap, + path::Path, time::{SystemTime, UNIX_EPOCH}, }; -pub type ValueType = (String, Option); +use redb::{Database, Error, ReadableTable, Table, TableDefinition, WriteTransaction, ReadTransaction}; +use serde::{Deserialize, Serialize}; -pub struct Storage { - // key -> (value, (insert/update time, expire milli seconds)) - set: HashMap, +use crate::error::DBError; + +// Table definitions for different Redis data types +const TYPES_TABLE: TableDefinition<&str, &str> = TableDefinition::new("types"); +const STRINGS_TABLE: TableDefinition<&str, &[u8]> = TableDefinition::new("strings"); +const HASHES_TABLE: TableDefinition<(&str, &str), &str> = TableDefinition::new("hashes"); +const STREAMS_META_TABLE: TableDefinition<&str, &[u8]> = TableDefinition::new("streams_meta"); +const STREAMS_DATA_TABLE: TableDefinition<(&str, &str), &[u8]> = TableDefinition::new("streams_data"); + +#[derive(Serialize, Deserialize, Debug, Clone)] +pub struct StringValue { + pub value: String, + pub expires_at_ms: Option, +} + +#[derive(Serialize, Deserialize, Debug, Clone)] +pub struct StreamEntry { + pub fields: Vec<(String, String)>, } #[inline] @@ -17,43 +33,416 @@ pub fn now_in_millis() -> u128 { duration_since_epoch.as_millis() } +pub struct Storage { + db: Database, +} + impl Storage { - pub fn new() -> Self { - Storage { - set: HashMap::new(), + pub fn new(path: impl AsRef) -> Result { + let db = Database::create(path)?; + + // Create tables if they don't exist + let write_txn = db.begin_write()?; + { + let _ = write_txn.open_table(TYPES_TABLE)?; + let _ = write_txn.open_table(STRINGS_TABLE)?; + let _ = write_txn.open_table(HASHES_TABLE)?; + let _ = write_txn.open_table(STREAMS_META_TABLE)?; + let _ = write_txn.open_table(STREAMS_DATA_TABLE)?; + } + write_txn.commit()?; + + Ok(Storage { db }) + } + + pub fn get_key_type(&self, key: &str) -> Result, DBError> { + let read_txn = self.db.begin_read()?; + let table = read_txn.open_table(TYPES_TABLE)?; + + match table.get(key)? { + Some(type_val) => Ok(Some(type_val.value().to_string())), + None => Ok(None), } } - pub fn get(self: &mut Self, k: &str) -> Option { - match self.set.get(k) { - Some((ss, expire_timestamp)) => match expire_timestamp { - Some(expire_time_stamp) => { - if now_in_millis() > *expire_time_stamp { - self.set.remove(k); - None - } else { - Some(ss.clone()) + pub fn get(&self, key: &str) -> Result, DBError> { + let read_txn = self.db.begin_read()?; + + // Check if key exists and is of string type + let types_table = read_txn.open_table(TYPES_TABLE)?; + match types_table.get(key)? { + Some(type_val) if type_val.value() == "string" => { + let strings_table = read_txn.open_table(STRINGS_TABLE)?; + match strings_table.get(key)? { + Some(data) => { + let string_value: StringValue = bincode::deserialize(data.value())?; + + // Check if expired + if let Some(expires_at) = string_value.expires_at_ms { + if now_in_millis() > expires_at { + // Key expired, remove it + drop(read_txn); + self.del(key.to_string())?; + return Ok(None); + } + } + + Ok(Some(string_value.value)) + } + None => Ok(None), + } + } + _ => Ok(None), + } + } + + pub fn set(&self, key: String, value: String) -> Result<(), DBError> { + let write_txn = self.db.begin_write()?; + + { + let mut types_table = write_txn.open_table(TYPES_TABLE)?; + types_table.insert(key.as_str(), "string")?; + + let mut strings_table = write_txn.open_table(STRINGS_TABLE)?; + let string_value = StringValue { + value, + expires_at_ms: None, + }; + let serialized = bincode::serialize(&string_value)?; + strings_table.insert(key.as_str(), serialized.as_slice())?; + } + + write_txn.commit()?; + Ok(()) + } + + pub fn setx(&self, key: String, value: String, expire_ms: u128) -> Result<(), DBError> { + let write_txn = self.db.begin_write()?; + + { + let mut types_table = write_txn.open_table(TYPES_TABLE)?; + types_table.insert(key.as_str(), "string")?; + + let mut strings_table = write_txn.open_table(STRINGS_TABLE)?; + let string_value = StringValue { + value, + expires_at_ms: Some(expire_ms + now_in_millis()), + }; + let serialized = bincode::serialize(&string_value)?; + strings_table.insert(key.as_str(), serialized.as_slice())?; + } + + write_txn.commit()?; + Ok(()) + } + + pub fn del(&self, key: String) -> Result<(), DBError> { + let write_txn = self.db.begin_write()?; + + { + let mut types_table = write_txn.open_table(TYPES_TABLE)?; + let mut strings_table = write_txn.open_table(STRINGS_TABLE)?; + let mut hashes_table = write_txn.open_table(HASHES_TABLE)?; + + // Remove from type table + types_table.remove(key.as_str())?; + + // Remove from strings table + strings_table.remove(key.as_str())?; + + // Remove all hash fields for this key + let mut to_remove = Vec::new(); + let mut iter = hashes_table.iter()?; + while let Some(entry) = iter.next() { + let entry = entry?; + let (hash_key, field) = entry.0.value(); + if hash_key == key.as_str() { + to_remove.push((hash_key.to_string(), field.to_string())); + } + } + drop(iter); + + for (hash_key, field) in to_remove { + hashes_table.remove((hash_key.as_str(), field.as_str()))?; + } + } + + write_txn.commit()?; + Ok(()) + } + + pub fn keys(&self, pattern: &str) -> Result, DBError> { + let read_txn = self.db.begin_read()?; + let table = read_txn.open_table(TYPES_TABLE)?; + + let mut keys = Vec::new(); + let mut iter = table.iter()?; + while let Some(entry) = iter.next() { + let key = entry?.0.value().to_string(); + if pattern == "*" || key.contains(pattern) { + keys.push(key); + } + } + + Ok(keys) + } + + // Hash operations + pub fn hset(&self, key: &str, pairs: &[(String, String)]) -> Result { + let write_txn = self.db.begin_write()?; + let mut new_fields = 0u64; + + { + let mut types_table = write_txn.open_table(TYPES_TABLE)?; + let mut hashes_table = write_txn.open_table(HASHES_TABLE)?; + + // Check if key exists and is of correct type + let existing_type = match types_table.get(key)? { + Some(type_val) => Some(type_val.value().to_string()), + None => None, + }; + + match existing_type { + Some(ref type_str) if type_str != "hash" => { + return Err(DBError("WRONGTYPE Operation against a key holding the wrong kind of value".to_string())); + } + None => { + // Set type to hash + types_table.insert(key, "hash")?; + } + _ => {} + } + + for (field, value) in pairs { + let existed = hashes_table.get((key, field.as_str()))?.is_some(); + hashes_table.insert((key, field.as_str()), value.as_str())?; + if !existed { + new_fields += 1; + } + } + } + + write_txn.commit()?; + Ok(new_fields) + } + + pub fn hget(&self, key: &str, field: &str) -> Result, DBError> { + let read_txn = self.db.begin_read()?; + + // Check type + let types_table = read_txn.open_table(TYPES_TABLE)?; + match types_table.get(key)? { + Some(type_val) if type_val.value() == "hash" => { + let hashes_table = read_txn.open_table(HASHES_TABLE)?; + match hashes_table.get((key, field))? { + Some(value) => Ok(Some(value.value().to_string())), + None => Ok(None), + } + } + Some(_) => Err(DBError("WRONGTYPE Operation against a key holding the wrong kind of value".to_string())), + None => Ok(None), + } + } + + pub fn hgetall(&self, key: &str) -> Result, DBError> { + let read_txn = self.db.begin_read()?; + + // Check type + let types_table = read_txn.open_table(TYPES_TABLE)?; + match types_table.get(key)? { + Some(type_val) if type_val.value() == "hash" => { + let hashes_table = read_txn.open_table(HASHES_TABLE)?; + let mut result = Vec::new(); + + let mut iter = hashes_table.iter()?; + while let Some(entry) = iter.next() { + let entry = entry?; + let (hash_key, field) = entry.0.value(); + let value = entry.1.value(); + if hash_key == key { + result.push((field.to_string(), value.to_string())); } } - _ => Some(ss.clone()), - }, - _ => None, + + Ok(result) + } + Some(_) => Err(DBError("WRONGTYPE Operation against a key holding the wrong kind of value".to_string())), + None => Ok(Vec::new()), } } - pub fn set(self: &mut Self, k: String, v: String) { - self.set.insert(k, (v, None)); + pub fn hdel(&self, key: &str, fields: &[String]) -> Result { + let write_txn = self.db.begin_write()?; + let mut deleted = 0u64; + + { + let types_table = write_txn.open_table(TYPES_TABLE)?; + let key_type = types_table.get(key)?; + match key_type { + Some(type_val) if type_val.value() == "hash" => { + let mut hashes_table = write_txn.open_table(HASHES_TABLE)?; + + for field in fields { + if hashes_table.remove((key, field.as_str()))?.is_some() { + deleted += 1; + } + } + } + Some(_) => return Err(DBError("WRONGTYPE Operation against a key holding the wrong kind of value".to_string())), + None => {} + } + } + + write_txn.commit()?; + Ok(deleted) } - pub fn setx(self: &mut Self, k: String, v: String, expire_ms: u128) { - self.set.insert(k, (v, Some(expire_ms + now_in_millis()))); + pub fn hexists(&self, key: &str, field: &str) -> Result { + let read_txn = self.db.begin_read()?; + + let types_table = read_txn.open_table(TYPES_TABLE)?; + match types_table.get(key)? { + Some(type_val) if type_val.value() == "hash" => { + let hashes_table = read_txn.open_table(HASHES_TABLE)?; + Ok(hashes_table.get((key, field))?.is_some()) + } + Some(_) => Err(DBError("WRONGTYPE Operation against a key holding the wrong kind of value".to_string())), + None => Ok(false), + } } - pub fn del(self: &mut Self, k: String) { - self.set.remove(&k); + pub fn hkeys(&self, key: &str) -> Result, DBError> { + let read_txn = self.db.begin_read()?; + + let types_table = read_txn.open_table(TYPES_TABLE)?; + match types_table.get(key)? { + Some(type_val) if type_val.value() == "hash" => { + let hashes_table = read_txn.open_table(HASHES_TABLE)?; + let mut result = Vec::new(); + + let mut iter = hashes_table.iter()?; + while let Some(entry) = iter.next() { + let entry = entry?; + let (hash_key, field) = entry.0.value(); + if hash_key == key { + result.push(field.to_string()); + } + } + + Ok(result) + } + Some(_) => Err(DBError("WRONGTYPE Operation against a key holding the wrong kind of value".to_string())), + None => Ok(Vec::new()), + } } - pub fn keys(self: &Self) -> Vec { - self.set.keys().map(|x| x.clone()).collect() + pub fn hvals(&self, key: &str) -> Result, DBError> { + let read_txn = self.db.begin_read()?; + + let types_table = read_txn.open_table(TYPES_TABLE)?; + match types_table.get(key)? { + Some(type_val) if type_val.value() == "hash" => { + let hashes_table = read_txn.open_table(HASHES_TABLE)?; + let mut result = Vec::new(); + + let mut iter = hashes_table.iter()?; + while let Some(entry) = iter.next() { + let entry = entry?; + let (hash_key, _) = entry.0.value(); + let value = entry.1.value(); + if hash_key == key { + result.push(value.to_string()); + } + } + + Ok(result) + } + Some(_) => Err(DBError("WRONGTYPE Operation against a key holding the wrong kind of value".to_string())), + None => Ok(Vec::new()), + } + } + + pub fn hlen(&self, key: &str) -> Result { + let read_txn = self.db.begin_read()?; + + let types_table = read_txn.open_table(TYPES_TABLE)?; + match types_table.get(key)? { + Some(type_val) if type_val.value() == "hash" => { + let hashes_table = read_txn.open_table(HASHES_TABLE)?; + let mut count = 0u64; + + let mut iter = hashes_table.iter()?; + while let Some(entry) = iter.next() { + let entry = entry?; + let (hash_key, _) = entry.0.value(); + if hash_key == key { + count += 1; + } + } + + Ok(count) + } + Some(_) => Err(DBError("WRONGTYPE Operation against a key holding the wrong kind of value".to_string())), + None => Ok(0), + } + } + + pub fn hmget(&self, key: &str, fields: &[String]) -> Result>, DBError> { + let read_txn = self.db.begin_read()?; + + let types_table = read_txn.open_table(TYPES_TABLE)?; + match types_table.get(key)? { + Some(type_val) if type_val.value() == "hash" => { + let hashes_table = read_txn.open_table(HASHES_TABLE)?; + let mut result = Vec::new(); + + for field in fields { + match hashes_table.get((key, field.as_str()))? { + Some(value) => result.push(Some(value.value().to_string())), + None => result.push(None), + } + } + + Ok(result) + } + Some(_) => Err(DBError("WRONGTYPE Operation against a key holding the wrong kind of value".to_string())), + None => Ok(fields.iter().map(|_| None).collect()), + } + } + + pub fn hsetnx(&self, key: &str, field: &str, value: &str) -> Result { + let write_txn = self.db.begin_write()?; + let mut result = false; + + { + let mut types_table = write_txn.open_table(TYPES_TABLE)?; + let mut hashes_table = write_txn.open_table(HASHES_TABLE)?; + + // Check if key exists and is of correct type + let existing_type = match types_table.get(key)? { + Some(type_val) => Some(type_val.value().to_string()), + None => None, + }; + + match existing_type { + Some(ref type_str) if type_str != "hash" => { + return Err(DBError("WRONGTYPE Operation against a key holding the wrong kind of value".to_string())); + } + None => { + // Set type to hash + types_table.insert(key, "hash")?; + } + _ => {} + } + + // Check if field already exists + if hashes_table.get((key, field))?.is_none() { + hashes_table.insert((key, field), value)?; + result = true; + } + } + + write_txn.commit()?; + Ok(result) } } diff --git a/test_herodb.sh b/test_herodb.sh new file mode 100755 index 0000000..fb96a6c --- /dev/null +++ b/test_herodb.sh @@ -0,0 +1,302 @@ +#!/bin/bash + +# Test script for HeroDB - Redis-compatible database with redb backend +# This script starts the server and runs comprehensive tests + +set -e + +# Colors for output +RED='\033[0;31m' +GREEN='\033[0;32m' +YELLOW='\033[1;33m' +BLUE='\033[0;34m' +NC='\033[0m' # No Color + +# Configuration +DB_DIR="./test_db" +PORT=6379 +SERVER_PID="" + +# Function to print colored output +print_status() { + echo -e "${BLUE}[INFO]${NC} $1" +} + +print_success() { + echo -e "${GREEN}[SUCCESS]${NC} $1" +} + +print_error() { + echo -e "${RED}[ERROR]${NC} $1" +} + +print_warning() { + echo -e "${YELLOW}[WARNING]${NC} $1" +} + +# Function to cleanup on exit +cleanup() { + if [ ! -z "$SERVER_PID" ]; then + print_status "Stopping HeroDB server (PID: $SERVER_PID)..." + kill $SERVER_PID 2>/dev/null || true + wait $SERVER_PID 2>/dev/null || true + fi + + # Clean up test database + if [ -d "$DB_DIR" ]; then + print_status "Cleaning up test database directory..." + rm -rf "$DB_DIR" + fi +} + +# Set trap to cleanup on script exit +trap cleanup EXIT + +# Function to wait for server to start +wait_for_server() { + local max_attempts=30 + local attempt=1 + + print_status "Waiting for server to start on port $PORT..." + + while [ $attempt -le $max_attempts ]; do + if nc -z localhost $PORT 2>/dev/null; then + print_success "Server is ready!" + return 0 + fi + + echo -n "." + sleep 1 + attempt=$((attempt + 1)) + done + + print_error "Server failed to start within $max_attempts seconds" + return 1 +} + +# Function to send Redis command and get response +redis_cmd() { + local cmd="$1" + local expected="$2" + + print_status "Testing: $cmd" + + local result=$(echo "$cmd" | redis-cli -p $PORT --raw 2>/dev/null || echo "ERROR") + + if [ "$expected" != "" ] && [ "$result" != "$expected" ]; then + print_error "Expected: '$expected', Got: '$result'" + return 1 + else + print_success "✓ $cmd -> $result" + return 0 + fi +} + +# Function to test basic string operations +test_string_operations() { + print_status "=== Testing String Operations ===" + + redis_cmd "PING" "PONG" + redis_cmd "SET mykey hello" "OK" + redis_cmd "GET mykey" "hello" + redis_cmd "SET counter 1" "OK" + redis_cmd "INCR counter" "2" + redis_cmd "INCR counter" "3" + redis_cmd "GET counter" "3" + redis_cmd "DEL mykey" "1" + redis_cmd "GET mykey" "" + redis_cmd "TYPE counter" "string" + redis_cmd "TYPE nonexistent" "none" +} + +# Function to test hash operations +test_hash_operations() { + print_status "=== Testing Hash Operations ===" + + # HSET and HGET + redis_cmd "HSET user:1 name John" "1" + redis_cmd "HSET user:1 age 30 city NYC" "2" + redis_cmd "HGET user:1 name" "John" + redis_cmd "HGET user:1 age" "30" + redis_cmd "HGET user:1 nonexistent" "" + + # HGETALL + print_status "Testing HGETALL user:1" + redis_cmd "HGETALL user:1" "" + + # HEXISTS + redis_cmd "HEXISTS user:1 name" "1" + redis_cmd "HEXISTS user:1 nonexistent" "0" + + # HKEYS + print_status "Testing HKEYS user:1" + redis_cmd "HKEYS user:1" "" + + # HVALS + print_status "Testing HVALS user:1" + redis_cmd "HVALS user:1" "" + + # HLEN + redis_cmd "HLEN user:1" "3" + + # HMGET + print_status "Testing HMGET user:1 name age" + redis_cmd "HMGET user:1 name age" "" + + # HSETNX + redis_cmd "HSETNX user:1 name Jane" "0" # Should not set, field exists + redis_cmd "HSETNX user:1 email john@example.com" "1" # Should set, new field + redis_cmd "HGET user:1 email" "john@example.com" + + # HDEL + redis_cmd "HDEL user:1 age city" "2" + redis_cmd "HLEN user:1" "2" + redis_cmd "HEXISTS user:1 age" "0" + + # Test type checking + redis_cmd "SET stringkey value" "OK" + print_status "Testing WRONGTYPE error on string key" + redis_cmd "HGET stringkey field" "" # Should return WRONGTYPE error +} + +# Function to test configuration commands +test_config_operations() { + print_status "=== Testing Configuration Operations ===" + + print_status "Testing CONFIG GET dir" + redis_cmd "CONFIG GET dir" "" + + print_status "Testing CONFIG GET dbfilename" + redis_cmd "CONFIG GET dbfilename" "" +} + +# Function to test transaction operations +test_transaction_operations() { + print_status "=== Testing Transaction Operations ===" + + redis_cmd "MULTI" "OK" + redis_cmd "SET tx_key1 value1" "QUEUED" + redis_cmd "SET tx_key2 value2" "QUEUED" + redis_cmd "INCR counter" "QUEUED" + print_status "Testing EXEC" + redis_cmd "EXEC" "" + + redis_cmd "GET tx_key1" "value1" + redis_cmd "GET tx_key2" "value2" + + # Test DISCARD + redis_cmd "MULTI" "OK" + redis_cmd "SET discard_key value" "QUEUED" + redis_cmd "DISCARD" "OK" + redis_cmd "GET discard_key" "" +} + +# Function to test keys operations +test_keys_operations() { + print_status "=== Testing Keys Operations ===" + + print_status "Testing KEYS *" + redis_cmd "KEYS *" "" +} + +# Function to test info operations +test_info_operations() { + print_status "=== Testing Info Operations ===" + + print_status "Testing INFO" + redis_cmd "INFO" "" + + print_status "Testing INFO replication" + redis_cmd "INFO replication" "" +} + +# Function to test expiration +test_expiration() { + print_status "=== Testing Expiration ===" + + redis_cmd "SET expire_key value" "OK" + redis_cmd "SET expire_px_key value PX 1000" "OK" # 1 second + redis_cmd "SET expire_ex_key value EX 1" "OK" # 1 second + + redis_cmd "GET expire_key" "value" + redis_cmd "GET expire_px_key" "value" + redis_cmd "GET expire_ex_key" "value" + + print_status "Waiting 2 seconds for expiration..." + sleep 2 + + redis_cmd "GET expire_key" "value" # Should still exist + redis_cmd "GET expire_px_key" "" # Should be expired + redis_cmd "GET expire_ex_key" "" # Should be expired +} + +# Main execution +main() { + print_status "Starting HeroDB comprehensive test suite..." + + # Build the project + print_status "Building HeroDB..." + if ! cargo build --release; then + print_error "Failed to build HeroDB" + exit 1 + fi + + # Create test database directory + mkdir -p "$DB_DIR" + + # Start the server + print_status "Starting HeroDB server..." + ./target/release/redis-rs --dir "$DB_DIR" --port $PORT & + SERVER_PID=$! + + # Wait for server to start + if ! wait_for_server; then + print_error "Failed to start server" + exit 1 + fi + + # Run tests + local failed_tests=0 + + test_string_operations || failed_tests=$((failed_tests + 1)) + test_hash_operations || failed_tests=$((failed_tests + 1)) + test_config_operations || failed_tests=$((failed_tests + 1)) + test_transaction_operations || failed_tests=$((failed_tests + 1)) + test_keys_operations || failed_tests=$((failed_tests + 1)) + test_info_operations || failed_tests=$((failed_tests + 1)) + test_expiration || failed_tests=$((failed_tests + 1)) + + # Summary + echo + print_status "=== Test Summary ===" + if [ $failed_tests -eq 0 ]; then + print_success "All tests completed! Some may have warnings due to protocol differences." + print_success "HeroDB is working with persistent redb storage!" + else + print_warning "$failed_tests test categories had issues" + print_warning "Check the output above for details" + fi + + print_status "Database file created at: $DB_DIR/herodb.redb" + print_status "Server logs and any errors are shown above" +} + +# Check dependencies +check_dependencies() { + if ! command -v cargo &> /dev/null; then + print_error "cargo is required but not installed" + exit 1 + fi + + if ! command -v nc &> /dev/null; then + print_warning "netcat (nc) not found - some tests may not work properly" + fi + + if ! command -v redis-cli &> /dev/null; then + print_warning "redis-cli not found - using netcat fallback" + fi +} + +# Run dependency check and main function +check_dependencies +main "$@" \ No newline at end of file