This commit is contained in:
2025-08-16 06:58:04 +02:00
parent 31c47d7998
commit cd61406d1d
20 changed files with 3241 additions and 1 deletions

621
src/cmd.rs Normal file
View File

@@ -0,0 +1,621 @@
use std::{collections::BTreeMap, ops::Bound, time::Duration, u64};
use tokio::sync::mpsc;
use crate::{error::DBError, protocol::Protocol, server::Server, storage::now_in_millis};
#[derive(Debug, Clone)]
pub enum Cmd {
Ping,
Echo(String),
Get(String),
Set(String, String),
SetPx(String, String, u128),
SetEx(String, String, u128),
Keys,
ConfigGet(String),
Info(Option<String>),
Del(String),
Replconf(String),
Psync,
Type(String),
Xadd(String, String, Vec<(String, String)>),
Xrange(String, String, String),
Xread(Vec<String>, Vec<String>, Option<u64>),
Incr(String),
Multi,
Exec,
Unknow,
Discard,
}
impl Cmd {
pub fn from(s: &str) -> Result<(Self, Protocol), DBError> {
let protocol = Protocol::from(s)?;
match protocol.clone().0 {
Protocol::Array(p) => {
let cmd = p.into_iter().map(|x| x.decode()).collect::<Vec<_>>();
if cmd.is_empty() {
return Err(DBError("cmd length is 0".to_string()));
}
Ok((
match cmd[0].as_str() {
"echo" => Cmd::Echo(cmd[1].clone()),
"ping" => Cmd::Ping,
"get" => Cmd::Get(cmd[1].clone()),
"set" => {
if cmd.len() == 5 && cmd[3] == "px" {
Cmd::SetPx(cmd[1].clone(), cmd[2].clone(), cmd[4].parse().unwrap())
} else if cmd.len() == 5 && cmd[3] == "ex" {
Cmd::SetEx(cmd[1].clone(), cmd[2].clone(), cmd[4].parse().unwrap())
} else if cmd.len() == 3 {
Cmd::Set(cmd[1].clone(), cmd[2].clone())
} else {
return Err(DBError(format!("unsupported cmd {:?}", cmd)));
}
}
"config" => {
if cmd.len() != 3 || cmd[1] != "get" {
return Err(DBError(format!("unsupported cmd {:?}", cmd)));
} else {
Cmd::ConfigGet(cmd[2].clone())
}
}
"keys" => {
if cmd.len() != 2 || cmd[1] != "*" {
return Err(DBError(format!("unsupported cmd {:?}", cmd)));
} else {
Cmd::Keys
}
}
"info" => {
let section = if cmd.len() == 2 {
Some(cmd[1].clone())
} else {
None
};
Cmd::Info(section)
}
"replconf" => {
if cmd.len() < 3 {
return Err(DBError(format!("unsupported cmd {:?}", cmd)));
}
Cmd::Replconf(cmd[1].clone())
}
"psync" => {
if cmd.len() != 3 {
return Err(DBError(format!("unsupported cmd {:?}", cmd)));
}
Cmd::Psync
}
"del" => {
if cmd.len() != 2 {
return Err(DBError(format!("unsupported cmd {:?}", cmd)));
}
Cmd::Del(cmd[1].clone())
}
"type" => {
if cmd.len() != 2 {
return Err(DBError(format!("unsupported cmd {:?}", cmd)));
}
Cmd::Type(cmd[1].clone())
}
"xadd" => {
if cmd.len() < 5 {
return Err(DBError(format!("unsupported cmd {:?}", cmd)));
}
let mut key_value = Vec::<(String, String)>::new();
let mut i = 3;
while i < cmd.len() - 1 {
key_value.push((cmd[i].clone(), cmd[i + 1].clone()));
i += 2;
}
Cmd::Xadd(cmd[1].clone(), cmd[2].clone(), key_value)
}
"xrange" => {
if cmd.len() != 4 {
return Err(DBError(format!("unsupported cmd {:?}", cmd)));
}
Cmd::Xrange(cmd[1].clone(), cmd[2].clone(), cmd[3].clone())
}
"xread" => {
if cmd.len() < 4 || cmd.len() % 2 != 0 {
return Err(DBError(format!("unsupported cmd {:?}", cmd)));
}
let mut offset = 2;
// block cmd
let mut block = None;
if cmd[1] == "block" {
offset += 2;
if let Ok(block_time) = cmd[2].parse() {
block = Some(block_time);
} else {
return Err(DBError(format!("unsupported cmd {:?}", cmd)));
}
}
let cmd2 = &cmd[offset..];
let len2 = cmd2.len() / 2;
Cmd::Xread(cmd2[0..len2].to_vec(), cmd2[len2..].to_vec(), block)
}
"incr" => {
if cmd.len() != 2 {
return Err(DBError(format!("unsupported cmd {:?}", cmd)));
}
Cmd::Incr(cmd[1].clone())
}
"multi" => {
if cmd.len() != 1 {
return Err(DBError(format!("unsupported cmd {:?}", cmd)));
}
Cmd::Multi
}
"exec" => {
if cmd.len() != 1 {
return Err(DBError(format!("unsupported cmd {:?}", cmd)));
}
Cmd::Exec
}
"discard" => Cmd::Discard,
_ => Cmd::Unknow,
},
protocol.0,
))
}
_ => Err(DBError(format!(
"fail to parse as cmd for {:?}",
protocol.0
))),
}
}
pub async fn run(
&self,
server: &mut Server,
protocol: Protocol,
is_rep_con: bool,
queued_cmd: &mut Option<Vec<(Cmd, Protocol)>>,
) -> Result<Protocol, DBError> {
// return if the command is a write command
let p = protocol.clone();
if queued_cmd.is_some()
&& !matches!(self, Cmd::Exec)
&& !matches!(self, Cmd::Multi)
&& !matches!(self, Cmd::Discard)
{
queued_cmd
.as_mut()
.unwrap()
.push((self.clone(), protocol.clone()));
return Ok(Protocol::SimpleString("QUEUED".to_string()));
}
let ret = match self {
Cmd::Ping => Ok(Protocol::SimpleString("PONG".to_string())),
Cmd::Echo(s) => Ok(Protocol::SimpleString(s.clone())),
Cmd::Get(k) => get_cmd(server, k).await,
Cmd::Set(k, v) => set_cmd(server, k, v, protocol, is_rep_con).await,
Cmd::SetPx(k, v, x) => set_px_cmd(server, k, v, x, protocol, is_rep_con).await,
Cmd::SetEx(k, v, x) => set_ex_cmd(server, k, v, x, protocol, is_rep_con).await,
Cmd::Del(k) => del_cmd(server, k, protocol, is_rep_con).await,
Cmd::ConfigGet(name) => config_get_cmd(name, server),
Cmd::Keys => keys_cmd(server).await,
Cmd::Info(section) => info_cmd(section, server),
Cmd::Replconf(sub_cmd) => replconf_cmd(sub_cmd, server),
Cmd::Psync => psync_cmd(server),
Cmd::Type(k) => type_cmd(server, k).await,
Cmd::Xadd(stream_key, offset, kvps) => {
xadd_cmd(
offset.as_str(),
server,
stream_key.as_str(),
kvps,
protocol,
is_rep_con,
)
.await
}
Cmd::Xrange(stream_key, start, end) => xrange_cmd(server, stream_key, start, end).await,
Cmd::Xread(stream_keys, starts, block) => {
xread_cmd(starts, server, stream_keys, block).await
}
Cmd::Incr(key) => incr_cmd(server, key).await,
Cmd::Multi => {
*queued_cmd = Some(Vec::<(Cmd, Protocol)>::new());
Ok(Protocol::SimpleString("ok".to_string()))
}
Cmd::Exec => exec_cmd(queued_cmd, server, is_rep_con).await,
Cmd::Discard => {
if queued_cmd.is_some() {
*queued_cmd = None;
Ok(Protocol::SimpleString("ok".to_string()))
} else {
Ok(Protocol::err("ERR Discard without MULTI"))
}
}
Cmd::Unknow => Ok(Protocol::err("unknow cmd")),
};
if ret.is_ok() {
server.offset.fetch_add(
p.encode().len() as u64,
std::sync::atomic::Ordering::Relaxed,
);
}
ret
}
}
async fn exec_cmd(
queued_cmd: &mut Option<Vec<(Cmd, Protocol)>>,
server: &mut Server,
is_rep_con: bool,
) -> Result<Protocol, DBError> {
if queued_cmd.is_some() {
let mut vec = Vec::new();
for (cmd, protocol) in queued_cmd.as_ref().unwrap() {
let res = Box::pin(cmd.run(server, protocol.clone(), is_rep_con, &mut None)).await?;
vec.push(res);
}
*queued_cmd = None;
Ok(Protocol::Array(vec))
} else {
Ok(Protocol::err("ERR EXEC without MULTI"))
}
}
async fn incr_cmd(server: &mut Server, key: &String) -> Result<Protocol, DBError> {
let mut storage = server.storage.lock().await;
let v = storage.get(key);
// return 1 if key is missing
let v = v.map_or("1".to_string(), |v| v);
if let Ok(x) = v.parse::<u64>() {
let v = (x + 1).to_string();
storage.set(key.clone(), v.clone());
Ok(Protocol::SimpleString(v))
} else {
Ok(Protocol::err("ERR value is not an integer or out of range"))
}
}
fn config_get_cmd(name: &String, server: &mut Server) -> Result<Protocol, DBError> {
match name.as_str() {
"dir" => Ok(Protocol::Array(vec![
Protocol::BulkString(name.clone()),
Protocol::BulkString(server.option.dir.clone()),
])),
"dbfilename" => Ok(Protocol::Array(vec![
Protocol::BulkString(name.clone()),
Protocol::BulkString(server.option.db_file_name.clone()),
])),
_ => Err(DBError(format!("unsupported config {:?}", name))),
}
}
async fn keys_cmd(server: &mut Server) -> Result<Protocol, DBError> {
let keys = { server.storage.lock().await.keys() };
Ok(Protocol::Array(
keys.into_iter().map(Protocol::BulkString).collect(),
))
}
fn info_cmd(section: &Option<String>, server: &mut Server) -> Result<Protocol, DBError> {
match section {
Some(s) => match s.as_str() {
"replication" => Ok(Protocol::BulkString(format!(
"role:{}\nmaster_replid:{}\nmaster_repl_offset:{}\n",
server.option.replication.role,
server.option.replication.master_replid,
server.option.replication.master_repl_offset
))),
_ => Err(DBError(format!("unsupported section {:?}", s))),
},
None => Ok(Protocol::BulkString("default".to_string())),
}
}
async fn xread_cmd(
starts: &[String],
server: &mut Server,
stream_keys: &[String],
block_millis: &Option<u64>,
) -> Result<Protocol, DBError> {
if let Some(t) = block_millis {
if t > &0 {
tokio::time::sleep(Duration::from_millis(*t)).await;
} else {
let (sender, mut receiver) = mpsc::channel(4);
{
let mut blocker = server.stream_reader_blocker.lock().await;
blocker.push(sender.clone());
}
while let Some(_) = receiver.recv().await {
println!("get new xadd cmd, release block");
// break;
}
}
}
let streams = server.streams.lock().await;
let mut ret = Vec::new();
for (i, stream_key) in stream_keys.iter().enumerate() {
let stream = streams.get(stream_key);
if let Some(s) = stream {
let (offset_id, mut offset_seq, _) = split_offset(starts[i].as_str());
offset_seq += 1;
let start = format!("{}-{}", offset_id, offset_seq);
let end = format!("{}-{}", u64::MAX - 1, 0);
// query stream range
let range = s.range::<String, _>((Bound::Included(&start), Bound::Included(&end)));
let mut array = Vec::new();
for (k, v) in range {
array.push(Protocol::BulkString(k.clone()));
array.push(Protocol::from_vec(
v.iter()
.flat_map(|(a, b)| vec![a.as_str(), b.as_str()])
.collect(),
))
}
ret.push(Protocol::BulkString(stream_key.clone()));
ret.push(Protocol::Array(array));
}
}
Ok(Protocol::Array(ret))
}
fn replconf_cmd(sub_cmd: &str, server: &mut Server) -> Result<Protocol, DBError> {
match sub_cmd {
"getack" => Ok(Protocol::from_vec(vec![
"REPLCONF",
"ACK",
server
.offset
.load(std::sync::atomic::Ordering::Relaxed)
.to_string()
.as_str(),
])),
_ => Ok(Protocol::SimpleString("OK".to_string())),
}
}
async fn xrange_cmd(
server: &mut Server,
stream_key: &String,
start: &String,
end: &String,
) -> Result<Protocol, DBError> {
let streams = server.streams.lock().await;
let stream = streams.get(stream_key);
Ok(stream.map_or(Protocol::none(), |s| {
// support query with '-'
let start = if start == "-" {
"0".to_string()
} else {
start.clone()
};
// support query with '+'
let end = if end == "+" {
u64::MAX.to_string()
} else {
end.clone()
};
// query stream range
let range = s.range::<String, _>((Bound::Included(&start), Bound::Included(&end)));
let mut array = Vec::new();
for (k, v) in range {
array.push(Protocol::BulkString(k.clone()));
array.push(Protocol::from_vec(
v.iter()
.flat_map(|(a, b)| vec![a.as_str(), b.as_str()])
.collect(),
))
}
println!("after xrange: {:?}", array);
Protocol::Array(array)
}))
}
async fn xadd_cmd(
offset: &str,
server: &mut Server,
stream_key: &str,
kvps: &Vec<(String, String)>,
protocol: Protocol,
is_rep_con: bool,
) -> Result<Protocol, DBError> {
let mut offset = offset.to_string();
if offset == "*" {
offset = format!("{}-*", now_in_millis() as u64);
}
let (offset_id, mut offset_seq, has_wildcard) = split_offset(offset.as_str());
if offset_id == 0 && offset_seq == 0 && !has_wildcard {
return Ok(Protocol::err(
"ERR The ID specified in XADD must be greater than 0-0",
));
}
{
let mut streams = server.streams.lock().await;
let stream = streams
.entry(stream_key.to_string())
.or_insert_with(BTreeMap::new);
if let Some((last_offset, _)) = stream.last_key_value() {
let (last_offset_id, last_offset_seq, _) = split_offset(last_offset.as_str());
if last_offset_id > offset_id
|| (last_offset_id == offset_id && last_offset_seq >= offset_seq && !has_wildcard)
{
return Ok(Protocol::err("ERR The ID specified in XADD is equal or smaller than the target stream top item"));
}
if last_offset_id == offset_id && last_offset_seq >= offset_seq && has_wildcard {
offset_seq = last_offset_seq + 1;
}
}
let offset = format!("{}-{}", offset_id, offset_seq);
let s = stream.entry(offset.clone()).or_insert_with(Vec::new);
for (key, value) in kvps {
s.push((key.clone(), value.clone()));
}
}
{
let mut blocker = server.stream_reader_blocker.lock().await;
for sender in blocker.iter() {
sender.send(()).await?;
}
blocker.clear();
}
resp_and_replicate(
server,
Protocol::BulkString(offset.to_string()),
protocol,
is_rep_con,
)
.await
}
async fn type_cmd(server: &mut Server, k: &String) -> Result<Protocol, DBError> {
let v = { server.storage.lock().await.get(k) };
if v.is_some() {
return Ok(Protocol::SimpleString("string".to_string()));
}
let streams = server.streams.lock().await;
let v = streams.get(k);
Ok(v.map_or(Protocol::none(), |_| {
Protocol::SimpleString("stream".to_string())
}))
}
fn psync_cmd(server: &mut Server) -> Result<Protocol, DBError> {
if server.is_master() {
Ok(Protocol::SimpleString(format!(
"FULLRESYNC {} 0",
server.option.replication.master_replid
)))
} else {
Ok(Protocol::psync_on_slave_err())
}
}
async fn del_cmd(
server: &mut Server,
k: &str,
protocol: Protocol,
is_rep_con: bool,
) -> Result<Protocol, DBError> {
// offset
let _ = {
let mut s = server.storage.lock().await;
s.del(k.to_string());
server
.offset
.fetch_add(1, std::sync::atomic::Ordering::Relaxed)
};
resp_and_replicate(server, Protocol::ok(), protocol, is_rep_con).await
}
async fn set_ex_cmd(
server: &mut Server,
k: &str,
v: &str,
x: &u128,
protocol: Protocol,
is_rep_con: bool,
) -> Result<Protocol, DBError> {
// offset
let _ = {
let mut s = server.storage.lock().await;
s.setx(k.to_string(), v.to_string(), *x * 1000);
server
.offset
.fetch_add(1, std::sync::atomic::Ordering::Relaxed)
};
resp_and_replicate(server, Protocol::ok(), protocol, is_rep_con).await
}
async fn set_px_cmd(
server: &mut Server,
k: &str,
v: &str,
x: &u128,
protocol: Protocol,
is_rep_con: bool,
) -> Result<Protocol, DBError> {
// offset
let _ = {
let mut s = server.storage.lock().await;
s.setx(k.to_string(), v.to_string(), *x);
server
.offset
.fetch_add(1, std::sync::atomic::Ordering::Relaxed)
};
resp_and_replicate(server, Protocol::ok(), protocol, is_rep_con).await
}
async fn set_cmd(
server: &mut Server,
k: &str,
v: &str,
protocol: Protocol,
is_rep_con: bool,
) -> Result<Protocol, DBError> {
// offset
let _ = {
let mut s = server.storage.lock().await;
s.set(k.to_string(), v.to_string());
server
.offset
.fetch_add(1, std::sync::atomic::Ordering::Relaxed)
+ 1
};
resp_and_replicate(server, Protocol::ok(), protocol, is_rep_con).await
}
async fn get_cmd(server: &mut Server, k: &str) -> Result<Protocol, DBError> {
let v = {
let mut s = server.storage.lock().await;
s.get(k)
};
Ok(v.map_or(Protocol::Null, Protocol::SimpleString))
}
async fn resp_and_replicate(
server: &mut Server,
resp: Protocol,
replication: Protocol,
is_rep_con: bool,
) -> Result<Protocol, DBError> {
if server.is_master() {
server
.master_repl_clients
.lock()
.await
.as_mut()
.unwrap()
.send_command(replication)
.await?;
Ok(resp)
} else if !is_rep_con {
Ok(Protocol::write_on_slave_err())
} else {
Ok(resp)
}
}
fn split_offset(offset: &str) -> (u64, u64, bool) {
let offset_split = offset.split('-').collect::<Vec<_>>();
let offset_id = offset_split[0].parse::<u64>().expect(&format!(
"ERR The ID specified in XADD must be a number: {}",
offset
));
if offset_split.len() == 1 || offset_split[1] == "*" {
return (offset_id, if offset_id == 0 { 1 } else { 0 }, true);
}
let offset_seq = offset_split[1].parse::<u64>().unwrap();
(offset_id, offset_seq, false)
}

