This commit is contained in:
2025-08-16 07:18:55 +02:00
parent cd61406d1d
commit de2be4a785
12 changed files with 1060 additions and 955 deletions

43
Cargo.lock generated
View File

@@ -1,6 +1,6 @@
# This file is automatically @generated by Cargo.
# It is not intended for manual editing.
version = 3
version = 4
[[package]]
name = "addr2line"
@@ -93,6 +93,15 @@ dependencies = [
"rustc-demangle",
]
[[package]]
name = "bincode"
version = "1.3.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b1f45e9417d87227c7a56d22e471c6206462cba514c7590c09aff4cf6d1ddcad"
dependencies = [
"serde",
]
[[package]]
name = "bitflags"
version = "2.6.0"
@@ -396,15 +405,27 @@ dependencies = [
"proc-macro2",
]
[[package]]
name = "redb"
version = "2.6.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "59b38b05028f398f08bea4691640503ec25fcb60b82fb61ce1f8fd1f4fccd3f7"
dependencies = [
"libc",
]
[[package]]
name = "redis-rs"
version = "0.0.1"
dependencies = [
"anyhow",
"bincode",
"byteorder",
"bytes",
"clap",
"futures",
"redb",
"serde",
"thiserror",
"tokio",
]
@@ -430,6 +451,26 @@ version = "1.2.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "94143f37725109f92c262ed2cf5e59bce7498c01bcc1502d7b9afe439a4e9f49"
[[package]]
name = "serde"
version = "1.0.210"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "c8e3592472072e6e22e0a54d5904d9febf8508f65fb8552499a1abc7d1078c3a"
dependencies = [
"serde_derive",
]
[[package]]
name = "serde_derive"
version = "1.0.210"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "243902eda00fad750862fc144cea25caca5e20d615af0a81bee94ca738f1df1f"
dependencies = [
"proc-macro2",
"quote",
"syn",
]
[[package]]
name = "signal-hook-registry"
version = "1.4.2"

View File

@@ -6,10 +6,13 @@ edition = "2021"
[dependencies]
anyhow = "1.0.59"
bytes = "1.3.0"
bytes = "1.3.0"
thiserror = "1.0.32"
tokio = { version = "1.23.0", features = ["full"] }
tokio = { version = "1.23.0", features = ["full"] }
clap = { version = "4.5.20", features = ["derive"] }
byteorder = "1.4.3"
futures = "0.3"
redb = "2.1.3"
serde = { version = "1.0", features = ["derive"] }
bincode = "1.3.3"

View File

