...
This commit is contained in:
621
src/cmd.rs
Normal file
621
src/cmd.rs
Normal file
@@ -0,0 +1,621 @@
|
||||
use std::{collections::BTreeMap, ops::Bound, time::Duration, u64};
|
||||
|
||||
use tokio::sync::mpsc;
|
||||
|
||||
use crate::{error::DBError, protocol::Protocol, server::Server, storage::now_in_millis};
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub enum Cmd {
|
||||
Ping,
|
||||
Echo(String),
|
||||
Get(String),
|
||||
Set(String, String),
|
||||
SetPx(String, String, u128),
|
||||
SetEx(String, String, u128),
|
||||
Keys,
|
||||
ConfigGet(String),
|
||||
Info(Option<String>),
|
||||
Del(String),
|
||||
Replconf(String),
|
||||
Psync,
|
||||
Type(String),
|
||||
Xadd(String, String, Vec<(String, String)>),
|
||||
Xrange(String, String, String),
|
||||
Xread(Vec<String>, Vec<String>, Option<u64>),
|
||||
Incr(String),
|
||||
Multi,
|
||||
Exec,
|
||||
Unknow,
|
||||
Discard,
|
||||
}
|
||||
|
||||
impl Cmd {
|
||||
pub fn from(s: &str) -> Result<(Self, Protocol), DBError> {
|
||||
let protocol = Protocol::from(s)?;
|
||||
match protocol.clone().0 {
|
||||
Protocol::Array(p) => {
|
||||
let cmd = p.into_iter().map(|x| x.decode()).collect::<Vec<_>>();
|
||||
if cmd.is_empty() {
|
||||
return Err(DBError("cmd length is 0".to_string()));
|
||||
}
|
||||
Ok((
|
||||
match cmd[0].as_str() {
|
||||
"echo" => Cmd::Echo(cmd[1].clone()),
|
||||
"ping" => Cmd::Ping,
|
||||
"get" => Cmd::Get(cmd[1].clone()),
|
||||
"set" => {
|
||||
if cmd.len() == 5 && cmd[3] == "px" {
|
||||
Cmd::SetPx(cmd[1].clone(), cmd[2].clone(), cmd[4].parse().unwrap())
|
||||
} else if cmd.len() == 5 && cmd[3] == "ex" {
|
||||
Cmd::SetEx(cmd[1].clone(), cmd[2].clone(), cmd[4].parse().unwrap())
|
||||
} else if cmd.len() == 3 {
|
||||
Cmd::Set(cmd[1].clone(), cmd[2].clone())
|
||||
} else {
|
||||
return Err(DBError(format!("unsupported cmd {:?}", cmd)));
|
||||
}
|
||||
}
|
||||
"config" => {
|
||||
if cmd.len() != 3 || cmd[1] != "get" {
|
||||
return Err(DBError(format!("unsupported cmd {:?}", cmd)));
|
||||
} else {
|
||||
Cmd::ConfigGet(cmd[2].clone())
|
||||
}
|
||||
}
|
||||
"keys" => {
|
||||
if cmd.len() != 2 || cmd[1] != "*" {
|
||||
return Err(DBError(format!("unsupported cmd {:?}", cmd)));
|
||||
} else {
|
||||
Cmd::Keys
|
||||
}
|
||||
}
|
||||
"info" => {
|
||||
let section = if cmd.len() == 2 {
|
||||
Some(cmd[1].clone())
|
||||
} else {
|
||||
None
|
||||
};
|
||||
Cmd::Info(section)
|
||||
}
|
||||
"replconf" => {
|
||||
if cmd.len() < 3 {
|
||||
return Err(DBError(format!("unsupported cmd {:?}", cmd)));
|
||||
}
|
||||
Cmd::Replconf(cmd[1].clone())
|
||||
}
|
||||
"psync" => {
|
||||
if cmd.len() != 3 {
|
||||
return Err(DBError(format!("unsupported cmd {:?}", cmd)));
|
||||
}
|
||||
Cmd::Psync
|
||||
}
|
||||
"del" => {
|
||||
if cmd.len() != 2 {
|
||||
return Err(DBError(format!("unsupported cmd {:?}", cmd)));
|
||||
}
|
||||
Cmd::Del(cmd[1].clone())
|
||||
}
|
||||
"type" => {
|
||||
if cmd.len() != 2 {
|
||||
return Err(DBError(format!("unsupported cmd {:?}", cmd)));
|
||||
}
|
||||
Cmd::Type(cmd[1].clone())
|
||||
}
|
||||
"xadd" => {
|
||||
if cmd.len() < 5 {
|
||||
return Err(DBError(format!("unsupported cmd {:?}", cmd)));
|
||||
}
|
||||
|
||||
let mut key_value = Vec::<(String, String)>::new();
|
||||
let mut i = 3;
|
||||
while i < cmd.len() - 1 {
|
||||
key_value.push((cmd[i].clone(), cmd[i + 1].clone()));
|
||||
i += 2;
|
||||
}
|
||||
Cmd::Xadd(cmd[1].clone(), cmd[2].clone(), key_value)
|
||||
}
|
||||
"xrange" => {
|
||||
if cmd.len() != 4 {
|
||||
return Err(DBError(format!("unsupported cmd {:?}", cmd)));
|
||||
}
|
||||
Cmd::Xrange(cmd[1].clone(), cmd[2].clone(), cmd[3].clone())
|
||||
}
|
||||
"xread" => {
|
||||
if cmd.len() < 4 || cmd.len() % 2 != 0 {
|
||||
return Err(DBError(format!("unsupported cmd {:?}", cmd)));
|
||||
}
|
||||
let mut offset = 2;
|
||||
// block cmd
|
||||
let mut block = None;
|
||||
if cmd[1] == "block" {
|
||||
offset += 2;
|
||||
if let Ok(block_time) = cmd[2].parse() {
|
||||
block = Some(block_time);
|
||||
} else {
|
||||
return Err(DBError(format!("unsupported cmd {:?}", cmd)));
|
||||
}
|
||||
}
|
||||
let cmd2 = &cmd[offset..];
|
||||
let len2 = cmd2.len() / 2;
|
||||
Cmd::Xread(cmd2[0..len2].to_vec(), cmd2[len2..].to_vec(), block)
|
||||
}
|
||||
"incr" => {
|
||||
if cmd.len() != 2 {
|
||||
return Err(DBError(format!("unsupported cmd {:?}", cmd)));
|
||||
}
|
||||
Cmd::Incr(cmd[1].clone())
|
||||
}
|
||||
"multi" => {
|
||||
if cmd.len() != 1 {
|
||||
return Err(DBError(format!("unsupported cmd {:?}", cmd)));
|
||||
}
|
||||
Cmd::Multi
|
||||
}
|
||||
"exec" => {
|
||||
if cmd.len() != 1 {
|
||||
return Err(DBError(format!("unsupported cmd {:?}", cmd)));
|
||||
}
|
||||
Cmd::Exec
|
||||
}
|
||||
"discard" => Cmd::Discard,
|
||||
_ => Cmd::Unknow,
|
||||
},
|
||||
protocol.0,
|
||||
))
|
||||
}
|
||||
_ => Err(DBError(format!(
|
||||
"fail to parse as cmd for {:?}",
|
||||
protocol.0
|
||||
))),
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn run(
|
||||
&self,
|
||||
server: &mut Server,
|
||||
protocol: Protocol,
|
||||
is_rep_con: bool,
|
||||
queued_cmd: &mut Option<Vec<(Cmd, Protocol)>>,
|
||||
) -> Result<Protocol, DBError> {
|
||||
// return if the command is a write command
|
||||
let p = protocol.clone();
|
||||
if queued_cmd.is_some()
|
||||
&& !matches!(self, Cmd::Exec)
|
||||
&& !matches!(self, Cmd::Multi)
|
||||
&& !matches!(self, Cmd::Discard)
|
||||
{
|
||||
queued_cmd
|
||||
.as_mut()
|
||||
.unwrap()
|
||||
.push((self.clone(), protocol.clone()));
|
||||
return Ok(Protocol::SimpleString("QUEUED".to_string()));
|
||||
}
|
||||
let ret = match self {
|
||||
Cmd::Ping => Ok(Protocol::SimpleString("PONG".to_string())),
|
||||
Cmd::Echo(s) => Ok(Protocol::SimpleString(s.clone())),
|
||||
Cmd::Get(k) => get_cmd(server, k).await,
|
||||
Cmd::Set(k, v) => set_cmd(server, k, v, protocol, is_rep_con).await,
|
||||
Cmd::SetPx(k, v, x) => set_px_cmd(server, k, v, x, protocol, is_rep_con).await,
|
||||
Cmd::SetEx(k, v, x) => set_ex_cmd(server, k, v, x, protocol, is_rep_con).await,
|
||||
Cmd::Del(k) => del_cmd(server, k, protocol, is_rep_con).await,
|
||||
Cmd::ConfigGet(name) => config_get_cmd(name, server),
|
||||
Cmd::Keys => keys_cmd(server).await,
|
||||
Cmd::Info(section) => info_cmd(section, server),
|
||||
Cmd::Replconf(sub_cmd) => replconf_cmd(sub_cmd, server),
|
||||
Cmd::Psync => psync_cmd(server),
|
||||
Cmd::Type(k) => type_cmd(server, k).await,
|
||||
Cmd::Xadd(stream_key, offset, kvps) => {
|
||||
xadd_cmd(
|
||||
offset.as_str(),
|
||||
server,
|
||||
stream_key.as_str(),
|
||||
kvps,
|
||||
protocol,
|
||||
is_rep_con,
|
||||
)
|
||||
.await
|
||||
}
|
||||
|
||||
Cmd::Xrange(stream_key, start, end) => xrange_cmd(server, stream_key, start, end).await,
|
||||
Cmd::Xread(stream_keys, starts, block) => {
|
||||
xread_cmd(starts, server, stream_keys, block).await
|
||||
}
|
||||
Cmd::Incr(key) => incr_cmd(server, key).await,
|
||||
Cmd::Multi => {
|
||||
*queued_cmd = Some(Vec::<(Cmd, Protocol)>::new());
|
||||
Ok(Protocol::SimpleString("ok".to_string()))
|
||||
}
|
||||
Cmd::Exec => exec_cmd(queued_cmd, server, is_rep_con).await,
|
||||
Cmd::Discard => {
|
||||
if queued_cmd.is_some() {
|
||||
*queued_cmd = None;
|
||||
Ok(Protocol::SimpleString("ok".to_string()))
|
||||
} else {
|
||||
Ok(Protocol::err("ERR Discard without MULTI"))
|
||||
}
|
||||
}
|
||||
Cmd::Unknow => Ok(Protocol::err("unknow cmd")),
|
||||
};
|
||||
if ret.is_ok() {
|
||||
server.offset.fetch_add(
|
||||
p.encode().len() as u64,
|
||||
std::sync::atomic::Ordering::Relaxed,
|
||||
);
|
||||
}
|
||||
ret
|
||||
}
|
||||
}
|
||||
|
||||
async fn exec_cmd(
|
||||
queued_cmd: &mut Option<Vec<(Cmd, Protocol)>>,
|
||||
server: &mut Server,
|
||||
is_rep_con: bool,
|
||||
) -> Result<Protocol, DBError> {
|
||||
if queued_cmd.is_some() {
|
||||
let mut vec = Vec::new();
|
||||
for (cmd, protocol) in queued_cmd.as_ref().unwrap() {
|
||||
let res = Box::pin(cmd.run(server, protocol.clone(), is_rep_con, &mut None)).await?;
|
||||
vec.push(res);
|
||||
}
|
||||
*queued_cmd = None;
|
||||
Ok(Protocol::Array(vec))
|
||||
} else {
|
||||
Ok(Protocol::err("ERR EXEC without MULTI"))
|
||||
}
|
||||
}
|
||||
|
||||
async fn incr_cmd(server: &mut Server, key: &String) -> Result<Protocol, DBError> {
|
||||
let mut storage = server.storage.lock().await;
|
||||
let v = storage.get(key);
|
||||
// return 1 if key is missing
|
||||
let v = v.map_or("1".to_string(), |v| v);
|
||||
|
||||
if let Ok(x) = v.parse::<u64>() {
|
||||
let v = (x + 1).to_string();
|
||||
storage.set(key.clone(), v.clone());
|
||||
Ok(Protocol::SimpleString(v))
|
||||
} else {
|
||||
Ok(Protocol::err("ERR value is not an integer or out of range"))
|
||||
}
|
||||
}
|
||||
|
||||
fn config_get_cmd(name: &String, server: &mut Server) -> Result<Protocol, DBError> {
|
||||
match name.as_str() {
|
||||
"dir" => Ok(Protocol::Array(vec![
|
||||
Protocol::BulkString(name.clone()),
|
||||
Protocol::BulkString(server.option.dir.clone()),
|
||||
])),
|
||||
"dbfilename" => Ok(Protocol::Array(vec![
|
||||
Protocol::BulkString(name.clone()),
|
||||
Protocol::BulkString(server.option.db_file_name.clone()),
|
||||
])),
|
||||
_ => Err(DBError(format!("unsupported config {:?}", name))),
|
||||
}
|
||||
}
|
||||
|
||||
async fn keys_cmd(server: &mut Server) -> Result<Protocol, DBError> {
|
||||
let keys = { server.storage.lock().await.keys() };
|
||||
Ok(Protocol::Array(
|
||||
keys.into_iter().map(Protocol::BulkString).collect(),
|
||||
))
|
||||
}
|
||||
|
||||
fn info_cmd(section: &Option<String>, server: &mut Server) -> Result<Protocol, DBError> {
|
||||
match section {
|
||||
Some(s) => match s.as_str() {
|
||||
"replication" => Ok(Protocol::BulkString(format!(
|
||||
"role:{}\nmaster_replid:{}\nmaster_repl_offset:{}\n",
|
||||
server.option.replication.role,
|
||||
server.option.replication.master_replid,
|
||||
server.option.replication.master_repl_offset
|
||||
))),
|
||||
_ => Err(DBError(format!("unsupported section {:?}", s))),
|
||||
},
|
||||
None => Ok(Protocol::BulkString("default".to_string())),
|
||||
}
|
||||
}
|
||||
|
||||
async fn xread_cmd(
|
||||
starts: &[String],
|
||||
server: &mut Server,
|
||||
stream_keys: &[String],
|
||||
block_millis: &Option<u64>,
|
||||
) -> Result<Protocol, DBError> {
|
||||
if let Some(t) = block_millis {
|
||||
if t > &0 {
|
||||
tokio::time::sleep(Duration::from_millis(*t)).await;
|
||||
} else {
|
||||
let (sender, mut receiver) = mpsc::channel(4);
|
||||
{
|
||||
let mut blocker = server.stream_reader_blocker.lock().await;
|
||||
blocker.push(sender.clone());
|
||||
}
|
||||
while let Some(_) = receiver.recv().await {
|
||||
println!("get new xadd cmd, release block");
|
||||
// break;
|
||||
}
|
||||
}
|
||||
}
|
||||
let streams = server.streams.lock().await;
|
||||
let mut ret = Vec::new();
|
||||
for (i, stream_key) in stream_keys.iter().enumerate() {
|
||||
let stream = streams.get(stream_key);
|
||||
if let Some(s) = stream {
|
||||
let (offset_id, mut offset_seq, _) = split_offset(starts[i].as_str());
|
||||
offset_seq += 1;
|
||||
let start = format!("{}-{}", offset_id, offset_seq);
|
||||
let end = format!("{}-{}", u64::MAX - 1, 0);
|
||||
|
||||
// query stream range
|
||||
let range = s.range::<String, _>((Bound::Included(&start), Bound::Included(&end)));
|
||||
let mut array = Vec::new();
|
||||
for (k, v) in range {
|
||||
array.push(Protocol::BulkString(k.clone()));
|
||||
array.push(Protocol::from_vec(
|
||||
v.iter()
|
||||
.flat_map(|(a, b)| vec![a.as_str(), b.as_str()])
|
||||
.collect(),
|
||||
))
|
||||
}
|
||||
ret.push(Protocol::BulkString(stream_key.clone()));
|
||||
ret.push(Protocol::Array(array));
|
||||
}
|
||||
}
|
||||
Ok(Protocol::Array(ret))
|
||||
}
|
||||
|
||||
fn replconf_cmd(sub_cmd: &str, server: &mut Server) -> Result<Protocol, DBError> {
|
||||
match sub_cmd {
|
||||
"getack" => Ok(Protocol::from_vec(vec![
|
||||
"REPLCONF",
|
||||
"ACK",
|
||||
server
|
||||
.offset
|
||||
.load(std::sync::atomic::Ordering::Relaxed)
|
||||
.to_string()
|
||||
.as_str(),
|
||||
])),
|
||||
_ => Ok(Protocol::SimpleString("OK".to_string())),
|
||||
}
|
||||
}
|
||||
|
||||
async fn xrange_cmd(
|
||||
server: &mut Server,
|
||||
stream_key: &String,
|
||||
start: &String,
|
||||
end: &String,
|
||||
) -> Result<Protocol, DBError> {
|
||||
let streams = server.streams.lock().await;
|
||||
let stream = streams.get(stream_key);
|
||||
Ok(stream.map_or(Protocol::none(), |s| {
|
||||
// support query with '-'
|
||||
let start = if start == "-" {
|
||||
"0".to_string()
|
||||
} else {
|
||||
start.clone()
|
||||
};
|
||||
|
||||
// support query with '+'
|
||||
let end = if end == "+" {
|
||||
u64::MAX.to_string()
|
||||
} else {
|
||||
end.clone()
|
||||
};
|
||||
|
||||
// query stream range
|
||||
let range = s.range::<String, _>((Bound::Included(&start), Bound::Included(&end)));
|
||||
let mut array = Vec::new();
|
||||
for (k, v) in range {
|
||||
array.push(Protocol::BulkString(k.clone()));
|
||||
array.push(Protocol::from_vec(
|
||||
v.iter()
|
||||
.flat_map(|(a, b)| vec![a.as_str(), b.as_str()])
|
||||
.collect(),
|
||||
))
|
||||
}
|
||||
println!("after xrange: {:?}", array);
|
||||
Protocol::Array(array)
|
||||
}))
|
||||
}
|
||||
|
||||
async fn xadd_cmd(
|
||||
offset: &str,
|
||||
server: &mut Server,
|
||||
stream_key: &str,
|
||||
kvps: &Vec<(String, String)>,
|
||||
protocol: Protocol,
|
||||
is_rep_con: bool,
|
||||
) -> Result<Protocol, DBError> {
|
||||
let mut offset = offset.to_string();
|
||||
if offset == "*" {
|
||||
offset = format!("{}-*", now_in_millis() as u64);
|
||||
}
|
||||
let (offset_id, mut offset_seq, has_wildcard) = split_offset(offset.as_str());
|
||||
if offset_id == 0 && offset_seq == 0 && !has_wildcard {
|
||||
return Ok(Protocol::err(
|
||||
"ERR The ID specified in XADD must be greater than 0-0",
|
||||
));
|
||||
}
|
||||
{
|
||||
let mut streams = server.streams.lock().await;
|
||||
let stream = streams
|
||||
.entry(stream_key.to_string())
|
||||
.or_insert_with(BTreeMap::new);
|
||||
|
||||
if let Some((last_offset, _)) = stream.last_key_value() {
|
||||
let (last_offset_id, last_offset_seq, _) = split_offset(last_offset.as_str());
|
||||
if last_offset_id > offset_id
|
||||
|| (last_offset_id == offset_id && last_offset_seq >= offset_seq && !has_wildcard)
|
||||
{
|
||||
return Ok(Protocol::err("ERR The ID specified in XADD is equal or smaller than the target stream top item"));
|
||||
}
|
||||
|
||||
if last_offset_id == offset_id && last_offset_seq >= offset_seq && has_wildcard {
|
||||
offset_seq = last_offset_seq + 1;
|
||||
}
|
||||
}
|
||||
|
||||
let offset = format!("{}-{}", offset_id, offset_seq);
|
||||
|
||||
let s = stream.entry(offset.clone()).or_insert_with(Vec::new);
|
||||
for (key, value) in kvps {
|
||||
s.push((key.clone(), value.clone()));
|
||||
}
|
||||
}
|
||||
{
|
||||
let mut blocker = server.stream_reader_blocker.lock().await;
|
||||
for sender in blocker.iter() {
|
||||
sender.send(()).await?;
|
||||
}
|
||||
blocker.clear();
|
||||
}
|
||||
resp_and_replicate(
|
||||
server,
|
||||
Protocol::BulkString(offset.to_string()),
|
||||
protocol,
|
||||
is_rep_con,
|
||||
)
|
||||
.await
|
||||
}
|
||||
|
||||
async fn type_cmd(server: &mut Server, k: &String) -> Result<Protocol, DBError> {
|
||||
let v = { server.storage.lock().await.get(k) };
|
||||
if v.is_some() {
|
||||
return Ok(Protocol::SimpleString("string".to_string()));
|
||||
}
|
||||
let streams = server.streams.lock().await;
|
||||
let v = streams.get(k);
|
||||
Ok(v.map_or(Protocol::none(), |_| {
|
||||
Protocol::SimpleString("stream".to_string())
|
||||
}))
|
||||
}
|
||||
|
||||
fn psync_cmd(server: &mut Server) -> Result<Protocol, DBError> {
|
||||
if server.is_master() {
|
||||
Ok(Protocol::SimpleString(format!(
|
||||
"FULLRESYNC {} 0",
|
||||
server.option.replication.master_replid
|
||||
)))
|
||||
} else {
|
||||
Ok(Protocol::psync_on_slave_err())
|
||||
}
|
||||
}
|
||||
|
||||
async fn del_cmd(
|
||||
server: &mut Server,
|
||||
k: &str,
|
||||
protocol: Protocol,
|
||||
is_rep_con: bool,
|
||||
) -> Result<Protocol, DBError> {
|
||||
// offset
|
||||
let _ = {
|
||||
let mut s = server.storage.lock().await;
|
||||
s.del(k.to_string());
|
||||
server
|
||||
.offset
|
||||
.fetch_add(1, std::sync::atomic::Ordering::Relaxed)
|
||||
};
|
||||
resp_and_replicate(server, Protocol::ok(), protocol, is_rep_con).await
|
||||
}
|
||||
|
||||
async fn set_ex_cmd(
|
||||
server: &mut Server,
|
||||
k: &str,
|
||||
v: &str,
|
||||
x: &u128,
|
||||
protocol: Protocol,
|
||||
is_rep_con: bool,
|
||||
) -> Result<Protocol, DBError> {
|
||||
// offset
|
||||
let _ = {
|
||||
let mut s = server.storage.lock().await;
|
||||
s.setx(k.to_string(), v.to_string(), *x * 1000);
|
||||
server
|
||||
.offset
|
||||
.fetch_add(1, std::sync::atomic::Ordering::Relaxed)
|
||||
};
|
||||
resp_and_replicate(server, Protocol::ok(), protocol, is_rep_con).await
|
||||
}
|
||||
|
||||
async fn set_px_cmd(
|
||||
server: &mut Server,
|
||||
k: &str,
|
||||
v: &str,
|
||||
x: &u128,
|
||||
protocol: Protocol,
|
||||
is_rep_con: bool,
|
||||
) -> Result<Protocol, DBError> {
|
||||
// offset
|
||||
let _ = {
|
||||
let mut s = server.storage.lock().await;
|
||||
s.setx(k.to_string(), v.to_string(), *x);
|
||||
server
|
||||
.offset
|
||||
.fetch_add(1, std::sync::atomic::Ordering::Relaxed)
|
||||
};
|
||||
resp_and_replicate(server, Protocol::ok(), protocol, is_rep_con).await
|
||||
}
|
||||
|
||||
async fn set_cmd(
|
||||
server: &mut Server,
|
||||
k: &str,
|
||||
v: &str,
|
||||
protocol: Protocol,
|
||||
is_rep_con: bool,
|
||||
) -> Result<Protocol, DBError> {
|
||||
// offset
|
||||
let _ = {
|
||||
let mut s = server.storage.lock().await;
|
||||
s.set(k.to_string(), v.to_string());
|
||||
server
|
||||
.offset
|
||||
.fetch_add(1, std::sync::atomic::Ordering::Relaxed)
|
||||
+ 1
|
||||
};
|
||||
resp_and_replicate(server, Protocol::ok(), protocol, is_rep_con).await
|
||||
}
|
||||
|
||||
async fn get_cmd(server: &mut Server, k: &str) -> Result<Protocol, DBError> {
|
||||
let v = {
|
||||
let mut s = server.storage.lock().await;
|
||||
s.get(k)
|
||||
};
|
||||
Ok(v.map_or(Protocol::Null, Protocol::SimpleString))
|
||||
}
|
||||
|
||||
async fn resp_and_replicate(
|
||||
server: &mut Server,
|
||||
resp: Protocol,
|
||||
replication: Protocol,
|
||||
is_rep_con: bool,
|
||||
) -> Result<Protocol, DBError> {
|
||||
if server.is_master() {
|
||||
server
|
||||
.master_repl_clients
|
||||
.lock()
|
||||
.await
|
||||
.as_mut()
|
||||
.unwrap()
|
||||
.send_command(replication)
|
||||
.await?;
|
||||
Ok(resp)
|
||||
} else if !is_rep_con {
|
||||
Ok(Protocol::write_on_slave_err())
|
||||
} else {
|
||||
Ok(resp)
|
||||
}
|
||||
}
|
||||
|
||||
fn split_offset(offset: &str) -> (u64, u64, bool) {
|
||||
let offset_split = offset.split('-').collect::<Vec<_>>();
|
||||
let offset_id = offset_split[0].parse::<u64>().expect(&format!(
|
||||
"ERR The ID specified in XADD must be a number: {}",
|
||||
offset
|
||||
));
|
||||
|
||||
if offset_split.len() == 1 || offset_split[1] == "*" {
|
||||
return (offset_id, if offset_id == 0 { 1 } else { 0 }, true);
|
||||
}
|
||||
|
||||
let offset_seq = offset_split[1].parse::<u64>().unwrap();
|
||||
(offset_id, offset_seq, false)
|
||||
}
|
44
src/error.rs
Normal file
44
src/error.rs
Normal file
@@ -0,0 +1,44 @@
|
||||
use std::num::ParseIntError;
|
||||
|
||||
use tokio::sync::mpsc;
|
||||
|
||||
use crate::protocol::Protocol;
|
||||
|
||||
// todo: more error types
|
||||
#[derive(Debug)]
|
||||
pub struct DBError(pub String);
|
||||
|
||||
impl From<std::io::Error> for DBError {
|
||||
fn from(item: std::io::Error) -> Self {
|
||||
DBError(item.to_string().clone())
|
||||
}
|
||||
}
|
||||
|
||||
impl From<ParseIntError> for DBError {
|
||||
fn from(item: ParseIntError) -> Self {
|
||||
DBError(item.to_string().clone())
|
||||
}
|
||||
}
|
||||
|
||||
impl From<std::str::Utf8Error> for DBError {
|
||||
fn from(item: std::str::Utf8Error) -> Self {
|
||||
DBError(item.to_string().clone())
|
||||
}
|
||||
}
|
||||
|
||||
impl From<std::string::FromUtf8Error> for DBError {
|
||||
fn from(item: std::string::FromUtf8Error) -> Self {
|
||||
DBError(item.to_string().clone())
|
||||
}
|
||||
}
|
||||
|
||||
impl From<mpsc::error::SendError<(Protocol, u64)>> for DBError {
|
||||
fn from(item: mpsc::error::SendError<(Protocol, u64)>) -> Self {
|
||||
DBError(item.to_string().clone())
|
||||
}
|
||||
}
|
||||
impl From<tokio::sync::mpsc::error::SendError<()>> for DBError {
|
||||
fn from(item: mpsc::error::SendError<()>) -> Self {
|
||||
DBError(item.to_string().clone())
|
||||
}
|
||||
}
|
8
src/lib.rs
Normal file
8
src/lib.rs
Normal file
@@ -0,0 +1,8 @@
|
||||
mod cmd;
|
||||
pub mod error;
|
||||
pub mod options;
|
||||
mod protocol;
|
||||
mod rdb;
|
||||
mod replication_client;
|
||||
pub mod server;
|
||||
mod storage;
|
101
src/main.rs
Normal file
101
src/main.rs
Normal file
@@ -0,0 +1,101 @@
|
||||
// #![allow(unused_imports)]
|
||||
|
||||
use tokio::net::TcpListener;
|
||||
|
||||
use redis_rs::{options::ReplicationOption, server};
|
||||
|
||||
use clap::Parser;
|
||||
|
||||
/// Simple program to greet a person
|
||||
#[derive(Parser, Debug)]
|
||||
#[command(version, about, long_about = None)]
|
||||
struct Args {
|
||||
/// The directory of Redis DB file
|
||||
#[arg(long)]
|
||||
dir: String,
|
||||
|
||||
/// The name of the Redis DB file
|
||||
#[arg(long)]
|
||||
dbfilename: String,
|
||||
|
||||
/// The port of the Redis server, default is 6379 if not specified
|
||||
#[arg(long)]
|
||||
port: Option<u16>,
|
||||
|
||||
/// The address of the master Redis server, if the server is a replica. None if the server is a master.
|
||||
#[arg(long)]
|
||||
replicaof: Option<String>,
|
||||
}
|
||||
|
||||
#[tokio::main]
|
||||
async fn main() {
|
||||
// parse args
|
||||
let args = Args::parse();
|
||||
|
||||
// bind port
|
||||
let port = args.port.unwrap_or(6379);
|
||||
println!("will listen on port: {}", port);
|
||||
let listener = TcpListener::bind(format!("127.0.0.1:{}", port))
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
// new DB option
|
||||
let option = redis_rs::options::DBOption {
|
||||
dir: args.dir,
|
||||
db_file_name: args.dbfilename,
|
||||
port,
|
||||
replication: ReplicationOption {
|
||||
role: if let Some(_) = args.replicaof {
|
||||
"slave".to_string()
|
||||
} else {
|
||||
"master".to_string()
|
||||
},
|
||||
master_replid: "8371b4fb1155b71f4a04d3e1bc3e18c4a990aeea".to_string(), // should be a random string but hard code for now
|
||||
master_repl_offset: 0,
|
||||
replica_of: args.replicaof,
|
||||
},
|
||||
};
|
||||
|
||||
// new server
|
||||
let mut server = server::Server::new(option).await;
|
||||
|
||||
//start receive replication cmds for slave
|
||||
if server.is_slave() {
|
||||
let mut sc = server.clone();
|
||||
|
||||
let mut follower_repl_client = server.get_follower_repl_client().await.unwrap();
|
||||
follower_repl_client.ping_master().await.unwrap();
|
||||
follower_repl_client
|
||||
.report_port(server.option.port)
|
||||
.await
|
||||
.unwrap();
|
||||
follower_repl_client.report_sync_protocol().await.unwrap();
|
||||
follower_repl_client.start_psync(&mut sc).await.unwrap();
|
||||
|
||||
tokio::spawn(async move {
|
||||
if let Err(e) = sc.handle(follower_repl_client.stream, true).await {
|
||||
println!("error: {:?}, will close the connection. Bye", e);
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
// accept new connections
|
||||
loop {
|
||||
let stream = listener.accept().await;
|
||||
match stream {
|
||||
Ok((stream, _)) => {
|
||||
println!("accepted new connection");
|
||||
|
||||
let mut sc = server.clone();
|
||||
tokio::spawn(async move {
|
||||
if let Err(e) = sc.handle(stream, false).await {
|
||||
println!("error: {:?}, will close the connection. Bye", e);
|
||||
}
|
||||
});
|
||||
}
|
||||
Err(e) => {
|
||||
println!("error: {}", e);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
15
src/options.rs
Normal file
15
src/options.rs
Normal file
@@ -0,0 +1,15 @@
|
||||
#[derive(Clone)]
|
||||
pub struct DBOption {
|
||||
pub dir: String,
|
||||
pub db_file_name: String,
|
||||
pub replication: ReplicationOption,
|
||||
pub port: u16,
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct ReplicationOption {
|
||||
pub role: String,
|
||||
pub master_replid: String,
|
||||
pub master_repl_offset: u64,
|
||||
pub replica_of: Option<String>,
|
||||
}
|
176
src/protocol.rs
Normal file
176
src/protocol.rs
Normal file
@@ -0,0 +1,176 @@
|
||||
use core::fmt;
|
||||
|
||||
use crate::error::DBError;
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub enum Protocol {
|
||||
SimpleString(String),
|
||||
BulkString(String),
|
||||
Null,
|
||||
Array(Vec<Protocol>),
|
||||
}
|
||||
|
||||
impl fmt::Display for Protocol {
|
||||
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
|
||||
write!(f, "{}", self.decode().as_str())
|
||||
}
|
||||
}
|
||||
|
||||
impl Protocol {
|
||||
pub fn from(protocol: &str) -> Result<(Self, usize), DBError> {
|
||||
let ret = match protocol.chars().nth(0) {
|
||||
Some('+') => Self::parse_simple_string_sfx(&protocol[1..]),
|
||||
Some('$') => Self::parse_bulk_string_sfx(&protocol[1..]),
|
||||
Some('*') => Self::parse_array_sfx(&protocol[1..]),
|
||||
_ => Err(DBError(format!(
|
||||
"[from] unsupported protocol: {:?}",
|
||||
protocol
|
||||
))),
|
||||
};
|
||||
match ret {
|
||||
Ok((p, s)) => Ok((p, s + 1)),
|
||||
Err(e) => Err(e),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn from_vec(array: Vec<&str>) -> Self {
|
||||
let array = array
|
||||
.into_iter()
|
||||
.map(|x| Protocol::BulkString(x.to_string()))
|
||||
.collect();
|
||||
Protocol::Array(array)
|
||||
}
|
||||
|
||||
#[inline]
|
||||
pub fn ok() -> Self {
|
||||
Protocol::SimpleString("ok".to_string())
|
||||
}
|
||||
|
||||
#[inline]
|
||||
pub fn err(msg: &str) -> Self {
|
||||
Protocol::SimpleString(msg.to_string())
|
||||
}
|
||||
|
||||
#[inline]
|
||||
pub fn write_on_slave_err() -> Self {
|
||||
Self::err("DISALLOW WRITE ON SLAVE")
|
||||
}
|
||||
|
||||
#[inline]
|
||||
pub fn psync_on_slave_err() -> Self {
|
||||
Self::err("PSYNC ON SLAVE IS NOT ALLOWED")
|
||||
}
|
||||
|
||||
#[inline]
|
||||
pub fn none() -> Self {
|
||||
Self::SimpleString("none".to_string())
|
||||
}
|
||||
|
||||
pub fn decode(&self) -> String {
|
||||
match self {
|
||||
Protocol::SimpleString(s) => s.to_string(),
|
||||
Protocol::BulkString(s) => s.to_string(),
|
||||
Protocol::Null => "".to_string(),
|
||||
Protocol::Array(s) => s.iter().map(|x| x.decode()).collect::<Vec<_>>().join(" "),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn encode(&self) -> String {
|
||||
match self {
|
||||
Protocol::SimpleString(s) => format!("+{}\r\n", s),
|
||||
Protocol::BulkString(s) => format!("${}\r\n{}\r\n", s.len(), s),
|
||||
Protocol::Array(ss) => {
|
||||
format!("*{}\r\n", ss.len())
|
||||
+ ss.iter()
|
||||
.map(|x| x.encode())
|
||||
.collect::<Vec<_>>()
|
||||
.join("")
|
||||
.as_str()
|
||||
}
|
||||
Protocol::Null => "$-1\r\n".to_string(),
|
||||
}
|
||||
}
|
||||
|
||||
fn parse_simple_string_sfx(protocol: &str) -> Result<(Self, usize), DBError> {
|
||||
match protocol.find("\r\n") {
|
||||
Some(x) => Ok((Self::SimpleString(protocol[..x].to_string()), x + 2)),
|
||||
_ => Err(DBError(format!(
|
||||
"[new simple string] unsupported protocol: {:?}",
|
||||
protocol
|
||||
))),
|
||||
}
|
||||
}
|
||||
|
||||
fn parse_bulk_string_sfx(protocol: &str) -> Result<(Self, usize), DBError> {
|
||||
if let Some(len) = protocol.find("\r\n") {
|
||||
let size = Self::parse_usize(&protocol[..len])?;
|
||||
if let Some(data_len) = protocol[len + 2..].find("\r\n") {
|
||||
let s = Self::parse_string(&protocol[len + 2..len + 2 + data_len])?;
|
||||
if size != s.len() {
|
||||
Err(DBError(format!(
|
||||
"[new bulk string] unmatched string length in prototocl {:?}",
|
||||
protocol,
|
||||
)))
|
||||
} else {
|
||||
Ok((
|
||||
Protocol::BulkString(s.to_lowercase()),
|
||||
len + 2 + data_len + 2,
|
||||
))
|
||||
}
|
||||
} else {
|
||||
Err(DBError(format!(
|
||||
"[new bulk string] unsupported protocol: {:?}",
|
||||
protocol
|
||||
)))
|
||||
}
|
||||
} else {
|
||||
Err(DBError(format!(
|
||||
"[new bulk string] unsupported protocol: {:?}",
|
||||
protocol
|
||||
)))
|
||||
}
|
||||
}
|
||||
|
||||
fn parse_array_sfx(s: &str) -> Result<(Self, usize), DBError> {
|
||||
let mut offset = 0;
|
||||
match s.find("\r\n") {
|
||||
Some(x) => {
|
||||
let array_len = s[..x].parse::<usize>()?;
|
||||
offset += x + 2;
|
||||
let mut vec = vec![];
|
||||
for _ in 0..array_len {
|
||||
match Protocol::from(&s[offset..]) {
|
||||
Ok((p, len)) => {
|
||||
offset += len;
|
||||
vec.push(p);
|
||||
}
|
||||
Err(e) => {
|
||||
return Err(e);
|
||||
}
|
||||
}
|
||||
}
|
||||
Ok((Protocol::Array(vec), offset))
|
||||
}
|
||||
_ => Err(DBError(format!(
|
||||
"[new array] unsupported protocol: {:?}",
|
||||
s
|
||||
))),
|
||||
}
|
||||
}
|
||||
|
||||
fn parse_usize(protocol: &str) -> Result<usize, DBError> {
|
||||
match protocol.len() {
|
||||
0 => Err(DBError(format!("parse usize error: {:?}", protocol))),
|
||||
_ => Ok(protocol
|
||||
.parse::<usize>()
|
||||
.map_err(|_| DBError(format!("parse usize error: {}", protocol)))?),
|
||||
}
|
||||
}
|
||||
|
||||
fn parse_string(protocol: &str) -> Result<String, DBError> {
|
||||
match protocol.len() {
|
||||
0 => Err(DBError(format!("parse usize error: {:?}", protocol))),
|
||||
_ => Ok(protocol.to_string()),
|
||||
}
|
||||
}
|
||||
}
|
201
src/rdb.rs
Normal file
201
src/rdb.rs
Normal file
@@ -0,0 +1,201 @@
|
||||
// parse Redis RDB file format: https://rdb.fnordig.de/file_format.html
|
||||
|
||||
use tokio::{
|
||||
fs,
|
||||
io::{AsyncRead, AsyncReadExt, BufReader},
|
||||
};
|
||||
|
||||
use crate::{error::DBError, server::Server};
|
||||
|
||||
use futures::pin_mut;
|
||||
|
||||
enum StringEncoding {
|
||||
Raw,
|
||||
I8,
|
||||
I16,
|
||||
I32,
|
||||
LZF,
|
||||
}
|
||||
|
||||
// RDB file format.
|
||||
const MAGIC: &[u8; 5] = b"REDIS";
|
||||
const META: u8 = 0xFA;
|
||||
const DB_SELECT: u8 = 0xFE;
|
||||
const TABLE_SIZE_INFO: u8 = 0xFB;
|
||||
pub const EOF: u8 = 0xFF;
|
||||
|
||||
pub async fn parse_rdb<R: AsyncRead + Unpin>(
|
||||
reader: &mut R,
|
||||
server: &mut Server,
|
||||
) -> Result<(), DBError> {
|
||||
let mut storage = server.storage.lock().await;
|
||||
parse_magic(reader).await?;
|
||||
let _version = parse_version(reader).await?;
|
||||
pin_mut!(reader);
|
||||
loop {
|
||||
let op = reader.read_u8().await?;
|
||||
match op {
|
||||
META => {
|
||||
let _ = parse_aux(&mut *reader).await?;
|
||||
let _ = parse_aux(&mut *reader).await?;
|
||||
// just ignore the aux info for now
|
||||
}
|
||||
DB_SELECT => {
|
||||
let (_, _) = parse_len(&mut *reader).await?;
|
||||
// just ignore the db index for now
|
||||
}
|
||||
TABLE_SIZE_INFO => {
|
||||
let size_no_expire = parse_len(&mut *reader).await?.0;
|
||||
let size_expire = parse_len(&mut *reader).await?.0;
|
||||
for _ in 0..size_no_expire {
|
||||
let (k, v) = parse_no_expire_entry(&mut *reader).await?;
|
||||
storage.set(k, v);
|
||||
}
|
||||
for _ in 0..size_expire {
|
||||
let (k, v, expire_timestamp) = parse_expire_entry(&mut *reader).await?;
|
||||
storage.setx(k, v, expire_timestamp);
|
||||
}
|
||||
}
|
||||
EOF => {
|
||||
// not verify crc for now
|
||||
let _crc = reader.read_u64().await?;
|
||||
break;
|
||||
}
|
||||
_ => return Err(DBError(format!("unexpected op: {}", op))),
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub async fn parse_rdb_file(f: &mut fs::File, server: &mut Server) -> Result<(), DBError> {
|
||||
let mut reader = BufReader::new(f);
|
||||
parse_rdb(&mut reader, server).await
|
||||
}
|
||||
|
||||
async fn parse_no_expire_entry<R: AsyncRead + Unpin>(
|
||||
input: &mut R,
|
||||
) -> Result<(String, String), DBError> {
|
||||
let b = input.read_u8().await?;
|
||||
if b != 0 {
|
||||
return Err(DBError(format!("unexpected key type: {}", b)));
|
||||
}
|
||||
let k = parse_aux(input).await?;
|
||||
let v = parse_aux(input).await?;
|
||||
Ok((k, v))
|
||||
}
|
||||
|
||||
async fn parse_expire_entry<R: AsyncRead + Unpin>(
|
||||
input: &mut R,
|
||||
) -> Result<(String, String, u128), DBError> {
|
||||
let b = input.read_u8().await?;
|
||||
match b {
|
||||
0xFC => {
|
||||
// expire in milliseconds
|
||||
let expire_stamp = input.read_u64_le().await?;
|
||||
let (k, v) = parse_no_expire_entry(input).await?;
|
||||
Ok((k, v, expire_stamp as u128))
|
||||
}
|
||||
0xFD => {
|
||||
// expire in seconds
|
||||
let expire_timestamp = input.read_u32_le().await?;
|
||||
let (k, v) = parse_no_expire_entry(input).await?;
|
||||
Ok((k, v, (expire_timestamp * 1000) as u128))
|
||||
}
|
||||
_ => return Err(DBError(format!("unexpected expire type: {}", b))),
|
||||
}
|
||||
}
|
||||
|
||||
async fn parse_magic<R: AsyncRead + Unpin>(input: &mut R) -> Result<(), DBError> {
|
||||
let mut magic = [0; 5];
|
||||
let size_read = input.read(&mut magic).await?;
|
||||
if size_read != 5 {
|
||||
Err(DBError("expected 5 chars for magic number".to_string()))
|
||||
} else if magic.as_slice() == MAGIC {
|
||||
Ok(())
|
||||
} else {
|
||||
Err(DBError(format!(
|
||||
"expected magic string {:?}, but got: {:?}",
|
||||
MAGIC, magic
|
||||
)))
|
||||
}
|
||||
}
|
||||
|
||||
async fn parse_version<R: AsyncRead + Unpin>(input: &mut R) -> Result<[u8; 4], DBError> {
|
||||
let mut version = [0; 4];
|
||||
let size_read = input.read(&mut version).await?;
|
||||
if size_read != 4 {
|
||||
Err(DBError("expected 4 chars for redis version".to_string()))
|
||||
} else {
|
||||
Ok(version)
|
||||
}
|
||||
}
|
||||
|
||||
async fn parse_aux<R: AsyncRead + Unpin>(input: &mut R) -> Result<String, DBError> {
|
||||
let (len, encoding) = parse_len(input).await?;
|
||||
let s = parse_string(input, len, encoding).await?;
|
||||
Ok(s)
|
||||
}
|
||||
|
||||
async fn parse_len<R: AsyncRead + Unpin>(input: &mut R) -> Result<(u32, StringEncoding), DBError> {
|
||||
let first = input.read_u8().await?;
|
||||
match first & 0xC0 {
|
||||
0x00 => {
|
||||
// The size is the remaining 6 bits of the byte.
|
||||
Ok((first as u32, StringEncoding::Raw))
|
||||
}
|
||||
0x04 => {
|
||||
// The size is the next 14 bits of the byte.
|
||||
let second = input.read_u8().await?;
|
||||
Ok((
|
||||
(((first & 0x3F) as u32) << 8 | second as u32) as u32,
|
||||
StringEncoding::Raw,
|
||||
))
|
||||
}
|
||||
0x80 => {
|
||||
//Ignore the remaining 6 bits of the first byte. The size is the next 4 bytes, in big-endian
|
||||
let second = input.read_u32().await?;
|
||||
Ok((second, StringEncoding::Raw))
|
||||
}
|
||||
0xC0 => {
|
||||
// The remaining 6 bits specify a type of string encoding.
|
||||
match first {
|
||||
0xC0 => Ok((1, StringEncoding::I8)),
|
||||
0xC1 => Ok((2, StringEncoding::I16)),
|
||||
0xC2 => Ok((4, StringEncoding::I32)),
|
||||
0xC3 => Ok((0, StringEncoding::LZF)), // not supported yet
|
||||
_ => Err(DBError(format!("unexpected string encoding: {}", first))),
|
||||
}
|
||||
}
|
||||
_ => Err(DBError(format!("unexpected len prefix: {}", first))),
|
||||
}
|
||||
}
|
||||
|
||||
async fn parse_string<R: AsyncRead + Unpin>(
|
||||
input: &mut R,
|
||||
len: u32,
|
||||
encoding: StringEncoding,
|
||||
) -> Result<String, DBError> {
|
||||
match encoding {
|
||||
StringEncoding::Raw => {
|
||||
let mut s = vec![0; len as usize];
|
||||
input.read_exact(&mut s).await?;
|
||||
Ok(String::from_utf8(s).unwrap())
|
||||
}
|
||||
StringEncoding::I8 => {
|
||||
let b = input.read_u8().await?;
|
||||
Ok(b.to_string())
|
||||
}
|
||||
StringEncoding::I16 => {
|
||||
let b = input.read_u16_le().await?;
|
||||
Ok(b.to_string())
|
||||
}
|
||||
StringEncoding::I32 => {
|
||||
let b = input.read_u32_le().await?;
|
||||
Ok(b.to_string())
|
||||
}
|
||||
StringEncoding::LZF => {
|
||||
// not supported yet
|
||||
Err(DBError("LZF encoding not supported yet".to_string()))
|
||||
}
|
||||
}
|
||||
}
|
155
src/replication_client.rs
Normal file
155
src/replication_client.rs
Normal file
@@ -0,0 +1,155 @@
|
||||
use std::{num::ParseIntError, sync::Arc};
|
||||
|
||||
use tokio::{
|
||||
io::{AsyncBufReadExt, AsyncReadExt, AsyncWriteExt, BufReader},
|
||||
net::TcpStream,
|
||||
sync::Mutex,
|
||||
};
|
||||
|
||||
use crate::{error::DBError, protocol::Protocol, rdb, server::Server};
|
||||
|
||||
const EMPTY_RDB_FILE_HEX_STRING: &str = "524544495330303131fa0972656469732d76657205372e322e30fa0a72656469732d62697473c040fa056374696d65c26d08bc65fa08757365642d6d656dc2b0c41000fa08616f662d62617365c000fff06e3bfec0ff5aa2";
|
||||
|
||||
pub struct FollowerReplicationClient {
|
||||
pub stream: TcpStream,
|
||||
}
|
||||
|
||||
impl FollowerReplicationClient {
|
||||
pub async fn new(addr: String) -> FollowerReplicationClient {
|
||||
FollowerReplicationClient {
|
||||
stream: TcpStream::connect(addr).await.unwrap(),
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn ping_master(self: &mut Self) -> Result<(), DBError> {
|
||||
let protocol = Protocol::Array(vec![Protocol::BulkString("PING".to_string())]);
|
||||
self.stream.write_all(protocol.encode().as_bytes()).await?;
|
||||
|
||||
self.check_resp("PONG").await
|
||||
}
|
||||
|
||||
pub async fn report_port(self: &mut Self, port: u16) -> Result<(), DBError> {
|
||||
let protocol = Protocol::from_vec(vec![
|
||||
"REPLCONF",
|
||||
"listening-port",
|
||||
port.to_string().as_str(),
|
||||
]);
|
||||
self.stream.write_all(protocol.encode().as_bytes()).await?;
|
||||
|
||||
self.check_resp("OK").await
|
||||
}
|
||||
|
||||
pub async fn report_sync_protocol(self: &mut Self) -> Result<(), DBError> {
|
||||
let p = Protocol::from_vec(vec!["REPLCONF", "capa", "psync2"]);
|
||||
self.stream.write_all(p.encode().as_bytes()).await?;
|
||||
self.check_resp("OK").await
|
||||
}
|
||||
|
||||
pub async fn start_psync(self: &mut Self, server: &mut Server) -> Result<(), DBError> {
|
||||
let p = Protocol::from_vec(vec!["PSYNC", "?", "-1"]);
|
||||
self.stream.write_all(p.encode().as_bytes()).await?;
|
||||
self.recv_rdb_file(server).await?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub async fn recv_rdb_file(self: &mut Self, server: &mut Server) -> Result<(), DBError> {
|
||||
let mut reader = BufReader::new(&mut self.stream);
|
||||
|
||||
let mut buf = Vec::new();
|
||||
let _ = reader.read_until(b'\n', &mut buf).await?;
|
||||
buf.pop();
|
||||
buf.pop();
|
||||
|
||||
let replication_info = String::from_utf8(buf)?;
|
||||
let replication_info = replication_info
|
||||
.split_whitespace()
|
||||
.map(|x| x.to_string())
|
||||
.collect::<Vec<String>>();
|
||||
if replication_info.len() != 3 {
|
||||
return Err(DBError(format!(
|
||||
"expect 3 args but found {:?}",
|
||||
replication_info
|
||||
)));
|
||||
}
|
||||
println!(
|
||||
"Get replication info: {:?} {:?} {:?}",
|
||||
replication_info[0], replication_info[1], replication_info[2]
|
||||
);
|
||||
|
||||
let c = reader.read_u8().await?;
|
||||
if c != b'$' {
|
||||
return Err(DBError(format!("expect $ but found {}", c)));
|
||||
}
|
||||
let mut buf = Vec::new();
|
||||
reader.read_until(b'\n', &mut buf).await?;
|
||||
buf.pop();
|
||||
buf.pop();
|
||||
let rdb_file_len = String::from_utf8(buf)?.parse::<usize>()?;
|
||||
println!("rdb file len: {}", rdb_file_len);
|
||||
|
||||
// receive rdb file content
|
||||
rdb::parse_rdb(&mut reader, server).await?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub async fn check_resp(&mut self, expected: &str) -> Result<(), DBError> {
|
||||
let mut buf = [0; 1024];
|
||||
let n_bytes = self.stream.read(&mut buf).await?;
|
||||
println!(
|
||||
"check resp: recv {:?}",
|
||||
String::from_utf8(buf[..n_bytes].to_vec()).unwrap()
|
||||
);
|
||||
let expect = Protocol::SimpleString(expected.to_string()).encode();
|
||||
if expect.as_bytes() != &buf[..n_bytes] {
|
||||
return Err(DBError(format!(
|
||||
"expect response {:?} but found {:?}",
|
||||
expect,
|
||||
&buf[..n_bytes]
|
||||
)));
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct MasterReplicationClient {
|
||||
pub streams: Arc<Mutex<Vec<TcpStream>>>,
|
||||
}
|
||||
|
||||
impl MasterReplicationClient {
|
||||
pub fn new() -> MasterReplicationClient {
|
||||
MasterReplicationClient {
|
||||
streams: Arc::new(Mutex::new(Vec::new())),
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn send_rdb_file(&mut self, stream: &mut TcpStream) -> Result<(), DBError> {
|
||||
let empty_rdb_file_bytes = (0..EMPTY_RDB_FILE_HEX_STRING.len())
|
||||
.step_by(2)
|
||||
.map(|i| u8::from_str_radix(&EMPTY_RDB_FILE_HEX_STRING[i..i + 2], 16))
|
||||
.collect::<Result<Vec<u8>, ParseIntError>>()?;
|
||||
|
||||
println!("going to send rdb file");
|
||||
_ = stream.write("$".as_bytes()).await?;
|
||||
_ = stream
|
||||
.write(empty_rdb_file_bytes.len().to_string().as_bytes())
|
||||
.await?;
|
||||
_ = stream.write_all("\r\n".as_bytes()).await?;
|
||||
_ = stream.write_all(&empty_rdb_file_bytes).await?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub async fn add_stream(&mut self, stream: TcpStream) -> Result<(), DBError> {
|
||||
let mut streams = self.streams.lock().await;
|
||||
streams.push(stream);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub async fn send_command(&mut self, protocol: Protocol) -> Result<(), DBError> {
|
||||
let mut streams = self.streams.lock().await;
|
||||
for stream in streams.iter_mut() {
|
||||
stream.write_all(protocol.encode().as_bytes()).await?;
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
}
|
156
src/server.rs
Normal file
156
src/server.rs
Normal file
@@ -0,0 +1,156 @@
|
||||
use core::str;
|
||||
use std::collections::BTreeMap;
|
||||
use std::collections::HashMap;
|
||||
use std::path::PathBuf;
|
||||
use std::sync::atomic::AtomicU64;
|
||||
use std::sync::Arc;
|
||||
use tokio::fs::OpenOptions;
|
||||
use tokio::io::AsyncReadExt;
|
||||
use tokio::io::AsyncWriteExt;
|
||||
use tokio::sync::mpsc::Sender;
|
||||
use tokio::sync::Mutex;
|
||||
|
||||
use crate::cmd::Cmd;
|
||||
use crate::error::DBError;
|
||||
use crate::options;
|
||||
use crate::protocol::Protocol;
|
||||
use crate::rdb;
|
||||
use crate::replication_client::FollowerReplicationClient;
|
||||
use crate::replication_client::MasterReplicationClient;
|
||||
use crate::storage::Storage;
|
||||
|
||||
type Stream = BTreeMap<String, Vec<(String, String)>>;
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct Server {
|
||||
pub storage: Arc<Mutex<Storage>>,
|
||||
pub streams: Arc<Mutex<HashMap<String, Stream>>>,
|
||||
pub option: options::DBOption,
|
||||
pub offset: Arc<AtomicU64>,
|
||||
pub master_repl_clients: Arc<Mutex<Option<MasterReplicationClient>>>,
|
||||
pub stream_reader_blocker: Arc<Mutex<Vec<Sender<()>>>>,
|
||||
master_addr: Option<String>,
|
||||
}
|
||||
|
||||
impl Server {
|
||||
pub async fn new(option: options::DBOption) -> Self {
|
||||
let master_addr = match option.replication.role.as_str() {
|
||||
"slave" => Some(
|
||||
option
|
||||
.replication
|
||||
.replica_of
|
||||
.clone()
|
||||
.unwrap()
|
||||
.replace(' ', ":"),
|
||||
),
|
||||
_ => None,
|
||||
};
|
||||
|
||||
let is_master = option.replication.role == "master";
|
||||
|
||||
let mut server = Server {
|
||||
storage: Arc::new(Mutex::new(Storage::new())),
|
||||
streams: Arc::new(Mutex::new(HashMap::new())),
|
||||
option,
|
||||
master_repl_clients: if is_master {
|
||||
Arc::new(Mutex::new(Some(MasterReplicationClient::new())))
|
||||
} else {
|
||||
Arc::new(Mutex::new(None))
|
||||
},
|
||||
offset: Arc::new(AtomicU64::new(0)),
|
||||
stream_reader_blocker: Arc::new(Mutex::new(Vec::new())),
|
||||
master_addr,
|
||||
};
|
||||
|
||||
server.init().await.unwrap();
|
||||
server
|
||||
}
|
||||
|
||||
pub async fn init(&mut self) -> Result<(), DBError> {
|
||||
// master initialization
|
||||
if self.is_master() {
|
||||
println!("Start as master\n");
|
||||
let db_file_path =
|
||||
PathBuf::from(self.option.dir.clone()).join(self.option.db_file_name.clone());
|
||||
println!("will open db file path: {}", db_file_path.display());
|
||||
|
||||
// create empty db file if not exits
|
||||
let mut file = OpenOptions::new()
|
||||
.read(true)
|
||||
.write(true)
|
||||
.create(true)
|
||||
.truncate(false)
|
||||
.open(db_file_path.clone())
|
||||
.await?;
|
||||
|
||||
if file.metadata().await?.len() != 0 {
|
||||
rdb::parse_rdb_file(&mut file, self).await?;
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub async fn get_follower_repl_client(&mut self) -> Option<FollowerReplicationClient> {
|
||||
if self.is_slave() {
|
||||
Some(FollowerReplicationClient::new(self.master_addr.clone().unwrap()).await)
|
||||
} else {
|
||||
None
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn handle(
|
||||
&mut self,
|
||||
mut stream: tokio::net::TcpStream,
|
||||
is_rep_conn: bool,
|
||||
) -> Result<(), DBError> {
|
||||
let mut buf = [0; 512];
|
||||
let mut queued_cmd: Option<Vec<(Cmd, Protocol)>> = None;
|
||||
loop {
|
||||
if let Ok(len) = stream.read(&mut buf).await {
|
||||
if len == 0 {
|
||||
println!("[handle] connection closed");
|
||||
return Ok(());
|
||||
}
|
||||
let s = str::from_utf8(&buf[..len])?;
|
||||
let (cmd, protocol) =
|
||||
Cmd::from(s).unwrap_or((Cmd::Unknow, Protocol::err("unknow cmd")));
|
||||
println!("got command: {:?}, protocol: {:?}", cmd, protocol);
|
||||
|
||||
let res = cmd
|
||||
.run(self, protocol, is_rep_conn, &mut queued_cmd)
|
||||
.await
|
||||
.unwrap_or(Protocol::err("unknow cmd"));
|
||||
print!("queued 2 cmd {:?}", queued_cmd);
|
||||
|
||||
// only send response to normal client, do not send response to replication client
|
||||
if !is_rep_conn {
|
||||
println!("going to send response {}", res.encode());
|
||||
_ = stream.write(res.encode().as_bytes()).await?;
|
||||
}
|
||||
|
||||
// send a full RDB file to slave
|
||||
if self.is_master() {
|
||||
if let Cmd::Psync = cmd {
|
||||
let mut master_rep_client = self.master_repl_clients.lock().await;
|
||||
let master_rep_client = master_rep_client.as_mut().unwrap();
|
||||
master_rep_client.send_rdb_file(&mut stream).await?;
|
||||
master_rep_client.add_stream(stream).await?;
|
||||
break;
|
||||
}
|
||||
}
|
||||
} else {
|
||||
println!("[handle] going to break");
|
||||
break;
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub fn is_slave(&self) -> bool {
|
||||
self.option.replication.role == "slave"
|
||||
}
|
||||
|
||||
pub fn is_master(&self) -> bool {
|
||||
!self.is_slave()
|
||||
}
|
||||
}
|
59
src/storage.rs
Normal file
59
src/storage.rs
Normal file
@@ -0,0 +1,59 @@
|
||||
use std::{
|
||||
collections::HashMap,
|
||||
time::{SystemTime, UNIX_EPOCH},
|
||||
};
|
||||
|
||||
pub type ValueType = (String, Option<u128>);
|
||||
|
||||
pub struct Storage {
|
||||
// key -> (value, (insert/update time, expire milli seconds))
|
||||
set: HashMap<String, ValueType>,
|
||||
}
|
||||
|
||||
#[inline]
|
||||
pub fn now_in_millis() -> u128 {
|
||||
let start = SystemTime::now();
|
||||
let duration_since_epoch = start.duration_since(UNIX_EPOCH).unwrap();
|
||||
duration_since_epoch.as_millis()
|
||||
}
|
||||
|
||||
impl Storage {
|
||||
pub fn new() -> Self {
|
||||
Storage {
|
||||
set: HashMap::new(),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn get(self: &mut Self, k: &str) -> Option<String> {
|
||||
match self.set.get(k) {
|
||||
Some((ss, expire_timestamp)) => match expire_timestamp {
|
||||
Some(expire_time_stamp) => {
|
||||
if now_in_millis() > *expire_time_stamp {
|
||||
self.set.remove(k);
|
||||
None
|
||||
} else {
|
||||
Some(ss.clone())
|
||||
}
|
||||
}
|
||||
_ => Some(ss.clone()),
|
||||
},
|
||||
_ => None,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn set(self: &mut Self, k: String, v: String) {
|
||||
self.set.insert(k, (v, None));
|
||||
}
|
||||
|
||||
pub fn setx(self: &mut Self, k: String, v: String, expire_ms: u128) {
|
||||
self.set.insert(k, (v, Some(expire_ms + now_in_millis())));
|
||||
}
|
||||
|
||||
pub fn del(self: &mut Self, k: String) {
|
||||
self.set.remove(&k);
|
||||
}
|
||||
|
||||
pub fn keys(self: &Self) -> Vec<String> {
|
||||
self.set.keys().map(|x| x.clone()).collect()
|
||||
}
|
||||
}
|
Reference in New Issue
Block a user