44
src/error.rs Normal file
View File

@@ -0,0 +1,44 @@
use std::num::ParseIntError;
use tokio::sync::mpsc;
use crate::protocol::Protocol;
// todo: more error types
#[derive(Debug)]
pub struct DBError(pub String);
impl From<std::io::Error> for DBError {
fn from(item: std::io::Error) -> Self {
DBError(item.to_string().clone())
}
}
impl From<ParseIntError> for DBError {
fn from(item: ParseIntError) -> Self {
DBError(item.to_string().clone())
}
}
impl From<std::str::Utf8Error> for DBError {
fn from(item: std::str::Utf8Error) -> Self {
DBError(item.to_string().clone())
}
}
impl From<std::string::FromUtf8Error> for DBError {
fn from(item: std::string::FromUtf8Error) -> Self {
DBError(item.to_string().clone())
}
}
impl From<mpsc::error::SendError<(Protocol, u64)>> for DBError {
fn from(item: mpsc::error::SendError<(Protocol, u64)>) -> Self {
DBError(item.to_string().clone())
}
}
impl From<tokio::sync::mpsc::error::SendError<()>> for DBError {
fn from(item: mpsc::error::SendError<()>) -> Self {
DBError(item.to_string().clone())
}
}

8
src/lib.rs Normal file
View File