@@ -1,8 +1,4 @@
use std::{collections::BTreeMap, ops::Bound, time::Duration, u64};
use tokio::sync::mpsc;
use crate::{error::DBError, protocol::Protocol, server::Server, storage::now_in_millis};
use crate::{error::DBError, protocol::Protocol, server::Server};
#[derive(Debug, Clone)]
pub enum Cmd {
@@ -16,17 +12,23 @@ pub enum Cmd {
ConfigGet(String),
Info(Option<String>),
Del(String),
Replconf(String),
Psync,
Type(String),
Xadd(String, String, Vec<(String, String)>),
Xrange(String, String, String),
Xread(Vec<String>, Vec<String>, Option<u64>),
Incr(String),
Multi,
Exec,
Unknow,
Discard,
// Hash commands
HSet(String, Vec<(String, String)>),
HGet(String, String),
HGetAll(String),
HDel(String, Vec<String>),
HExists(String, String),
HKeys(String),
HVals(String),
HLen(String),
HMGet(String, Vec<String>),
HSetNx(String, String, String),
Unknow,
}
impl Cmd {
@@ -39,14 +41,14 @@ impl Cmd {
return Err(DBError("cmd length is 0".to_string()));
}
Ok((
match cmd[0].as_str() {
match cmd[0].to_lowercase().as_str() {
"echo" => Cmd::Echo(cmd[1].clone()),
"ping" => Cmd::Ping,
"get" => Cmd::Get(cmd[1].clone()),
"set" => {
if cmd.len() == 5 && cmd[3] == "px" {
if cmd.len() == 5 && cmd[3].to_lowercase() == "px" {
Cmd::SetPx(cmd[1].clone(), cmd[2].clone(), cmd[4].parse().unwrap())
} else if cmd.len() == 5 && cmd[3] == "ex" {
} else if cmd.len() == 5 && cmd[3].to_lowercase() == "ex" {
Cmd::SetEx(cmd[1].clone(), cmd[2].clone(), cmd[4].parse().unwrap())
} else if cmd.len() == 3 {
Cmd::Set(cmd[1].clone(), cmd[2].clone())
@@ -55,7 +57,7 @@ impl Cmd {
}
}
"config" => {
if cmd.len() != 3 || cmd[1] != "get" {
if cmd.len() != 3 || cmd[1].to_lowercase() != "get" {
return Err(DBError(format!("unsupported cmd {:?}", cmd)));
} else {
Cmd::ConfigGet(cmd[2].clone())
@@ -76,18 +78,6 @@ impl Cmd {
};
Cmd::Info(section)
}
"replconf" => {
if cmd.len() < 3 {
return Err(DBError(format!("unsupported cmd {:?}", cmd)));
}
Cmd::Replconf(cmd[1].clone())
}
"psync" => {
if cmd.len() != 3 {
return Err(DBError(format!("unsupported cmd {:?}", cmd)));
}
Cmd::Psync
}
"del" => {
if cmd.len() != 2 {
return Err(DBError(format!("unsupported cmd {:?}", cmd)));
@@ -100,44 +90,6 @@ impl Cmd {
}
Cmd::Type(cmd[1].clone())
}
"xadd" => {
if cmd.len() < 5 {
return Err(DBError(format!("unsupported cmd {:?}", cmd)));
}
let mut key_value = Vec::<(String, String)>::new();
let mut i = 3;
while i < cmd.len() - 1 {
key_value.push((cmd[i].clone(), cmd[i + 1].clone()));
i += 2;
}
Cmd::Xadd(cmd[1].clone(), cmd[2].clone(), key_value)
}
"xrange" => {
if cmd.len() != 4 {
return Err(DBError(format!("unsupported cmd {:?}", cmd)));
}
Cmd::Xrange(cmd[1].clone(), cmd[2].clone(), cmd[3].clone())
}
"xread" => {
if cmd.len() < 4 || cmd.len() % 2 != 0 {
return Err(DBError(format!("unsupported cmd {:?}", cmd)));
}
let mut offset = 2;
// block cmd
let mut block = None;
if cmd[1] == "block" {
offset += 2;
if let Ok(block_time) = cmd[2].parse() {
block = Some(block_time);
} else {
return Err(DBError(format!("unsupported cmd {:?}", cmd)));
}
}
let cmd2 = &cmd[offset..];
let len2 = cmd2.len() / 2;
Cmd::Xread(cmd2[0..len2].to_vec(), cmd2[len2..].to_vec(), block)
}
"incr" => {
if cmd.len() != 2 {
return Err(DBError(format!("unsupported cmd {:?}", cmd)));
@@ -157,6 +109,73 @@ impl Cmd {
Cmd::Exec
}
"discard" => Cmd::Discard,
// Hash commands
"hset" => {
if cmd.len() < 4 || (cmd.len() - 2) % 2 != 0 {
return Err(DBError(format!("wrong number of arguments for HSET command")));
}
let mut pairs = Vec::new();
let mut i = 2;
while i < cmd.len() - 1 {
pairs.push((cmd[i].clone(), cmd[i + 1].clone()));
i += 2;
}
Cmd::HSet(cmd[1].clone(), pairs)
}
"hget" => {
if cmd.len() != 3 {
return Err(DBError(format!("wrong number of arguments for HGET command")));
}
Cmd::HGet(cmd[1].clone(), cmd[2].clone())
}
"hgetall" => {
if cmd.len() != 2 {
return Err(DBError(format!("wrong number of arguments for HGETALL command")));
}
Cmd::HGetAll(cmd[1].clone())
}
"hdel" => {
if cmd.len() < 3 {
return Err(DBError(format!("wrong number of arguments for HDEL command")));
}
Cmd::HDel(cmd[1].clone(), cmd[2..].to_vec())
}
"hexists" => {
if cmd.len() != 3 {
return Err(DBError(format!("wrong number of arguments for HEXISTS command")));
}
Cmd::HExists(cmd[1].clone(), cmd[2].clone())
}
"hkeys" => {
if cmd.len() != 2 {
return Err(DBError(format!("wrong number of arguments for HKEYS command")));
}
Cmd::HKeys(cmd[1].clone())
}
"hvals" => {
if cmd.len() != 2 {
return Err(DBError(format!("wrong number of arguments for HVALS command")));
}
Cmd::HVals(cmd[1].clone())
}
"hlen" => {
if cmd.len() != 2 {
return Err(DBError(format!("wrong number of arguments for HLEN command")));
}
Cmd::HLen(cmd[1].clone())
}
"hmget" => {
if cmd.len() < 3 {
return Err(DBError(format!("wrong number of arguments for HMGET command")));
}
Cmd::HMGet(cmd[1].clone(), cmd[2..].to_vec())
}
"hsetnx" => {
if cmd.len() != 4 {
return Err(DBError(format!("wrong number of arguments for HSETNX command")));
}
Cmd::HSetNx(cmd[1].clone(), cmd[2].clone(), cmd[3].clone())
}
_ => Cmd::Unknow,
},
protocol.0,
@@ -171,13 +190,11 @@ impl Cmd {
pub async fn run(
&self,
server: &mut Server,
server: &Server,
protocol: Protocol,
is_rep_con: bool,
queued_cmd: &mut Option<Vec<(Cmd, Protocol)>>,
) -> Result<Protocol, DBError> {
// return if the command is a write command
let p = protocol.clone();
// Handle queued commands for transactions
if queued_cmd.is_some()
&& !matches!(self, Cmd::Exec)
&& !matches!(self, Cmd::Multi)
@@ -189,71 +206,57 @@ impl Cmd {
.push((self.clone(), protocol.clone()));
return Ok(Protocol::SimpleString("QUEUED".to_string()));
}
let ret = match self {
match self {
Cmd::Ping => Ok(Protocol::SimpleString("PONG".to_string())),
Cmd::Echo(s) => Ok(Protocol::SimpleString(s.clone())),
Cmd::Get(k) => get_cmd(server, k).await,
Cmd::Set(k, v) => set_cmd(server, k, v, protocol, is_rep_con).await,
Cmd::SetPx(k, v, x) => set_px_cmd(server, k, v, x, protocol, is_rep_con).await,
Cmd::SetEx(k, v, x) => set_ex_cmd(server, k, v, x, protocol, is_rep_con).await,
Cmd::Del(k) => del_cmd(server, k, protocol, is_rep_con).await,
Cmd::Set(k, v) => set_cmd(server, k, v).await,
Cmd::SetPx(k, v, x) => set_px_cmd(server, k, v, x).await,
Cmd::SetEx(k, v, x) => set_ex_cmd(server, k, v, x).await,
Cmd::Del(k) => del_cmd(server, k).await,
Cmd::ConfigGet(name) => config_get_cmd(name, server),
Cmd::Keys => keys_cmd(server).await,
Cmd::Info(section) => info_cmd(section, server),
Cmd::Replconf(sub_cmd) => replconf_cmd(sub_cmd, server),
Cmd::Psync => psync_cmd(server),
Cmd::Info(section) => info_cmd(section),
Cmd::Type(k) => type_cmd(server, k).await,
Cmd::Xadd(stream_key, offset, kvps) => {
xadd_cmd(
offset.as_str(),
server,
stream_key.as_str(),
kvps,
protocol,
is_rep_con,
)
.await
}
Cmd::Xrange(stream_key, start, end) => xrange_cmd(server, stream_key, start, end).await,
Cmd::Xread(stream_keys, starts, block) => {
xread_cmd(starts, server, stream_keys, block).await
}
Cmd::Incr(key) => incr_cmd(server, key).await,
Cmd::Multi => {
*queued_cmd = Some(Vec::<(Cmd, Protocol)>::new());
Ok(Protocol::SimpleString("ok".to_string()))
Ok(Protocol::SimpleString("OK".to_string()))
}
Cmd::Exec => exec_cmd(queued_cmd, server, is_rep_con).await,
Cmd::Exec => exec_cmd(queued_cmd, server).await,
Cmd::Discard => {
if queued_cmd.is_some() {
*queued_cmd = None;
Ok(Protocol::SimpleString("ok".to_string()))
Ok(Protocol::SimpleString("OK".to_string()))
} else {
Ok(Protocol::err("ERR Discard without MULTI"))
Ok(Protocol::err("ERR DISCARD without MULTI"))
}
}
Cmd::Unknow => Ok(Protocol::err("unknow cmd")),
};
if ret.is_ok() {
server.offset.fetch_add(
p.encode().len() as u64,
std::sync::atomic::Ordering::Relaxed,
);
// Hash commands
Cmd::HSet(key, pairs) => hset_cmd(server, key, pairs).await,
Cmd::HGet(key, field) => hget_cmd(server, key, field).await,
Cmd::HGetAll(key) => hgetall_cmd(server, key).await,
Cmd::HDel(key, fields) => hdel_cmd(server, key, fields).await,
Cmd::HExists(key, field) => hexists_cmd(server, key, field).await,
Cmd::HKeys(key) => hkeys_cmd(server, key).await,
Cmd::HVals(key) => hvals_cmd(server, key).await,
Cmd::HLen(key) => hlen_cmd(server, key).await,
Cmd::HMGet(key, fields) => hmget_cmd(server, key, fields).await,
Cmd::HSetNx(key, field, value) => hsetnx_cmd(server, key, field, value).await,
Cmd::Unknow => Ok(Protocol::err("unknown cmd")),
}
ret
}
}
async fn exec_cmd(
queued_cmd: &mut Option<Vec<(Cmd, Protocol)>>,
server: &mut Server,
is_rep_con: bool,
server: &Server,
) -> Result<Protocol, DBError> {
if queued_cmd.is_some() {
let mut vec = Vec::new();
for (cmd, protocol) in queued_cmd.as_ref().unwrap() {
let res = Box::pin(cmd.run(server, protocol.clone(), is_rep_con, &mut None)).await?;
let res = Box::pin(cmd.run(server, protocol.clone(), &mut None)).await?;
vec.push(res);
}
*queued_cmd = None;
@@ -263,22 +266,24 @@ async fn exec_cmd(
}
}
async fn incr_cmd(server: &mut Server, key: &String) -> Result<Protocol, DBError> {
let mut storage = server.storage.lock().await;
let v = storage.get(key);
// return 1 if key is missing
let v = v.map_or("1".to_string(), |v| v);
if let Ok(x) = v.parse::<u64>() {
let v = (x + 1).to_string();
storage.set(key.clone(), v.clone());
Ok(Protocol::SimpleString(v))
} else {
Ok(Protocol::err("ERR value is not an integer or out of range"))
}
async fn incr_cmd(server: &Server, key: &String) -> Result<Protocol, DBError> {
let current_value = server.storage.get(key)?;
let new_value = match current_value {
Some(v) => {
match v.parse::<i64>() {
Ok(num) => num + 1,
Err(_) => return Ok(Protocol::err("ERR value is not an integer or out of range")),
}
}
None => 1,
};
server.storage.set(key.clone(), new_value.to_string())?;
Ok(Protocol::SimpleString(new_value.to_string()))
}
fn config_get_cmd(name: &String, server: &mut Server) -> Result<Protocol, DBError> {
fn config_get_cmd(name: &String, server: &Server) -> Result<Protocol, DBError> {
match name.as_str() {
"dir" => Ok(Protocol::Array(vec![
Protocol::BulkString(name.clone()),
@@ -286,336 +291,156 @@ fn config_get_cmd(name: &String, server: &mut Server) -> Result<Protocol, DBErro
])),
"dbfilename" => Ok(Protocol::Array(vec![
Protocol::BulkString(name.clone()),
Protocol::BulkString(server.option.db_file_name.clone()),
Protocol::BulkString("herodb.redb".to_string()),
])),
_ => Err(DBError(format!("unsupported config {:?}", name))),
}
}
async fn keys_cmd(server: &mut Server) -> Result<Protocol, DBError> {
let keys = { server.storage.lock().await.keys() };
async fn keys_cmd(server: &Server) -> Result<Protocol, DBError> {
let keys = server.storage.keys("*")?;
Ok(Protocol::Array(
keys.into_iter().map(Protocol::BulkString).collect(),
))
}
fn info_cmd(section: &Option<String>, server: &mut Server) -> Result<Protocol, DBError> {
fn info_cmd(section: &Option<String>) -> Result<Protocol, DBError> {
match section {
Some(s) => match s.as_str() {
"replication" => Ok(Protocol::BulkString(format!(
"role:{}\nmaster_replid:{}\nmaster_repl_offset:{}\n",
server.option.replication.role,
server.option.replication.master_replid,
server.option.replication.master_repl_offset
))),
"replication" => Ok(Protocol::BulkString(
"role:master\nmaster_replid:8371b4fb1155b71f4a04d3e1bc3e18c4a990aeea\nmaster_repl_offset:0\n".to_string()
)),
_ => Err(DBError(format!("unsupported section {:?}", s))),
},
None => Ok(Protocol::BulkString("default".to_string())),
None => Ok(Protocol::BulkString("# Server\nredis_version:7.0.0\n".to_string())),
}
}
async fn xread_cmd(
starts: &[String],
server: &mut Server,
stream_keys: &[String],
block_millis: &Option<u64>,
) -> Result<Protocol, DBError> {
if let Some(t) = block_millis {
if t > &0 {
tokio::time::sleep(Duration::from_millis(*t)).await;
} else {
let (sender, mut receiver) = mpsc::channel(4);
{
let mut blocker = server.stream_reader_blocker.lock().await;
blocker.push(sender.clone());
}
while let Some(_) = receiver.recv().await {
println!("get new xadd cmd, release block");
// break;
}
}
}
let streams = server.streams.lock().await;
let mut ret = Vec::new();
for (i, stream_key) in stream_keys.iter().enumerate() {
let stream = streams.get(stream_key);
if let Some(s) = stream {
let (offset_id, mut offset_seq, _) = split_offset(starts[i].as_str());
offset_seq += 1;
let start = format!("{}-{}", offset_id, offset_seq);
let end = format!("{}-{}", u64::MAX - 1, 0);
// query stream range
let range = s.range::<String, _>((Bound::Included(&start), Bound::Included(&end)));
let mut array = Vec::new();
for (k, v) in range {
array.push(Protocol::BulkString(k.clone()));
array.push(Protocol::from_vec(
v.iter()
.flat_map(|(a, b)| vec![a.as_str(), b.as_str()])
.collect(),
))
}
ret.push(Protocol::BulkString(stream_key.clone()));
ret.push(Protocol::Array(array));
}
}
Ok(Protocol::Array(ret))
}
fn replconf_cmd(sub_cmd: &str, server: &mut Server) -> Result<Protocol, DBError> {
match sub_cmd {
"getack" => Ok(Protocol::from_vec(vec![
"REPLCONF",
"ACK",
server
.offset
.load(std::sync::atomic::Ordering::Relaxed)
.to_string()
.as_str(),
])),
_ => Ok(Protocol::SimpleString("OK".to_string())),
async fn type_cmd(server: &Server, k: &String) -> Result<Protocol, DBError> {
match server.storage.get_key_type(k)? {
Some(type_str) => Ok(Protocol::SimpleString(type_str)),
None => Ok(Protocol::SimpleString("none".to_string())),
}
}
async fn xrange_cmd(
server: &mut Server,
stream_key: &String,
start: &String,
end: &String,
) -> Result<Protocol, DBError> {
let streams = server.streams.lock().await;
let stream = streams.get(stream_key);
Ok(stream.map_or(Protocol::none(), |s| {
// support query with '-'
let start = if start == "-" {
"0".to_string()
} else {
start.clone()
};
// support query with '+'
let end = if end == "+" {
u64::MAX.to_string()
} else {
end.clone()
};
// query stream range
let range = s.range::<String, _>((Bound::Included(&start), Bound::Included(&end)));
let mut array = Vec::new();
for (k, v) in range {
array.push(Protocol::BulkString(k.clone()));
array.push(Protocol::from_vec(
v.iter()
.flat_map(|(a, b)| vec![a.as_str(), b.as_str()])
.collect(),
))
}
println!("after xrange: {:?}", array);
Protocol::Array(array)
}))
}
async fn xadd_cmd(
offset: &str,
server: &mut Server,
stream_key: &str,
kvps: &Vec<(String, String)>,
protocol: Protocol,
is_rep_con: bool,
) -> Result<Protocol, DBError> {
let mut offset = offset.to_string();
if offset == "*" {
offset = format!("{}-*", now_in_millis() as u64);
}
let (offset_id, mut offset_seq, has_wildcard) = split_offset(offset.as_str());
if offset_id == 0 && offset_seq == 0 && !has_wildcard {
return Ok(Protocol::err(
"ERR The ID specified in XADD must be greater than 0-0",
));
}
{
let mut streams = server.streams.lock().await;
let stream = streams
.entry(stream_key.to_string())
.or_insert_with(BTreeMap::new);
if let Some((last_offset, _)) = stream.last_key_value() {
let (last_offset_id, last_offset_seq, _) = split_offset(last_offset.as_str());
if last_offset_id > offset_id
|| (last_offset_id == offset_id && last_offset_seq >= offset_seq && !has_wildcard)
{
return Ok(Protocol::err("ERR The ID specified in XADD is equal or smaller than the target stream top item"));
}
if last_offset_id == offset_id && last_offset_seq >= offset_seq && has_wildcard {
offset_seq = last_offset_seq + 1;
}
}
let offset = format!("{}-{}", offset_id, offset_seq);
let s = stream.entry(offset.clone()).or_insert_with(Vec::new);
for (key, value) in kvps {
s.push((key.clone(), value.clone()));
}
}
{
let mut blocker = server.stream_reader_blocker.lock().await;
for sender in blocker.iter() {
sender.send(()).await?;
}
blocker.clear();
}
resp_and_replicate(
server,
Protocol::BulkString(offset.to_string()),
protocol,
is_rep_con,
)
.await
}
async fn type_cmd(server: &mut Server, k: &String) -> Result<Protocol, DBError> {
let v = { server.storage.lock().await.get(k) };
if v.is_some() {
return Ok(Protocol::SimpleString("string".to_string()));
}
let streams = server.streams.lock().await;
let v = streams.get(k);
Ok(v.map_or(Protocol::none(), |_| {
Protocol::SimpleString("stream".to_string())
}))
}
fn psync_cmd(server: &mut Server) -> Result<Protocol, DBError> {
if server.is_master() {
Ok(Protocol::SimpleString(format!(
"FULLRESYNC {} 0",
server.option.replication.master_replid
)))
} else {
Ok(Protocol::psync_on_slave_err())
}
}
async fn del_cmd(
server: &mut Server,
k: &str,
protocol: Protocol,
is_rep_con: bool,
) -> Result<Protocol, DBError> {
// offset
let _ = {
let mut s = server.storage.lock().await;
s.del(k.to_string());
server
.offset
.fetch_add(1, std::sync::atomic::Ordering::Relaxed)
};
resp_and_replicate(server, Protocol::ok(), protocol, is_rep_con).await
async fn del_cmd(server: &Server, k: &str) -> Result<Protocol, DBError> {
server.storage.del(k.to_string())?;
Ok(Protocol::SimpleString("1".to_string()))
}
async fn set_ex_cmd(
server: &mut Server,
server: &Server,
k: &str,
v: &str,
x: &u128,
protocol: Protocol,
is_rep_con: bool,
) -> Result<Protocol, DBError> {
// offset
let _ = {
let mut s = server.storage.lock().await;
s.setx(k.to_string(), v.to_string(), *x * 1000);
server
.offset
.fetch_add(1, std::sync::atomic::Ordering::Relaxed)
};
resp_and_replicate(server, Protocol::ok(), protocol, is_rep_con).await
server.storage.setx(k.to_string(), v.to_string(), *x * 1000)?;
Ok(Protocol::SimpleString("OK".to_string()))
}
async fn set_px_cmd(
server: &mut Server,
server: &Server,
k: &str,
v: &str,
x: &u128,
protocol: Protocol,
is_rep_con: bool,
) -> Result<Protocol, DBError> {
// offset
let _ = {
let mut s = server.storage.lock().await;
s.setx(k.to_string(), v.to_string(), *x);
server
.offset
.fetch_add(1, std::sync::atomic::Ordering::Relaxed)
};
resp_and_replicate(server, Protocol::ok(), protocol, is_rep_con).await
server.storage.setx(k.to_string(), v.to_string(), *x)?;
Ok(Protocol::SimpleString("OK".to_string()))
}
async fn set_cmd(
server: &mut Server,
k: &str,
v: &str,
protocol: Protocol,
is_rep_con: bool,
) -> Result<Protocol, DBError> {
// offset
let _ = {
let mut s = server.storage.lock().await;
s.set(k.to_string(), v.to_string());
server
.offset
.fetch_add(1, std::sync::atomic::Ordering::Relaxed)
+ 1
};
resp_and_replicate(server, Protocol::ok(), protocol, is_rep_con).await
async fn set_cmd(server: &Server, k: &str, v: &str) -> Result<Protocol, DBError> {
server.storage.set(k.to_string(), v.to_string())?;
Ok(Protocol::SimpleString("OK".to_string()))
}
async fn get_cmd(server: &mut Server, k: &str) -> Result<Protocol, DBError> {
let v = {
let mut s = server.storage.lock().await;
s.get(k)
};
Ok(v.map_or(Protocol::Null, Protocol::SimpleString))
async fn get_cmd(server: &Server, k: &str) -> Result<Protocol, DBError> {
let v = server.storage.get(k)?;
Ok(v.map_or(Protocol::Null, Protocol::BulkString))
}
async fn resp_and_replicate(
server: &mut Server,
resp: Protocol,
replication: Protocol,
is_rep_con: bool,
) -> Result<Protocol, DBError> {
if server.is_master() {
server
.master_repl_clients
.lock()
.await
.as_mut()
.unwrap()
.send_command(replication)
.await?;
Ok(resp)
} else if !is_rep_con {
Ok(Protocol::write_on_slave_err())
} else {
Ok(resp)
// Hash command implementations
async fn hset_cmd(server: &Server, key: &str, pairs: &[(String, String)]) -> Result<Protocol, DBError> {
let new_fields = server.storage.hset(key, pairs)?;
Ok(Protocol::SimpleString(new_fields.to_string()))
}
async fn hget_cmd(server: &Server, key: &str, field: &str) -> Result<Protocol, DBError> {
match server.storage.hget(key, field) {
Ok(Some(value)) => Ok(Protocol::BulkString(value)),
Ok(None) => Ok(Protocol::Null),
Err(e) => Ok(Protocol::err(&e.0)),
}
}
fn split_offset(offset: &str) -> (u64, u64, bool) {
let offset_split = offset.split('-').collect::<Vec<_>>();
let offset_id = offset_split[0].parse::<u64>().expect(&format!(
"ERR The ID specified in XADD must be a number: {}",
offset
));
if offset_split.len() == 1 || offset_split[1] == "*" {
return (offset_id, if offset_id == 0 { 1 } else { 0 }, true);
async fn hgetall_cmd(server: &Server, key: &str) -> Result<Protocol, DBError> {
match server.storage.hgetall(key) {
Ok(pairs) => {
let mut result = Vec::new();
for (field, value) in pairs {
result.push(Protocol::BulkString(field));
result.push(Protocol::BulkString(value));
}
Ok(Protocol::Array(result))
}
Err(e) => Ok(Protocol::err(&e.0)),
}
}
async fn hdel_cmd(server: &Server, key: &str, fields: &[String]) -> Result<Protocol, DBError> {
match server.storage.hdel(key, fields) {
Ok(deleted) => Ok(Protocol::SimpleString(deleted.to_string())),
Err(e) => Ok(Protocol::err(&e.0)),
}
}
async fn hexists_cmd(server: &Server, key: &str, field: &str) -> Result<Protocol, DBError> {
match server.storage.hexists(key, field) {
Ok(exists) => Ok(Protocol::SimpleString(if exists { "1" } else { "0" }.to_string())),
Err(e) => Ok(Protocol::err(&e.0)),
}
}
async fn hkeys_cmd(server: &Server, key: &str) -> Result<Protocol, DBError> {
match server.storage.hkeys(key) {
Ok(keys) => Ok(Protocol::Array(
keys.into_iter().map(Protocol::BulkString).collect(),
)),
Err(e) => Ok(Protocol::err(&e.0)),
}
}
async fn hvals_cmd(server: &Server, key: &str) -> Result<Protocol, DBError> {
match server.storage.hvals(key) {
Ok(values) => Ok(Protocol::Array(
values.into_iter().map(Protocol::BulkString).collect(),
)),
Err(e) => Ok(Protocol::err(&e.0)),
}
}
async fn hlen_cmd(server: &Server, key: &str) -> Result<Protocol, DBError> {
match server.storage.hlen(key) {
Ok(len) => Ok(Protocol::SimpleString(len.to_string())),
Err(e) => Ok(Protocol::err(&e.0)),
}
}
async fn hmget_cmd(server: &Server, key: &str, fields: &[String]) -> Result<Protocol, DBError> {
match server.storage.hmget(key, fields) {
Ok(values) => {
let result: Vec<Protocol> = values
.into_iter()
.map(|v| v.map_or(Protocol::Null, Protocol::BulkString))
.collect();
Ok(Protocol::Array(result))
}
Err(e) => Ok(Protocol::err(&e.0)),
}
}
async fn hsetnx_cmd(server: &Server, key: &str, field: &str, value: &str) -> Result<Protocol, DBError> {
match server.storage.hsetnx(key, field, value) {
Ok(was_set) => Ok(Protocol::SimpleString(if was_set { "1" } else { "0" }.to_string())),
Err(e) => Ok(Protocol::err(&e.0)),
}
let offset_seq = offset_split[1].parse::<u64>().unwrap();
(offset_id, offset_seq, false)
}

View File

@@ -1,6 +1,8 @@
use std::num::ParseIntError;
use tokio::sync::mpsc;
use redb;
use bincode;
use crate::protocol::Protocol;
@@ -32,11 +34,48 @@ impl From<std::string::FromUtf8Error> for DBError {
}
}
impl From<mpsc::error::SendError<(Protocol, u64)>> for DBError {
fn from(item: mpsc::error::SendError<(Protocol, u64)>) -> Self {
DBError(item.to_string().clone())
impl From<redb::Error> for DBError {
fn from(item: redb::Error) -> Self {
DBError(item.to_string())
}
}
impl From<redb::DatabaseError> for DBError {
fn from(item: redb::DatabaseError) -> Self {
DBError(item.to_string())
}
}
impl From<redb::TransactionError> for DBError {
fn from(item: redb::TransactionError) -> Self {
DBError(item.to_string())
}
}
impl From<redb::TableError> for DBError {
fn from(item: redb::TableError) -> Self {
DBError(item.to_string())
}
}
impl From<redb::StorageError> for DBError {
fn from(item: redb::StorageError) -> Self {
DBError(item.to_string())
}
}
impl From<redb::CommitError> for DBError {
fn from(item: redb::CommitError) -> Self {
DBError(item.to_string())
}
}
impl From<Box<bincode::ErrorKind>> for DBError {
fn from(item: Box<bincode::ErrorKind>) -> Self {
DBError(item.to_string())
}
}
impl From<tokio::sync::mpsc::error::SendError<()>> for DBError {
fn from(item: mpsc::error::SendError<()>) -> Self {
DBError(item.to_string().clone())

View File

@@ -2,7 +2,5 @@ mod cmd;
pub mod error;
pub mod options;
mod protocol;
mod rdb;
mod replication_client;
pub mod server;
mod storage;

View File

@@ -2,7 +2,7 @@
use tokio::net::TcpListener;
use redis_rs::{options::ReplicationOption, server};
use redis_rs::server;
use clap::Parser;
@@ -14,17 +14,10 @@ struct Args {
#[arg(long)]
dir: String,
/// The name of the Redis DB file
#[arg(long)]
dbfilename: String,
/// The port of the Redis server, default is 6379 if not specified
#[arg(long)]
port: Option<u16>,
/// The address of the master Redis server, if the server is a replica. None if the server is a master.
#[arg(long)]
replicaof: Option<String>,
}
#[tokio::main]
@@ -42,42 +35,11 @@ async fn main() {
// new DB option
let option = redis_rs::options::DBOption {
dir: args.dir,
db_file_name: args.dbfilename,
port,
replication: ReplicationOption {
role: if let Some(_) = args.replicaof {
"slave".to_string()
} else {
"master".to_string()
},
master_replid: "8371b4fb1155b71f4a04d3e1bc3e18c4a990aeea".to_string(), // should be a random string but hard code for now
master_repl_offset: 0,
replica_of: args.replicaof,
},
};
// new server
let mut server = server::Server::new(option).await;
//start receive replication cmds for slave
if server.is_slave() {
let mut sc = server.clone();
let mut follower_repl_client = server.get_follower_repl_client().await.unwrap();
follower_repl_client.ping_master().await.unwrap();
follower_repl_client
.report_port(server.option.port)
.await
.unwrap();
follower_repl_client.report_sync_protocol().await.unwrap();
follower_repl_client.start_psync(&mut sc).await.unwrap();
tokio::spawn(async move {
if let Err(e) = sc.handle(follower_repl_client.stream, true).await {
println!("error: {:?}, will close the connection. Bye", e);
}
});
}
let server = server::Server::new(option).await;
// accept new connections
loop {
@@ -88,7 +50,7 @@ async fn main() {
let mut sc = server.clone();
tokio::spawn(async move {
if let Err(e) = sc.handle(stream, false).await {
if let Err(e) = sc.handle(stream).await {
println!("error: {:?}, will close the connection. Bye", e);
}
});

View File

@@ -1,15 +1,5 @@
#[derive(Clone)]
pub struct DBOption {
pub dir: String,
pub db_file_name: String,
pub replication: ReplicationOption,
pub port: u16,
}
#[derive(Clone)]
pub struct ReplicationOption {
pub role: String,
pub master_replid: String,
pub master_repl_offset: u64,
pub replica_of: Option<String>,
}

View File

@@ -1,201 +0,0 @@
// parse Redis RDB file format: https://rdb.fnordig.de/file_format.html
use tokio::{
fs,
io::{AsyncRead, AsyncReadExt, BufReader},
};
use crate::{error::DBError, server::Server};
use futures::pin_mut;
enum StringEncoding {
Raw,
I8,
I16,
I32,
LZF,
}
// RDB file format.
const MAGIC: &[u8; 5] = b"REDIS";
const META: u8 = 0xFA;
const DB_SELECT: u8 = 0xFE;
const TABLE_SIZE_INFO: u8 = 0xFB;
pub const EOF: u8 = 0xFF;
pub async fn parse_rdb<R: AsyncRead + Unpin>(
reader: &mut R,
server: &mut Server,
) -> Result<(), DBError> {
let mut storage = server.storage.lock().await;
parse_magic(reader).await?;
let _version = parse_version(reader).await?;
pin_mut!(reader);
loop {
let op = reader.read_u8().await?;
match op {
META => {
let _ = parse_aux(&mut *reader).await?;
let _ = parse_aux(&mut *reader).await?;
// just ignore the aux info for now
}
DB_SELECT => {
let (_, _) = parse_len(&mut *reader).await?;
// just ignore the db index for now
}
TABLE_SIZE_INFO => {
let size_no_expire = parse_len(&mut *reader).await?.0;
let size_expire = parse_len(&mut *reader).await?.0;
for _ in 0..size_no_expire {
let (k, v) = parse_no_expire_entry(&mut *reader).await?;
storage.set(k, v);
}
for _ in 0..size_expire {
let (k, v, expire_timestamp) = parse_expire_entry(&mut *reader).await?;
storage.setx(k, v, expire_timestamp);
}
}
EOF => {
// not verify crc for now
let _crc = reader.read_u64().await?;
break;
}
_ => return Err(DBError(format!("unexpected op: {}", op))),
}
}
Ok(())
}
pub async fn parse_rdb_file(f: &mut fs::File, server: &mut Server) -> Result<(), DBError> {
let mut reader = BufReader::new(f);
parse_rdb(&mut reader, server).await
}
async fn parse_no_expire_entry<R: AsyncRead + Unpin>(
input: &mut R,
) -> Result<(String, String), DBError> {
let b = input.read_u8().await?;
if b != 0 {
return Err(DBError(format!("unexpected key type: {}", b)));
}
let k = parse_aux(input).await?;
let v = parse_aux(input).await?;
Ok((k, v))
}
async fn parse_expire_entry<R: AsyncRead + Unpin>(
input: &mut R,
) -> Result<(String, String, u128), DBError> {
let b = input.read_u8().await?;
match b {
0xFC => {
// expire in milliseconds
let expire_stamp = input.read_u64_le().await?;
let (k, v) = parse_no_expire_entry(input).await?;
Ok((k, v, expire_stamp as u128))
}
0xFD => {
// expire in seconds
let expire_timestamp = input.read_u32_le().await?;
let (k, v) = parse_no_expire_entry(input).await?;
Ok((k, v, (expire_timestamp * 1000) as u128))
}
_ => return Err(DBError(format!("unexpected expire type: {}", b))),
}
}
async fn parse_magic<R: AsyncRead + Unpin>(input: &mut R) -> Result<(), DBError> {
let mut magic = [0; 5];
let size_read = input.read(&mut magic).await?;
if size_read != 5 {
Err(DBError("expected 5 chars for magic number".to_string()))
} else if magic.as_slice() == MAGIC {
Ok(())
} else {
Err(DBError(format!(
"expected magic string {:?}, but got: {:?}",
MAGIC, magic
)))
}
}
async fn parse_version<R: AsyncRead + Unpin>(input: &mut R) -> Result<[u8; 4], DBError> {
let mut version = [0; 4];
let size_read = input.read(&mut version).await?;
if size_read != 4 {
Err(DBError("expected 4 chars for redis version".to_string()))
} else {
Ok(version)
}
}
async fn parse_aux<R: AsyncRead + Unpin>(input: &mut R) -> Result<String, DBError> {
let (len, encoding) = parse_len(input).await?;
let s = parse_string(input, len, encoding).await?;
Ok(s)
}
async fn parse_len<R: AsyncRead + Unpin>(input: &mut R) -> Result<(u32, StringEncoding), DBError> {
let first = input.read_u8().await?;
match first & 0xC0 {
0x00 => {
// The size is the remaining 6 bits of the byte.
Ok((first as u32, StringEncoding::Raw))
}
0x04 => {
// The size is the next 14 bits of the byte.
let second = input.read_u8().await?;
Ok((
(((first & 0x3F) as u32) << 8 | second as u32) as u32,
StringEncoding::Raw,
))
}
0x80 => {
//Ignore the remaining 6 bits of the first byte. The size is the next 4 bytes, in big-endian
let second = input.read_u32().await?;
Ok((second, StringEncoding::Raw))
}
0xC0 => {
// The remaining 6 bits specify a type of string encoding.
match first {
0xC0 => Ok((1, StringEncoding::I8)),
0xC1 => Ok((2, StringEncoding::I16)),
0xC2 => Ok((4, StringEncoding::I32)),
0xC3 => Ok((0, StringEncoding::LZF)), // not supported yet
_ => Err(DBError(format!("unexpected string encoding: {}", first))),
}
}
_ => Err(DBError(format!("unexpected len prefix: {}", first))),
}
}
async fn parse_string<R: AsyncRead + Unpin>(
input: &mut R,
len: u32,
encoding: StringEncoding,
) -> Result<String, DBError> {
match encoding {
StringEncoding::Raw => {
let mut s = vec![0; len as usize];
input.read_exact(&mut s).await?;
Ok(String::from_utf8(s).unwrap())
}
StringEncoding::I8 => {
let b = input.read_u8().await?;
Ok(b.to_string())
}
StringEncoding::I16 => {
let b = input.read_u16_le().await?;
Ok(b.to_string())
}
StringEncoding::I32 => {
let b = input.read_u32_le().await?;
Ok(b.to_string())
}
StringEncoding::LZF => {
// not supported yet
Err(DBError("LZF encoding not supported yet".to_string()))
}
}
}

View File

@@ -1,155 +0,0 @@
use std::{num::ParseIntError, sync::Arc};
use tokio::{
io::{AsyncBufReadExt, AsyncReadExt, AsyncWriteExt, BufReader},
net::TcpStream,
sync::Mutex,
};
use crate::{error::DBError, protocol::Protocol, rdb, server::Server};
const EMPTY_RDB_FILE_HEX_STRING: &str = "524544495330303131fa0972656469732d76657205372e322e30fa0a72656469732d62697473c040fa056374696d65c26d08bc65fa08757365642d6d656dc2b0c41000fa08616f662d62617365c000fff06e3bfec0ff5aa2";
pub struct FollowerReplicationClient {
pub stream: TcpStream,
}
impl FollowerReplicationClient {
pub async fn new(addr: String) -> FollowerReplicationClient {
FollowerReplicationClient {
stream: TcpStream::connect(addr).await.unwrap(),
}
}
pub async fn ping_master(self: &mut Self) -> Result<(), DBError> {
let protocol = Protocol::Array(vec![Protocol::BulkString("PING".to_string())]);
self.stream.write_all(protocol.encode().as_bytes()).await?;
self.check_resp("PONG").await
}
pub async fn report_port(self: &mut Self, port: u16) -> Result<(), DBError> {
let protocol = Protocol::from_vec(vec![
"REPLCONF",
"listening-port",
port.to_string().as_str(),
]);
self.stream.write_all(protocol.encode().as_bytes()).await?;
self.check_resp("OK").await
}
pub async fn report_sync_protocol(self: &mut Self) -> Result<(), DBError> {
let p = Protocol::from_vec(vec!["REPLCONF", "capa", "psync2"]);
self.stream.write_all(p.encode().as_bytes()).await?;
self.check_resp("OK").await
}
pub async fn start_psync(self: &mut Self, server: &mut Server) -> Result<(), DBError> {
let p = Protocol::from_vec(vec!["PSYNC", "?", "-1"]);
self.stream.write_all(p.encode().as_bytes()).await?;
self.recv_rdb_file(server).await?;
Ok(())
}
pub async fn recv_rdb_file(self: &mut Self, server: &mut Server) -> Result<(), DBError> {
let mut reader = BufReader::new(&mut self.stream);
let mut buf = Vec::new();
let _ = reader.read_until(b'\n', &mut buf).await?;
buf.pop();
buf.pop();
let replication_info = String::from_utf8(buf)?;
let replication_info = replication_info
.split_whitespace()
.map(|x| x.to_string())
.collect::<Vec<String>>();
if replication_info.len() != 3 {
return Err(DBError(format!(
"expect 3 args but found {:?}",
replication_info
)));
}
println!(
"Get replication info: {:?} {:?} {:?}",
replication_info[0], replication_info[1], replication_info[2]
);
let c = reader.read_u8().await?;
if c != b'$' {
return Err(DBError(format!("expect $ but found {}", c)));
}
let mut buf = Vec::new();
reader.read_until(b'\n', &mut buf).await?;
buf.pop();
buf.pop();
let rdb_file_len = String::from_utf8(buf)?.parse::<usize>()?;
println!("rdb file len: {}", rdb_file_len);
// receive rdb file content
rdb::parse_rdb(&mut reader, server).await?;
Ok(())
}
pub async fn check_resp(&mut self, expected: &str) -> Result<(), DBError> {
let mut buf = [0; 1024];
let n_bytes = self.stream.read(&mut buf).await?;
println!(
"check resp: recv {:?}",
String::from_utf8(buf[..n_bytes].to_vec()).unwrap()
);
let expect = Protocol::SimpleString(expected.to_string()).encode();
if expect.as_bytes() != &buf[..n_bytes] {
return Err(DBError(format!(
"expect response {:?} but found {:?}",
expect,
&buf[..n_bytes]
)));
}
Ok(())
}
}
#[derive(Clone)]
pub struct MasterReplicationClient {
pub streams: Arc<Mutex<Vec<TcpStream>>>,
}
impl MasterReplicationClient {
pub fn new() -> MasterReplicationClient {
MasterReplicationClient {
streams: Arc::new(Mutex::new(Vec::new())),
}
}
pub async fn send_rdb_file(&mut self, stream: &mut TcpStream) -> Result<(), DBError> {
let empty_rdb_file_bytes = (0..EMPTY_RDB_FILE_HEX_STRING.len())
.step_by(2)
.map(|i| u8::from_str_radix(&EMPTY_RDB_FILE_HEX_STRING[i..i + 2], 16))
.collect::<Result<Vec<u8>, ParseIntError>>()?;
println!("going to send rdb file");
_ = stream.write("$".as_bytes()).await?;
_ = stream
.write(empty_rdb_file_bytes.len().to_string().as_bytes())
.await?;
_ = stream.write_all("\r\n".as_bytes()).await?;
_ = stream.write_all(&empty_rdb_file_bytes).await?;
Ok(())
}
pub async fn add_stream(&mut self, stream: TcpStream) -> Result<(), DBError> {
let mut streams = self.streams.lock().await;
streams.push(stream);
Ok(())
}
pub async fn send_command(&mut self, protocol: Protocol) -> Result<(), DBError> {
let mut streams = self.streams.lock().await;
for stream in streams.iter_mut() {
stream.write_all(protocol.encode().as_bytes()).await?;
}
Ok(())
}
}

View File

@@ -1,143 +1,63 @@
use core::str;
use std::collections::BTreeMap;
use std::collections::HashMap;
use std::path::PathBuf;
use std::sync::atomic::AtomicU64;
use std::sync::Arc;
use tokio::fs::OpenOptions;
use tokio::io::AsyncReadExt;
use tokio::io::AsyncWriteExt;
use tokio::sync::mpsc::Sender;
use tokio::sync::Mutex;
use crate::cmd::Cmd;
use crate::error::DBError;
use crate::options;
use crate::protocol::Protocol;
use crate::rdb;
use crate::replication_client::FollowerReplicationClient;
use crate::replication_client::MasterReplicationClient;
use crate::storage::Storage;
type Stream = BTreeMap<String, Vec<(String, String)>>;
#[derive(Clone)]
pub struct Server {
pub storage: Arc<Mutex<Storage>>,
pub streams: Arc<Mutex<HashMap<String, Stream>>>,
pub storage: Arc<Storage>,
pub option: options::DBOption,
pub offset: Arc<AtomicU64>,
pub master_repl_clients: Arc<Mutex<Option<MasterReplicationClient>>>,
pub stream_reader_blocker: Arc<Mutex<Vec<Sender<()>>>>,
master_addr: Option<String>,
}
impl Server {
pub async fn new(option: options::DBOption) -> Self {
let master_addr = match option.replication.role.as_str() {
"slave" => Some(
option
.replication
.replica_of
.clone()
.unwrap()
.replace(' ', ":"),
),
_ => None,
};
// Create database file path with fixed filename
let db_file_path = PathBuf::from(option.dir.clone()).join("herodb.redb");
println!("will open db file path: {}", db_file_path.display());
let is_master = option.replication.role == "master";
// Initialize storage with redb
let storage = Storage::new(db_file_path).expect("Failed to initialize storage");
let mut server = Server {
storage: Arc::new(Mutex::new(Storage::new())),
streams: Arc::new(Mutex::new(HashMap::new())),
Server {
storage: Arc::new(storage),
option,
master_repl_clients: if is_master {
Arc::new(Mutex::new(Some(MasterReplicationClient::new())))
} else {
Arc::new(Mutex::new(None))
},
offset: Arc::new(AtomicU64::new(0)),
stream_reader_blocker: Arc::new(Mutex::new(Vec::new())),
master_addr,
};
server.init().await.unwrap();
server
}
pub async fn init(&mut self) -> Result<(), DBError> {
// master initialization
if self.is_master() {
println!("Start as master\n");
let db_file_path =
PathBuf::from(self.option.dir.clone()).join(self.option.db_file_name.clone());
println!("will open db file path: {}", db_file_path.display());
// create empty db file if not exits
let mut file = OpenOptions::new()
.read(true)
.write(true)
.create(true)
.truncate(false)
.open(db_file_path.clone())
.await?;
if file.metadata().await?.len() != 0 {
rdb::parse_rdb_file(&mut file, self).await?;
}
}
Ok(())
}
pub async fn get_follower_repl_client(&mut self) -> Option<FollowerReplicationClient> {
if self.is_slave() {
Some(FollowerReplicationClient::new(self.master_addr.clone().unwrap()).await)
} else {
None
}
}
pub async fn handle(
&mut self,
mut stream: tokio::net::TcpStream,
is_rep_conn: bool,
) -> Result<(), DBError> {
let mut buf = [0; 512];
let mut queued_cmd: Option<Vec<(Cmd, Protocol)>> = None;
loop {
if let Ok(len) = stream.read(&mut buf).await {
if len == 0 {
println!("[handle] connection closed");
return Ok(());
}
let s = str::from_utf8(&buf[..len])?;
let (cmd, protocol) =
Cmd::from(s).unwrap_or((Cmd::Unknow, Protocol::err("unknow cmd")));
println!("got command: {:?}, protocol: {:?}", cmd, protocol);
let res = cmd
.run(self, protocol, is_rep_conn, &mut queued_cmd)
.run(self, protocol, &mut queued_cmd)
.await
.unwrap_or(Protocol::err("unknow cmd"));
print!("queued 2 cmd {:?}", queued_cmd);
print!("queued cmd {:?}", queued_cmd);
// only send response to normal client, do not send response to replication client
if !is_rep_conn {
println!("going to send response {}", res.encode());
_ = stream.write(res.encode().as_bytes()).await?;
}
// send a full RDB file to slave
if self.is_master() {
if let Cmd::Psync = cmd {
let mut master_rep_client = self.master_repl_clients.lock().await;
let master_rep_client = master_rep_client.as_mut().unwrap();
master_rep_client.send_rdb_file(&mut stream).await?;
master_rep_client.add_stream(stream).await?;
break;
}
}
println!("going to send response {}", res.encode());
_ = stream.write(res.encode().as_bytes()).await?;
} else {
println!("[handle] going to break");
break;
@@ -145,12 +65,4 @@ impl Server {
}
Ok(())
}
pub fn is_slave(&self) -> bool {
self.option.replication.role == "slave"
}
pub fn is_master(&self) -> bool {
!self.is_slave()
}
}

View File

@@ -1,13 +1,29 @@
use std::{
collections::HashMap,
path::Path,
time::{SystemTime, UNIX_EPOCH},
};
pub type ValueType = (String, Option<u128>);
use redb::{Database, Error, ReadableTable, Table, TableDefinition, WriteTransaction, ReadTransaction};
use serde::{Deserialize, Serialize};
pub struct Storage {
// key -> (value, (insert/update time, expire milli seconds))
set: HashMap<String, ValueType>,
use crate::error::DBError;
// Table definitions for different Redis data types
const TYPES_TABLE: TableDefinition<&str, &str> = TableDefinition::new("types");
const STRINGS_TABLE: TableDefinition<&str, &[u8]> = TableDefinition::new("strings");
const HASHES_TABLE: TableDefinition<(&str, &str), &str> = TableDefinition::new("hashes");
const STREAMS_META_TABLE: TableDefinition<&str, &[u8]> = TableDefinition::new("streams_meta");
const STREAMS_DATA_TABLE: TableDefinition<(&str, &str), &[u8]> = TableDefinition::new("streams_data");
#[derive(Serialize, Deserialize, Debug, Clone)]
pub struct StringValue {
pub value: String,
pub expires_at_ms: Option<u128>,
}
#[derive(Serialize, Deserialize, Debug, Clone)]
pub struct StreamEntry {
pub fields: Vec<(String, String)>,
}
#[inline]
@@ -17,43 +33,416 @@ pub fn now_in_millis() -> u128 {
duration_since_epoch.as_millis()
}
pub struct Storage {
db: Database,
}
impl Storage {
pub fn new() -> Self {
Storage {
set: HashMap::new(),
pub fn new(path: impl AsRef<Path>) -> Result<Self, DBError> {
let db = Database::create(path)?;
// Create tables if they don't exist
let write_txn = db.begin_write()?;
{
let _ = write_txn.open_table(TYPES_TABLE)?;
let _ = write_txn.open_table(STRINGS_TABLE)?;
let _ = write_txn.open_table(HASHES_TABLE)?;
let _ = write_txn.open_table(STREAMS_META_TABLE)?;
let _ = write_txn.open_table(STREAMS_DATA_TABLE)?;
}
write_txn.commit()?;
Ok(Storage { db })
}
pub fn get_key_type(&self, key: &str) -> Result<Option<String>, DBError> {
let read_txn = self.db.begin_read()?;
let table = read_txn.open_table(TYPES_TABLE)?;
match table.get(key)? {
Some(type_val) => Ok(Some(type_val.value().to_string())),
None => Ok(None),
}
}
pub fn get(self: &mut Self, k: &str) -> Option<String> {
match self.set.get(k) {
Some((ss, expire_timestamp)) => match expire_timestamp {
Some(expire_time_stamp) => {
if now_in_millis() > *expire_time_stamp {
self.set.remove(k);
None
} else {
Some(ss.clone())
pub fn get(&self, key: &str) -> Result<Option<String>, DBError> {
let read_txn = self.db.begin_read()?;
// Check if key exists and is of string type
let types_table = read_txn.open_table(TYPES_TABLE)?;
match types_table.get(key)? {
Some(type_val) if type_val.value() == "string" => {
let strings_table = read_txn.open_table(STRINGS_TABLE)?;
match strings_table.get(key)? {
Some(data) => {
let string_value: StringValue = bincode::deserialize(data.value())?;
// Check if expired
if let Some(expires_at) = string_value.expires_at_ms {
if now_in_millis() > expires_at {
// Key expired, remove it
drop(read_txn);
self.del(key.to_string())?;
return Ok(None);
}
}
Ok(Some(string_value.value))
}
None => Ok(None),
}
}
_ => Ok(None),
}
}
pub fn set(&self, key: String, value: String) -> Result<(), DBError> {
let write_txn = self.db.begin_write()?;
{
let mut types_table = write_txn.open_table(TYPES_TABLE)?;
types_table.insert(key.as_str(), "string")?;
let mut strings_table = write_txn.open_table(STRINGS_TABLE)?;
let string_value = StringValue {
value,
expires_at_ms: None,
};
let serialized = bincode::serialize(&string_value)?;
strings_table.insert(key.as_str(), serialized.as_slice())?;
}
write_txn.commit()?;
Ok(())
}
pub fn setx(&self, key: String, value: String, expire_ms: u128) -> Result<(), DBError> {
let write_txn = self.db.begin_write()?;
{
let mut types_table = write_txn.open_table(TYPES_TABLE)?;
types_table.insert(key.as_str(), "string")?;
let mut strings_table = write_txn.open_table(STRINGS_TABLE)?;
let string_value = StringValue {
value,
expires_at_ms: Some(expire_ms + now_in_millis()),
};
let serialized = bincode::serialize(&string_value)?;
strings_table.insert(key.as_str(), serialized.as_slice())?;
}
write_txn.commit()?;
Ok(())
}
pub fn del(&self, key: String) -> Result<(), DBError> {
let write_txn = self.db.begin_write()?;
{
let mut types_table = write_txn.open_table(TYPES_TABLE)?;
let mut strings_table = write_txn.open_table(STRINGS_TABLE)?;
let mut hashes_table = write_txn.open_table(HASHES_TABLE)?;
// Remove from type table
types_table.remove(key.as_str())?;
// Remove from strings table
strings_table.remove(key.as_str())?;
// Remove all hash fields for this key
let mut to_remove = Vec::new();
let mut iter = hashes_table.iter()?;
while let Some(entry) = iter.next() {
let entry = entry?;
let (hash_key, field) = entry.0.value();
if hash_key == key.as_str() {
to_remove.push((hash_key.to_string(), field.to_string()));
}
}
drop(iter);
for (hash_key, field) in to_remove {
hashes_table.remove((hash_key.as_str(), field.as_str()))?;
}
}
write_txn.commit()?;
Ok(())
}
pub fn keys(&self, pattern: &str) -> Result<Vec<String>, DBError> {
let read_txn = self.db.begin_read()?;
let table = read_txn.open_table(TYPES_TABLE)?;
let mut keys = Vec::new();
let mut iter = table.iter()?;
while let Some(entry) = iter.next() {
let key = entry?.0.value().to_string();
if pattern == "*" || key.contains(pattern) {
keys.push(key);
}
}
Ok(keys)
}
// Hash operations
pub fn hset(&self, key: &str, pairs: &[(String, String)]) -> Result<u64, DBError> {
let write_txn = self.db.begin_write()?;
let mut new_fields = 0u64;
{
let mut types_table = write_txn.open_table(TYPES_TABLE)?;
let mut hashes_table = write_txn.open_table(HASHES_TABLE)?;
// Check if key exists and is of correct type
let existing_type = match types_table.get(key)? {
Some(type_val) => Some(type_val.value().to_string()),
None => None,
};
match existing_type {
Some(ref type_str) if type_str != "hash" => {
return Err(DBError("WRONGTYPE Operation against a key holding the wrong kind of value".to_string()));
}
None => {
// Set type to hash
types_table.insert(key, "hash")?;
}
_ => {}
}
for (field, value) in pairs {
let existed = hashes_table.get((key, field.as_str()))?.is_some();
hashes_table.insert((key, field.as_str()), value.as_str())?;
if !existed {
new_fields += 1;
}
}
}
write_txn.commit()?;
Ok(new_fields)
}
pub fn hget(&self, key: &str, field: &str) -> Result<Option<String>, DBError> {
let read_txn = self.db.begin_read()?;
// Check type
let types_table = read_txn.open_table(TYPES_TABLE)?;
match types_table.get(key)? {
Some(type_val) if type_val.value() == "hash" => {
let hashes_table = read_txn.open_table(HASHES_TABLE)?;
match hashes_table.get((key, field))? {
Some(value) => Ok(Some(value.value().to_string())),
None => Ok(None),
}
}
Some(_) => Err(DBError("WRONGTYPE Operation against a key holding the wrong kind of value".to_string())),
None => Ok(None),
}
}
pub fn hgetall(&self, key: &str) -> Result<Vec<(String, String)>, DBError> {
let read_txn = self.db.begin_read()?;
// Check type
let types_table = read_txn.open_table(TYPES_TABLE)?;
match types_table.get(key)? {
Some(type_val) if type_val.value() == "hash" => {
let hashes_table = read_txn.open_table(HASHES_TABLE)?;
let mut result = Vec::new();
let mut iter = hashes_table.iter()?;
while let Some(entry) = iter.next() {
let entry = entry?;
let (hash_key, field) = entry.0.value();
let value = entry.1.value();
if hash_key == key {
result.push((field.to_string(), value.to_string()));
}
}
_ => Some(ss.clone()),
},
_ => None,
Ok(result)
}
Some(_) => Err(DBError("WRONGTYPE Operation against a key holding the wrong kind of value".to_string())),
None => Ok(Vec::new()),
}
}
pub fn set(self: &mut Self, k: String, v: String) {
self.set.insert(k, (v, None));
pub fn hdel(&self, key: &str, fields: &[String]) -> Result<u64, DBError> {
let write_txn = self.db.begin_write()?;
let mut deleted = 0u64;
{
let types_table = write_txn.open_table(TYPES_TABLE)?;
let key_type = types_table.get(key)?;
match key_type {
Some(type_val) if type_val.value() == "hash" => {
let mut hashes_table = write_txn.open_table(HASHES_TABLE)?;
for field in fields {
if hashes_table.remove((key, field.as_str()))?.is_some() {
deleted += 1;
}
}
}
Some(_) => return Err(DBError("WRONGTYPE Operation against a key holding the wrong kind of value".to_string())),
None => {}
}
}
write_txn.commit()?;
Ok(deleted)
}
pub fn setx(self: &mut Self, k: String, v: String, expire_ms: u128) {
self.set.insert(k, (v, Some(expire_ms + now_in_millis())));
pub fn hexists(&self, key: &str, field: &str) -> Result<bool, DBError> {
let read_txn = self.db.begin_read()?;
let types_table = read_txn.open_table(TYPES_TABLE)?;
match types_table.get(key)? {
Some(type_val) if type_val.value() == "hash" => {
let hashes_table = read_txn.open_table(HASHES_TABLE)?;
Ok(hashes_table.get((key, field))?.is_some())
}
Some(_) => Err(DBError("WRONGTYPE Operation against a key holding the wrong kind of value".to_string())),
None => Ok(false),
}
}
pub fn del(self: &mut Self, k: String) {
self.set.remove(&k);
pub fn hkeys(&self, key: &str) -> Result<Vec<String>, DBError> {
let read_txn = self.db.begin_read()?;
let types_table = read_txn.open_table(TYPES_TABLE)?;
match types_table.get(key)? {
Some(type_val) if type_val.value() == "hash" => {
let hashes_table = read_txn.open_table(HASHES_TABLE)?;
let mut result = Vec::new();
let mut iter = hashes_table.iter()?;
while let Some(entry) = iter.next() {
let entry = entry?;
let (hash_key, field) = entry.0.value();
if hash_key == key {
result.push(field.to_string());
}
}
Ok(result)
}
Some(_) => Err(DBError("WRONGTYPE Operation against a key holding the wrong kind of value".to_string())),
None => Ok(Vec::new()),
}
}
pub fn keys(self: &Self) -> Vec<String> {
self.set.keys().map(|x| x.clone()).collect()
pub fn hvals(&self, key: &str) -> Result<Vec<String>, DBError> {
let read_txn = self.db.begin_read()?;
let types_table = read_txn.open_table(TYPES_TABLE)?;
match types_table.get(key)? {
Some(type_val) if type_val.value() == "hash" => {
let hashes_table = read_txn.open_table(HASHES_TABLE)?;
let mut result = Vec::new();
let mut iter = hashes_table.iter()?;
while let Some(entry) = iter.next() {
let entry = entry?;
let (hash_key, _) = entry.0.value();
let value = entry.1.value();
if hash_key == key {
result.push(value.to_string());
}
}
Ok(result)
}
Some(_) => Err(DBError("WRONGTYPE Operation against a key holding the wrong kind of value".to_string())),
None => Ok(Vec::new()),
}
}
pub fn hlen(&self, key: &str) -> Result<u64, DBError> {
let read_txn = self.db.begin_read()?;
let types_table = read_txn.open_table(TYPES_TABLE)?;
match types_table.get(key)? {
Some(type_val) if type_val.value() == "hash" => {
let hashes_table = read_txn.open_table(HASHES_TABLE)?;
let mut count = 0u64;
let mut iter = hashes_table.iter()?;
while let Some(entry) = iter.next() {
let entry = entry?;
let (hash_key, _) = entry.0.value();
if hash_key == key {
count += 1;
}
}
Ok(count)
}
Some(_) => Err(DBError("WRONGTYPE Operation against a key holding the wrong kind of value".to_string())),
None => Ok(0),
}
}
pub fn hmget(&self, key: &str, fields: &[String]) -> Result<Vec<Option<String>>, DBError> {
let read_txn = self.db.begin_read()?;
let types_table = read_txn.open_table(TYPES_TABLE)?;
match types_table.get(key)? {
Some(type_val) if type_val.value() == "hash" => {
let hashes_table = read_txn.open_table(HASHES_TABLE)?;
let mut result = Vec::new();
for field in fields {
match hashes_table.get((key, field.as_str()))? {
Some(value) => result.push(Some(value.value().to_string())),
None => result.push(None),
}
}
Ok(result)
}
Some(_) => Err(DBError("WRONGTYPE Operation against a key holding the wrong kind of value".to_string())),
None => Ok(fields.iter().map(|_| None).collect()),
}
}
pub fn hsetnx(&self, key: &str, field: &str, value: &str) -> Result<bool, DBError> {
let write_txn = self.db.begin_write()?;
let mut result = false;
{
let mut types_table = write_txn.open_table(TYPES_TABLE)?;
let mut hashes_table = write_txn.open_table(HASHES_TABLE)?;
// Check if key exists and is of correct type
let existing_type = match types_table.get(key)? {
Some(type_val) => Some(type_val.value().to_string()),
None => None,
};
match existing_type {
Some(ref type_str) if type_str != "hash" => {
return Err(DBError("WRONGTYPE Operation against a key holding the wrong kind of value".to_string()));
}
None => {
// Set type to hash
types_table.insert(key, "hash")?;
}
_ => {}
}
// Check if field already exists
if hashes_table.get((key, field))?.is_none() {
hashes_table.insert((key, field), value)?;
result = true;
}
}
write_txn.commit()?;
Ok(result)
}
}

302
test_herodb.sh Executable file
View File

@@ -0,0 +1,302 @@
#!/bin/bash
# Test script for HeroDB - Redis-compatible database with redb backend
# This script starts the server and runs comprehensive tests
set -e
# Colors for output
RED='\033[0;31m'
GREEN='\033[0;32m'
YELLOW='\033[1;33m'
BLUE='\033[0;34m'
NC='\033[0m' # No Color
# Configuration
DB_DIR="./test_db"
PORT=6379
SERVER_PID=""
# Function to print colored output
print_status() {
echo -e "${BLUE}[INFO]${NC} $1"
}
print_success() {
echo -e "${GREEN}[SUCCESS]${NC} $1"
}
print_error() {
echo -e "${RED}[ERROR]${NC} $1"
}
print_warning() {
echo -e "${YELLOW}[WARNING]${NC} $1"
}
# Function to cleanup on exit
cleanup() {
if [ ! -z "$SERVER_PID" ]; then
print_status "Stopping HeroDB server (PID: $SERVER_PID)..."
kill $SERVER_PID 2>/dev/null || true
wait $SERVER_PID 2>/dev/null || true
fi
# Clean up test database
if [ -d "$DB_DIR" ]; then
print_status "Cleaning up test database directory..."
rm -rf "$DB_DIR"
fi
}
# Set trap to cleanup on script exit
trap cleanup EXIT
# Function to wait for server to start
wait_for_server() {
local max_attempts=30
local attempt=1
print_status "Waiting for server to start on port $PORT..."
while [ $attempt -le $max_attempts ]; do
if nc -z localhost $PORT 2>/dev/null; then
print_success "Server is ready!"
return 0
fi
echo -n "."
sleep 1
attempt=$((attempt + 1))
done
print_error "Server failed to start within $max_attempts seconds"
return 1
}
# Function to send Redis command and get response
redis_cmd() {
local cmd="$1"
local expected="$2"
print_status "Testing: $cmd"
local result=$(echo "$cmd" | redis-cli -p $PORT --raw 2>/dev/null || echo "ERROR")
if [ "$expected" != "" ] && [ "$result" != "$expected" ]; then
print_error "Expected: '$expected', Got: '$result'"
return 1
else
print_success "$cmd -> $result"
return 0
fi
}
# Function to test basic string operations
test_string_operations() {
print_status "=== Testing String Operations ==="
redis_cmd "PING" "PONG"
redis_cmd "SET mykey hello" "OK"
redis_cmd "GET mykey" "hello"
redis_cmd "SET counter 1" "OK"
redis_cmd "INCR counter" "2"
redis_cmd "INCR counter" "3"
redis_cmd "GET counter" "3"
redis_cmd "DEL mykey" "1"
redis_cmd "GET mykey" ""
redis_cmd "TYPE counter" "string"
redis_cmd "TYPE nonexistent" "none"
}
# Function to test hash operations
test_hash_operations() {
print_status "=== Testing Hash Operations ==="
# HSET and HGET
redis_cmd "HSET user:1 name John" "1"
redis_cmd "HSET user:1 age 30 city NYC" "2"
redis_cmd "HGET user:1 name" "John"
redis_cmd "HGET user:1 age" "30"
redis_cmd "HGET user:1 nonexistent" ""
# HGETALL
print_status "Testing HGETALL user:1"
redis_cmd "HGETALL user:1" ""
# HEXISTS
redis_cmd "HEXISTS user:1 name" "1"
redis_cmd "HEXISTS user:1 nonexistent" "0"
# HKEYS
print_status "Testing HKEYS user:1"
redis_cmd "HKEYS user:1" ""
# HVALS
print_status "Testing HVALS user:1"
redis_cmd "HVALS user:1" ""
# HLEN
redis_cmd "HLEN user:1" "3"
# HMGET
print_status "Testing HMGET user:1 name age"
redis_cmd "HMGET user:1 name age" ""
# HSETNX
redis_cmd "HSETNX user:1 name Jane" "0" # Should not set, field exists
redis_cmd "HSETNX user:1 email john@example.com" "1" # Should set, new field
redis_cmd "HGET user:1 email" "john@example.com"
# HDEL
redis_cmd "HDEL user:1 age city" "2"
redis_cmd "HLEN user:1" "2"
redis_cmd "HEXISTS user:1 age" "0"
# Test type checking
redis_cmd "SET stringkey value" "OK"
print_status "Testing WRONGTYPE error on string key"
redis_cmd "HGET stringkey field" "" # Should return WRONGTYPE error
}
# Function to test configuration commands
test_config_operations() {
print_status "=== Testing Configuration Operations ==="
print_status "Testing CONFIG GET dir"
redis_cmd "CONFIG GET dir" ""
print_status "Testing CONFIG GET dbfilename"
redis_cmd "CONFIG GET dbfilename" ""
}
# Function to test transaction operations
test_transaction_operations() {
print_status "=== Testing Transaction Operations ==="
redis_cmd "MULTI" "OK"
redis_cmd "SET tx_key1 value1" "QUEUED"
redis_cmd "SET tx_key2 value2" "QUEUED"
redis_cmd "INCR counter" "QUEUED"
print_status "Testing EXEC"
redis_cmd "EXEC" ""
redis_cmd "GET tx_key1" "value1"
redis_cmd "GET tx_key2" "value2"
# Test DISCARD
redis_cmd "MULTI" "OK"
redis_cmd "SET discard_key value" "QUEUED"
redis_cmd "DISCARD" "OK"
redis_cmd "GET discard_key" ""
}
# Function to test keys operations
test_keys_operations() {
print_status "=== Testing Keys Operations ==="
print_status "Testing KEYS *"
redis_cmd "KEYS *" ""
}
# Function to test info operations
test_info_operations() {
print_status "=== Testing Info Operations ==="
print_status "Testing INFO"
redis_cmd "INFO" ""
print_status "Testing INFO replication"
redis_cmd "INFO replication" ""
}
# Function to test expiration
test_expiration() {
print_status "=== Testing Expiration ==="
redis_cmd "SET expire_key value" "OK"
redis_cmd "SET expire_px_key value PX 1000" "OK" # 1 second
redis_cmd "SET expire_ex_key value EX 1" "OK" # 1 second
redis_cmd "GET expire_key" "value"
redis_cmd "GET expire_px_key" "value"
redis_cmd "GET expire_ex_key" "value"
print_status "Waiting 2 seconds for expiration..."
sleep 2
redis_cmd "GET expire_key" "value" # Should still exist
redis_cmd "GET expire_px_key" "" # Should be expired
redis_cmd "GET expire_ex_key" "" # Should be expired
}
# Main execution
main() {
print_status "Starting HeroDB comprehensive test suite..."
# Build the project
print_status "Building HeroDB..."
if ! cargo build --release; then
print_error "Failed to build HeroDB"
exit 1
fi
# Create test database directory
mkdir -p "$DB_DIR"
# Start the server
print_status "Starting HeroDB server..."
./target/release/redis-rs --dir "$DB_DIR" --port $PORT &
SERVER_PID=$!
# Wait for server to start
if ! wait_for_server; then
print_error "Failed to start server"
exit 1
fi
# Run tests
local failed_tests=0
test_string_operations || failed_tests=$((failed_tests + 1))
test_hash_operations || failed_tests=$((failed_tests + 1))
test_config_operations || failed_tests=$((failed_tests + 1))
test_transaction_operations || failed_tests=$((failed_tests + 1))
test_keys_operations || failed_tests=$((failed_tests + 1))
test_info_operations || failed_tests=$((failed_tests + 1))
test_expiration || failed_tests=$((failed_tests + 1))
# Summary
echo
print_status "=== Test Summary ==="
if [ $failed_tests -eq 0 ]; then
print_success "All tests completed! Some may have warnings due to protocol differences."
print_success "HeroDB is working with persistent redb storage!"
else
print_warning "$failed_tests test categories had issues"
print_warning "Check the output above for details"
fi
print_status "Database file created at: $DB_DIR/herodb.redb"
print_status "Server logs and any errors are shown above"
}
# Check dependencies
check_dependencies() {
if ! command -v cargo &> /dev/null; then
print_error "cargo is required but not installed"
exit 1
fi
if ! command -v nc &> /dev/null; then
print_warning "netcat (nc) not found - some tests may not work properly"
fi
if ! command -v redis-cli &> /dev/null; then
print_warning "redis-cli not found - using netcat fallback"
fi
}
# Run dependency check and main function
check_dependencies
main "$@"