@@ -0,0 +1,8 @@
mod cmd;
pub mod error;
pub mod options;
mod protocol;
mod rdb;
mod replication_client;
pub mod server;
mod storage;

101
src/main.rs Normal file
View File

@@ -0,0 +1,101 @@
// #![allow(unused_imports)]
use tokio::net::TcpListener;
use redis_rs::{options::ReplicationOption, server};
use clap::Parser;
/// Simple program to greet a person
#[derive(Parser, Debug)]
#[command(version, about, long_about = None)]
struct Args {
/// The directory of Redis DB file
#[arg(long)]
dir: String,
/// The name of the Redis DB file
#[arg(long)]
dbfilename: String,
/// The port of the Redis server, default is 6379 if not specified
#[arg(long)]
port: Option<u16>,
/// The address of the master Redis server, if the server is a replica. None if the server is a master.
#[arg(long)]
replicaof: Option<String>,
}
#[tokio::main]
async fn main() {
// parse args
let args = Args::parse();
// bind port
let port = args.port.unwrap_or(6379);
println!("will listen on port: {}", port);
let listener = TcpListener::bind(format!("127.0.0.1:{}", port))
.await
.unwrap();
// new DB option
let option = redis_rs::options::DBOption {
dir: args.dir,
db_file_name: args.dbfilename,
port,
replication: ReplicationOption {
role: if let Some(_) = args.replicaof {
"slave".to_string()
} else {
"master".to_string()
},
master_replid: "8371b4fb1155b71f4a04d3e1bc3e18c4a990aeea".to_string(), // should be a random string but hard code for now
master_repl_offset: 0,
replica_of: args.replicaof,
},
};
// new server
let mut server = server::Server::new(option).await;
//start receive replication cmds for slave
if server.is_slave() {
let mut sc = server.clone();
let mut follower_repl_client = server.get_follower_repl_client().await.unwrap();
follower_repl_client.ping_master().await.unwrap();
follower_repl_client
.report_port(server.option.port)
.await
.unwrap();
follower_repl_client.report_sync_protocol().await.unwrap();
follower_repl_client.start_psync(&mut sc).await.unwrap();
tokio::spawn(async move {
if let Err(e) = sc.handle(follower_repl_client.stream, true).await {
println!("error: {:?}, will close the connection. Bye", e);
}
});
}
// accept new connections
loop {
let stream = listener.accept().await;
match stream {
Ok((stream, _)) => {
println!("accepted new connection");
let mut sc = server.clone();
tokio::spawn(async move {
if let Err(e) = sc.handle(stream, false).await {
println!("error: {:?}, will close the connection. Bye", e);
}
});
}
Err(e) => {
println!("error: {}", e);
}
}
}
}

15
src/options.rs Normal file
View File

@@ -0,0 +1,15 @@
#[derive(Clone)]
pub struct DBOption {
pub dir: String,
pub db_file_name: String,
pub replication: ReplicationOption,
pub port: u16,
}
#[derive(Clone)]
pub struct ReplicationOption {
pub role: String,
pub master_replid: String,
pub master_repl_offset: u64,
pub replica_of: Option<String>,
}

176
src/protocol.rs Normal file
View File

@@ -0,0 +1,176 @@
use core::fmt;
use crate::error::DBError;
#[derive(Debug, Clone)]
pub enum Protocol {
SimpleString(String),
BulkString(String),
Null,
Array(Vec<Protocol>),
}
impl fmt::Display for Protocol {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
write!(f, "{}", self.decode().as_str())
}
}
impl Protocol {
pub fn from(protocol: &str) -> Result<(Self, usize), DBError> {
let ret = match protocol.chars().nth(0) {
Some('+') => Self::parse_simple_string_sfx(&protocol[1..]),
Some('$') => Self::parse_bulk_string_sfx(&protocol[1..]),
Some('*') => Self::parse_array_sfx(&protocol[1..]),
_ => Err(DBError(format!(
"[from] unsupported protocol: {:?}",
protocol
))),
};
match ret {
Ok((p, s)) => Ok((p, s + 1)),
Err(e) => Err(e),
}
}
pub fn from_vec(array: Vec<&str>) -> Self {
let array = array
.into_iter()
.map(|x| Protocol::BulkString(x.to_string()))
.collect();
Protocol::Array(array)
}
#[inline]
pub fn ok() -> Self {
Protocol::SimpleString("ok".to_string())
}
#[inline]
pub fn err(msg: &str) -> Self {
Protocol::SimpleString(msg.to_string())
}
#[inline]
pub fn write_on_slave_err() -> Self {
Self::err("DISALLOW WRITE ON SLAVE")
}
#[inline]
pub fn psync_on_slave_err() -> Self {
Self::err("PSYNC ON SLAVE IS NOT ALLOWED")
}
#[inline]
pub fn none() -> Self {
Self::SimpleString("none".to_string())
}
pub fn decode(&self) -> String {
match self {
Protocol::SimpleString(s) => s.to_string(),
Protocol::BulkString(s) => s.to_string(),
Protocol::Null => "".to_string(),
Protocol::Array(s) => s.iter().map(|x| x.decode()).collect::<Vec<_>>().join(" "),
}
}
pub fn encode(&self) -> String {
match self {
Protocol::SimpleString(s) => format!("+{}\r\n", s),
Protocol::BulkString(s) => format!("${}\r\n{}\r\n", s.len(), s),
Protocol::Array(ss) => {
format!("*{}\r\n", ss.len())
+ ss.iter()
.map(|x| x.encode())
.collect::<Vec<_>>()
.join("")
.as_str()
}
Protocol::Null => "$-1\r\n".to_string(),
}
}
fn parse_simple_string_sfx(protocol: &str) -> Result<(Self, usize), DBError> {
match protocol.find("\r\n") {
Some(x) => Ok((Self::SimpleString(protocol[..x].to_string()), x + 2)),
_ => Err(DBError(format!(
"[new simple string] unsupported protocol: {:?}",
protocol
))),
}
}
fn parse_bulk_string_sfx(protocol: &str) -> Result<(Self, usize), DBError> {
if let Some(len) = protocol.find("\r\n") {
let size = Self::parse_usize(&protocol[..len])?;
if let Some(data_len) = protocol[len + 2..].find("\r\n") {
let s = Self::parse_string(&protocol[len + 2..len + 2 + data_len])?;
if size != s.len() {
Err(DBError(format!(
"[new bulk string] unmatched string length in prototocl {:?}",
protocol,
)))
} else {
Ok((
Protocol::BulkString(s.to_lowercase()),
len + 2 + data_len + 2,
))
}
} else {
Err(DBError(format!(
"[new bulk string] unsupported protocol: {:?}",
protocol
)))
}
} else {
Err(DBError(format!(
"[new bulk string] unsupported protocol: {:?}",
protocol
)))
}
}
fn parse_array_sfx(s: &str) -> Result<(Self, usize), DBError> {
let mut offset = 0;
match s.find("\r\n") {
Some(x) => {
let array_len = s[..x].parse::<usize>()?;
offset += x + 2;
let mut vec = vec![];
for _ in 0..array_len {
match Protocol::from(&s[offset..]) {
Ok((p, len)) => {
offset += len;
vec.push(p);
}
Err(e) => {
return Err(e);
}
}
}
Ok((Protocol::Array(vec), offset))
}
_ => Err(DBError(format!(
"[new array] unsupported protocol: {:?}",
s
))),
}
}
fn parse_usize(protocol: &str) -> Result<usize, DBError> {
match protocol.len() {
0 => Err(DBError(format!("parse usize error: {:?}", protocol))),
_ => Ok(protocol
.parse::<usize>()
.map_err(|_| DBError(format!("parse usize error: {}", protocol)))?),
}
}
fn parse_string(protocol: &str) -> Result<String, DBError> {
match protocol.len() {
0 => Err(DBError(format!("parse usize error: {:?}", protocol))),
_ => Ok(protocol.to_string()),
}
}
}

201
src/rdb.rs Normal file
View File

@@ -0,0 +1,201 @@
// parse Redis RDB file format: https://rdb.fnordig.de/file_format.html
use tokio::{
fs,
io::{AsyncRead, AsyncReadExt, BufReader},
};
use crate::{error::DBError, server::Server};
use futures::pin_mut;
enum StringEncoding {
Raw,
I8,
I16,
I32,
LZF,
}
// RDB file format.
const MAGIC: &[u8; 5] = b"REDIS";
const META: u8 = 0xFA;
const DB_SELECT: u8 = 0xFE;
const TABLE_SIZE_INFO: u8 = 0xFB;
pub const EOF: u8 = 0xFF;
pub async fn parse_rdb<R: AsyncRead + Unpin>(
reader: &mut R,
server: &mut Server,
) -> Result<(), DBError> {
let mut storage = server.storage.lock().await;
parse_magic(reader).await?;
let _version = parse_version(reader).await?;
pin_mut!(reader);
loop {
let op = reader.read_u8().await?;
match op {
META => {
let _ = parse_aux(&mut *reader).await?;
let _ = parse_aux(&mut *reader).await?;
// just ignore the aux info for now
}
DB_SELECT => {
let (_, _) = parse_len(&mut *reader).await?;
// just ignore the db index for now
}
TABLE_SIZE_INFO => {
let size_no_expire = parse_len(&mut *reader).await?.0;
let size_expire = parse_len(&mut *reader).await?.0;
for _ in 0..size_no_expire {
let (k, v) = parse_no_expire_entry(&mut *reader).await?;
storage.set(k, v);
}
for _ in 0..size_expire {
let (k, v, expire_timestamp) = parse_expire_entry(&mut *reader).await?;
storage.setx(k, v, expire_timestamp);
}
}
EOF => {
// not verify crc for now
let _crc = reader.read_u64().await?;
break;
}
_ => return Err(DBError(format!("unexpected op: {}", op))),
}
}
Ok(())
}
pub async fn parse_rdb_file(f: &mut fs::File, server: &mut Server) -> Result<(), DBError> {
let mut reader = BufReader::new(f);
parse_rdb(&mut reader, server).await
}
async fn parse_no_expire_entry<R: AsyncRead + Unpin>(
input: &mut R,
) -> Result<(String, String), DBError> {
let b = input.read_u8().await?;
if b != 0 {
return Err(DBError(format!("unexpected key type: {}", b)));
}
let k = parse_aux(input).await?;
let v = parse_aux(input).await?;
Ok((k, v))
}
async fn parse_expire_entry<R: AsyncRead + Unpin>(
input: &mut R,
) -> Result<(String, String, u128), DBError> {
let b = input.read_u8().await?;
match b {
0xFC => {
// expire in milliseconds
let expire_stamp = input.read_u64_le().await?;
let (k, v) = parse_no_expire_entry(input).await?;
Ok((k, v, expire_stamp as u128))
}
0xFD => {
// expire in seconds
let expire_timestamp = input.read_u32_le().await?;
let (k, v) = parse_no_expire_entry(input).await?;
Ok((k, v, (expire_timestamp * 1000) as u128))
}
_ => return Err(DBError(format!("unexpected expire type: {}", b))),
}
}
async fn parse_magic<R: AsyncRead + Unpin>(input: &mut R) -> Result<(), DBError> {
let mut magic = [0; 5];
let size_read = input.read(&mut magic).await?;
if size_read != 5 {
Err(DBError("expected 5 chars for magic number".to_string()))
} else if magic.as_slice() == MAGIC {
Ok(())
} else {
Err(DBError(format!(
"expected magic string {:?}, but got: {:?}",
MAGIC, magic
)))
}
}
async fn parse_version<R: AsyncRead + Unpin>(input: &mut R) -> Result<[u8; 4], DBError> {
let mut version = [0; 4];
let size_read = input.read(&mut version).await?;
if size_read != 4 {
Err(DBError("expected 4 chars for redis version".to_string()))
} else {
Ok(version)
}
}
async fn parse_aux<R: AsyncRead + Unpin>(input: &mut R) -> Result<String, DBError> {
let (len, encoding) = parse_len(input).await?;
let s = parse_string(input, len, encoding).await?;
Ok(s)
}
async fn parse_len<R: AsyncRead + Unpin>(input: &mut R) -> Result<(u32, StringEncoding), DBError> {
let first = input.read_u8().await?;
match first & 0xC0 {
0x00 => {
// The size is the remaining 6 bits of the byte.
Ok((first as u32, StringEncoding::Raw))
}
0x04 => {
// The size is the next 14 bits of the byte.
let second = input.read_u8().await?;
Ok((
(((first & 0x3F) as u32) << 8 | second as u32) as u32,
StringEncoding::Raw,
))
}
0x80 => {
//Ignore the remaining 6 bits of the first byte. The size is the next 4 bytes, in big-endian
let second = input.read_u32().await?;
Ok((second, StringEncoding::Raw))
}
0xC0 => {
// The remaining 6 bits specify a type of string encoding.
match first {
0xC0 => Ok((1, StringEncoding::I8)),
0xC1 => Ok((2, StringEncoding::I16)),
0xC2 => Ok((4, StringEncoding::I32)),
0xC3 => Ok((0, StringEncoding::LZF)), // not supported yet
_ => Err(DBError(format!("unexpected string encoding: {}", first))),
}
}
_ => Err(DBError(format!("unexpected len prefix: {}", first))),
}
}
async fn parse_string<R: AsyncRead + Unpin>(
input: &mut R,
len: u32,
encoding: StringEncoding,
) -> Result<String, DBError> {
match encoding {
StringEncoding::Raw => {
let mut s = vec![0; len as usize];
input.read_exact(&mut s).await?;
Ok(String::from_utf8(s).unwrap())
}
StringEncoding::I8 => {
let b = input.read_u8().await?;
Ok(b.to_string())
}
StringEncoding::I16 => {
let b = input.read_u16_le().await?;
Ok(b.to_string())
}
StringEncoding::I32 => {
let b = input.read_u32_le().await?;
Ok(b.to_string())
}
StringEncoding::LZF => {
// not supported yet
Err(DBError("LZF encoding not supported yet".to_string()))
}
}
}

155
src/replication_client.rs Normal file
View File

@@ -0,0 +1,155 @@
use std::{num::ParseIntError, sync::Arc};
use tokio::{
io::{AsyncBufReadExt, AsyncReadExt, AsyncWriteExt, BufReader},
net::TcpStream,
sync::Mutex,
};
use crate::{error::DBError, protocol::Protocol, rdb, server::Server};
const EMPTY_RDB_FILE_HEX_STRING: &str = "524544495330303131fa0972656469732d76657205372e322e30fa0a72656469732d62697473c040fa056374696d65c26d08bc65fa08757365642d6d656dc2b0c41000fa08616f662d62617365c000fff06e3bfec0ff5aa2";
pub struct FollowerReplicationClient {
pub stream: TcpStream,
}
impl FollowerReplicationClient {
pub async fn new(addr: String) -> FollowerReplicationClient {
FollowerReplicationClient {
stream: TcpStream::connect(addr).await.unwrap(),
}
}
pub async fn ping_master(self: &mut Self) -> Result<(), DBError> {
let protocol = Protocol::Array(vec![Protocol::BulkString("PING".to_string())]);
self.stream.write_all(protocol.encode().as_bytes()).await?;
self.check_resp("PONG").await
}
pub async fn report_port(self: &mut Self, port: u16) -> Result<(), DBError> {
let protocol = Protocol::from_vec(vec![
"REPLCONF",
"listening-port",
port.to_string().as_str(),
]);
self.stream.write_all(protocol.encode().as_bytes()).await?;
self.check_resp("OK").await
}
pub async fn report_sync_protocol(self: &mut Self) -> Result<(), DBError> {
let p = Protocol::from_vec(vec!["REPLCONF", "capa", "psync2"]);
self.stream.write_all(p.encode().as_bytes()).await?;
self.check_resp("OK").await
}
pub async fn start_psync(self: &mut Self, server: &mut Server) -> Result<(), DBError> {
let p = Protocol::from_vec(vec!["PSYNC", "?", "-1"]);
self.stream.write_all(p.encode().as_bytes()).await?;
self.recv_rdb_file(server).await?;
Ok(())
}
pub async fn recv_rdb_file(self: &mut Self, server: &mut Server) -> Result<(), DBError> {
let mut reader = BufReader::new(&mut self.stream);
let mut buf = Vec::new();
let _ = reader.read_until(b'\n', &mut buf).await?;
buf.pop();
buf.pop();
let replication_info = String::from_utf8(buf)?;
let replication_info = replication_info
.split_whitespace()
.map(|x| x.to_string())
.collect::<Vec<String>>();
if replication_info.len() != 3 {
return Err(DBError(format!(
"expect 3 args but found {:?}",
replication_info
)));
}
println!(
"Get replication info: {:?} {:?} {:?}",
replication_info[0], replication_info[1], replication_info[2]
);
let c = reader.read_u8().await?;
if c != b'$' {
return Err(DBError(format!("expect $ but found {}", c)));
}
let mut buf = Vec::new();
reader.read_until(b'\n', &mut buf).await?;
buf.pop();
buf.pop();
let rdb_file_len = String::from_utf8(buf)?.parse::<usize>()?;
println!("rdb file len: {}", rdb_file_len);
// receive rdb file content
rdb::parse_rdb(&mut reader, server).await?;
Ok(())
}
pub async fn check_resp(&mut self, expected: &str) -> Result<(), DBError> {
let mut buf = [0; 1024];
let n_bytes = self.stream.read(&mut buf).await?;
println!(
"check resp: recv {:?}",
String::from_utf8(buf[..n_bytes].to_vec()).unwrap()
);
let expect = Protocol::SimpleString(expected.to_string()).encode();
if expect.as_bytes() != &buf[..n_bytes] {
return Err(DBError(format!(
"expect response {:?} but found {:?}",
expect,
&buf[..n_bytes]
)));
}
Ok(())
}
}
#[derive(Clone)]
pub struct MasterReplicationClient {
pub streams: Arc<Mutex<Vec<TcpStream>>>,
}
impl MasterReplicationClient {
pub fn new() -> MasterReplicationClient {
MasterReplicationClient {
streams: Arc::new(Mutex::new(Vec::new())),
}
}
pub async fn send_rdb_file(&mut self, stream: &mut TcpStream) -> Result<(), DBError> {
let empty_rdb_file_bytes = (0..EMPTY_RDB_FILE_HEX_STRING.len())
.step_by(2)
.map(|i| u8::from_str_radix(&EMPTY_RDB_FILE_HEX_STRING[i..i + 2], 16))
.collect::<Result<Vec<u8>, ParseIntError>>()?;
println!("going to send rdb file");
_ = stream.write("$".as_bytes()).await?;
_ = stream
.write(empty_rdb_file_bytes.len().to_string().as_bytes())
.await?;
_ = stream.write_all("\r\n".as_bytes()).await?;
_ = stream.write_all(&empty_rdb_file_bytes).await?;
Ok(())
}
pub async fn add_stream(&mut self, stream: TcpStream) -> Result<(), DBError> {
let mut streams = self.streams.lock().await;
streams.push(stream);
Ok(())
}
pub async fn send_command(&mut self, protocol: Protocol) -> Result<(), DBError> {
let mut streams = self.streams.lock().await;
for stream in streams.iter_mut() {
stream.write_all(protocol.encode().as_bytes()).await?;
}
Ok(())
}
}

156
src/server.rs Normal file
View File

@@ -0,0 +1,156 @@
use core::str;
use std::collections::BTreeMap;
use std::collections::HashMap;
use std::path::PathBuf;
use std::sync::atomic::AtomicU64;
use std::sync::Arc;
use tokio::fs::OpenOptions;
use tokio::io::AsyncReadExt;
use tokio::io::AsyncWriteExt;
use tokio::sync::mpsc::Sender;
use tokio::sync::Mutex;
use crate::cmd::Cmd;
use crate::error::DBError;
use crate::options;
use crate::protocol::Protocol;
use crate::rdb;
use crate::replication_client::FollowerReplicationClient;
use crate::replication_client::MasterReplicationClient;
use crate::storage::Storage;
type Stream = BTreeMap<String, Vec<(String, String)>>;
#[derive(Clone)]
pub struct Server {
pub storage: Arc<Mutex<Storage>>,
pub streams: Arc<Mutex<HashMap<String, Stream>>>,
pub option: options::DBOption,
pub offset: Arc<AtomicU64>,
pub master_repl_clients: Arc<Mutex<Option<MasterReplicationClient>>>,
pub stream_reader_blocker: Arc<Mutex<Vec<Sender<()>>>>,
master_addr: Option<String>,
}
impl Server {
pub async fn new(option: options::DBOption) -> Self {
let master_addr = match option.replication.role.as_str() {
"slave" => Some(
option
.replication
.replica_of
.clone()
.unwrap()
.replace(' ', ":"),
),
_ => None,
};
let is_master = option.replication.role == "master";
let mut server = Server {
storage: Arc::new(Mutex::new(Storage::new())),
streams: Arc::new(Mutex::new(HashMap::new())),
option,
master_repl_clients: if is_master {
Arc::new(Mutex::new(Some(MasterReplicationClient::new())))
} else {
Arc::new(Mutex::new(None))
},
offset: Arc::new(AtomicU64::new(0)),
stream_reader_blocker: Arc::new(Mutex::new(Vec::new())),
master_addr,
};
server.init().await.unwrap();
server
}
pub async fn init(&mut self) -> Result<(), DBError> {
// master initialization
if self.is_master() {
println!("Start as master\n");
let db_file_path =
PathBuf::from(self.option.dir.clone()).join(self.option.db_file_name.clone());
println!("will open db file path: {}", db_file_path.display());
// create empty db file if not exits
let mut file = OpenOptions::new()
.read(true)
.write(true)
.create(true)
.truncate(false)
.open(db_file_path.clone())
.await?;
if file.metadata().await?.len() != 0 {
rdb::parse_rdb_file(&mut file, self).await?;
}
}
Ok(())
}
pub async fn get_follower_repl_client(&mut self) -> Option<FollowerReplicationClient> {
if self.is_slave() {
Some(FollowerReplicationClient::new(self.master_addr.clone().unwrap()).await)
} else {
None
}
}
pub async fn handle(
&mut self,
mut stream: tokio::net::TcpStream,
is_rep_conn: bool,
) -> Result<(), DBError> {
let mut buf = [0; 512];
let mut queued_cmd: Option<Vec<(Cmd, Protocol)>> = None;
loop {
if let Ok(len) = stream.read(&mut buf).await {
if len == 0 {
println!("[handle] connection closed");
return Ok(());
}
let s = str::from_utf8(&buf[..len])?;
let (cmd, protocol) =
Cmd::from(s).unwrap_or((Cmd::Unknow, Protocol::err("unknow cmd")));
println!("got command: {:?}, protocol: {:?}", cmd, protocol);
let res = cmd
.run(self, protocol, is_rep_conn, &mut queued_cmd)
.await
.unwrap_or(Protocol::err("unknow cmd"));
print!("queued 2 cmd {:?}", queued_cmd);
// only send response to normal client, do not send response to replication client
if !is_rep_conn {
println!("going to send response {}", res.encode());
_ = stream.write(res.encode().as_bytes()).await?;
}
// send a full RDB file to slave
if self.is_master() {
if let Cmd::Psync = cmd {
let mut master_rep_client = self.master_repl_clients.lock().await;
let master_rep_client = master_rep_client.as_mut().unwrap();
master_rep_client.send_rdb_file(&mut stream).await?;
master_rep_client.add_stream(stream).await?;
break;
}
}
} else {
println!("[handle] going to break");
break;
}
}
Ok(())
}
pub fn is_slave(&self) -> bool {
self.option.replication.role == "slave"
}
pub fn is_master(&self) -> bool {
!self.is_slave()
}
}

59
src/storage.rs Normal file
View File

@@ -0,0 +1,59 @@
use std::{
collections::HashMap,
time::{SystemTime, UNIX_EPOCH},
};
pub type ValueType = (String, Option<u128>);
pub struct Storage {
// key -> (value, (insert/update time, expire milli seconds))
set: HashMap<String, ValueType>,
}
#[inline]
pub fn now_in_millis() -> u128 {
let start = SystemTime::now();
let duration_since_epoch = start.duration_since(UNIX_EPOCH).unwrap();
duration_since_epoch.as_millis()
}
impl Storage {
pub fn new() -> Self {
Storage {
set: HashMap::new(),
}
}
pub fn get(self: &mut Self, k: &str) -> Option<String> {
match self.set.get(k) {
Some((ss, expire_timestamp)) => match expire_timestamp {
Some(expire_time_stamp) => {
if now_in_millis() > *expire_time_stamp {
self.set.remove(k);
None
} else {
Some(ss.clone())
}
}
_ => Some(ss.clone()),
},
_ => None,
}
}
pub fn set(self: &mut Self, k: String, v: String) {
self.set.insert(k, (v, None));
}
pub fn setx(self: &mut Self, k: String, v: String, expire_ms: u128) {
self.set.insert(k, (v, Some(expire_ms + now_in_millis())));
}
pub fn del(self: &mut Self, k: String) {
self.set.remove(&k);
}
pub fn keys(self: &Self) -> Vec<String> {
self.set.keys().map(|x| x.clone()).collect()
}
}