This commit is contained in:
2025-08-16 10:10:24 +02:00
parent 074be114c3
commit 246304b9fa
3 changed files with 89 additions and 95 deletions

View File

@@ -415,81 +415,74 @@ impl Cmd {
} }
} }
pub async fn run( pub async fn run(self, server: &mut Server) -> Result<Protocol, DBError> {
&self,
server: &mut Server,
protocol: Protocol,
queued_cmd: &mut Option<Vec<(Cmd, Protocol)>>,
) -> Result<Protocol, DBError> {
// Handle queued commands for transactions // Handle queued commands for transactions
if queued_cmd.is_some() if server.queued_cmd.is_some()
&& !matches!(self, Cmd::Exec) && !matches!(self, Cmd::Exec)
&& !matches!(self, Cmd::Multi) && !matches!(self, Cmd::Multi)
&& !matches!(self, Cmd::Discard) && !matches!(self, Cmd::Discard)
{ {
queued_cmd let protocol = self.clone().to_protocol();
.as_mut() server.queued_cmd.as_mut().unwrap().push((self, protocol));
.unwrap()
.push((self.clone(), protocol.clone()));
return Ok(Protocol::SimpleString("QUEUED".to_string())); return Ok(Protocol::SimpleString("QUEUED".to_string()));
} }
match self { match self {
Cmd::Select(db) => select_cmd(server, *db).await, Cmd::Select(db) => select_cmd(server, db).await,
Cmd::Ping => Ok(Protocol::SimpleString("PONG".to_string())), Cmd::Ping => Ok(Protocol::SimpleString("PONG".to_string())),
Cmd::Echo(s) => Ok(Protocol::SimpleString(s.clone())), Cmd::Echo(s) => Ok(Protocol::BulkString(s)),
Cmd::Get(k) => get_cmd(server, k).await, Cmd::Get(k) => get_cmd(server, &k).await,
Cmd::Set(k, v) => set_cmd(server, k, v).await, Cmd::Set(k, v) => set_cmd(server, &k, &v).await,
Cmd::SetPx(k, v, x) => set_px_cmd(server, k, v, x).await, Cmd::SetPx(k, v, x) => set_px_cmd(server, &k, &v, &x).await,
Cmd::SetEx(k, v, x) => set_ex_cmd(server, k, v, x).await, Cmd::SetEx(k, v, x) => set_ex_cmd(server, &k, &v, &x).await,
Cmd::Del(k) => del_cmd(server, k).await, Cmd::Del(k) => del_cmd(server, &k).await,
Cmd::ConfigGet(name) => config_get_cmd(name, server), Cmd::ConfigGet(name) => config_get_cmd(&name, server),
Cmd::Keys => keys_cmd(server).await, Cmd::Keys => keys_cmd(server).await,
Cmd::Info(section) => info_cmd(server, section).await, Cmd::Info(section) => info_cmd(server, &section).await,
Cmd::Type(k) => type_cmd(server, k).await, Cmd::Type(k) => type_cmd(server, &k).await,
Cmd::Incr(key) => incr_cmd(server, key).await, Cmd::Incr(key) => incr_cmd(server, &key).await,
Cmd::Multi => { Cmd::Multi => {
*queued_cmd = Some(Vec::<(Cmd, Protocol)>::new()); server.queued_cmd = Some(Vec::<(Cmd, Protocol)>::new());
Ok(Protocol::SimpleString("OK".to_string())) Ok(Protocol::SimpleString("OK".to_string()))
} }
Cmd::Exec => exec_cmd(queued_cmd, server).await, Cmd::Exec => exec_cmd(server).await,
Cmd::Discard => { Cmd::Discard => {
if queued_cmd.is_some() { if server.queued_cmd.is_some() {
*queued_cmd = None; server.queued_cmd = None;
Ok(Protocol::SimpleString("OK".to_string())) Ok(Protocol::SimpleString("OK".to_string()))
} else { } else {
Ok(Protocol::err("ERR DISCARD without MULTI")) Ok(Protocol::err("ERR DISCARD without MULTI"))
} }
} }
// Hash commands // Hash commands
Cmd::HSet(key, pairs) => hset_cmd(server, key, pairs).await, Cmd::HSet(key, pairs) => hset_cmd(server, &key, &pairs).await,
Cmd::HGet(key, field) => hget_cmd(server, key, field).await, Cmd::HGet(key, field) => hget_cmd(server, &key, &field).await,
Cmd::HGetAll(key) => hgetall_cmd(server, key).await, Cmd::HGetAll(key) => hgetall_cmd(server, &key).await,
Cmd::HDel(key, fields) => hdel_cmd(server, key, fields).await, Cmd::HDel(key, fields) => hdel_cmd(server, &key, &fields).await,
Cmd::HExists(key, field) => hexists_cmd(server, key, field).await, Cmd::HExists(key, field) => hexists_cmd(server, &key, &field).await,
Cmd::HKeys(key) => hkeys_cmd(server, key).await, Cmd::HKeys(key) => hkeys_cmd(server, &key).await,
Cmd::HVals(key) => hvals_cmd(server, key).await, Cmd::HVals(key) => hvals_cmd(server, &key).await,
Cmd::HLen(key) => hlen_cmd(server, key).await, Cmd::HLen(key) => hlen_cmd(server, &key).await,
Cmd::HMGet(key, fields) => hmget_cmd(server, key, fields).await, Cmd::HMGet(key, fields) => hmget_cmd(server, &key, &fields).await,
Cmd::HSetNx(key, field, value) => hsetnx_cmd(server, key, field, value).await, Cmd::HSetNx(key, field, value) => hsetnx_cmd(server, &key, &field, &value).await,
Cmd::HScan(key, cursor, pattern, count) => hscan_cmd(server, key, cursor, pattern.as_deref(), count).await, Cmd::HScan(key, cursor, pattern, count) => hscan_cmd(server, &key, &cursor, pattern.as_deref(), &count).await,
Cmd::Scan(cursor, pattern, count) => scan_cmd(server, cursor, pattern.as_deref(), count).await, Cmd::Scan(cursor, pattern, count) => scan_cmd(server, &cursor, pattern.as_deref(), &count).await,
Cmd::Ttl(key) => ttl_cmd(server, key).await, Cmd::Ttl(key) => ttl_cmd(server, &key).await,
Cmd::Exists(key) => exists_cmd(server, key).await, Cmd::Exists(key) => exists_cmd(server, &key).await,
Cmd::Quit => Ok(Protocol::SimpleString("OK".to_string())), Cmd::Quit => Ok(Protocol::SimpleString("OK".to_string())),
Cmd::Client(_) => Ok(Protocol::SimpleString("OK".to_string())), Cmd::Client(_) => Ok(Protocol::SimpleString("OK".to_string())),
Cmd::ClientSetName(name) => client_setname_cmd(server, name).await, Cmd::ClientSetName(name) => client_setname_cmd(server, &name).await,
Cmd::ClientGetName => client_getname_cmd(server).await, Cmd::ClientGetName => client_getname_cmd(server).await,
// List commands // List commands
Cmd::LPush(key, elements) => lpush_cmd(server, key, elements).await, Cmd::LPush(key, elements) => lpush_cmd(server, &key, &elements).await,
Cmd::RPush(key, elements) => rpush_cmd(server, key, elements).await, Cmd::RPush(key, elements) => rpush_cmd(server, &key, &elements).await,
Cmd::LPop(key, count) => lpop_cmd(server, key, count).await, Cmd::LPop(key, count) => lpop_cmd(server, &key, &count).await,
Cmd::RPop(key, count) => rpop_cmd(server, key, count).await, Cmd::RPop(key, count) => rpop_cmd(server, &key, &count).await,
Cmd::LLen(key) => llen_cmd(server, key).await, Cmd::LLen(key) => llen_cmd(server, &key).await,
Cmd::LRem(key, count, element) => lrem_cmd(server, key, *count, element).await, Cmd::LRem(key, count, element) => lrem_cmd(server, &key, count, &element).await,
Cmd::LTrim(key, start, stop) => ltrim_cmd(server, key, *start, *stop).await, Cmd::LTrim(key, start, stop) => ltrim_cmd(server, &key, start, stop).await,
Cmd::LIndex(key, index) => lindex_cmd(server, key, *index).await, Cmd::LIndex(key, index) => lindex_cmd(server, &key, index).await,
Cmd::LRange(key, start, stop) => lrange_cmd(server, key, *start, *stop).await, Cmd::LRange(key, start, stop) => lrange_cmd(server, &key, start, stop).await,
Cmd::FlushDb => flushdb_cmd(server).await, Cmd::FlushDb => flushdb_cmd(server).await,
Cmd::Unknow(s) => { Cmd::Unknow(s) => {
println!("\x1b[31;1munknown command: {}\x1b[0m", s); println!("\x1b[31;1munknown command: {}\x1b[0m", s);
@@ -497,6 +490,17 @@ impl Cmd {
} }
} }
} }
pub fn to_protocol(self) -> Protocol {
match self {
Cmd::Select(db) => Protocol::Array(vec![Protocol::BulkString("select".to_string()), Protocol::BulkString(db.to_string())]),
Cmd::Ping => Protocol::Array(vec![Protocol::BulkString("ping".to_string())]),
Cmd::Echo(s) => Protocol::Array(vec![Protocol::BulkString("echo".to_string()), Protocol::BulkString(s)]),
Cmd::Get(k) => Protocol::Array(vec![Protocol::BulkString("get".to_string()), Protocol::BulkString(k)]),
Cmd::Set(k, v) => Protocol::Array(vec![Protocol::BulkString("set".to_string()), Protocol::BulkString(k), Protocol::BulkString(v)]),
_ => Protocol::SimpleString("...".to_string())
}
}
} }
async fn flushdb_cmd(server: &mut Server) -> Result<Protocol, DBError> { async fn flushdb_cmd(server: &mut Server) -> Result<Protocol, DBError> {
@@ -605,21 +609,19 @@ async fn rpush_cmd(server: &Server, key: &str, elements: &[String]) -> Result<Pr
} }
} }
async fn exec_cmd( async fn exec_cmd(server: &mut Server) -> Result<Protocol, DBError> {
queued_cmd: &mut Option<Vec<(Cmd, Protocol)>>, let cmds = if let Some(cmds) = server.queued_cmd.take() {
server: &mut Server, cmds
) -> Result<Protocol, DBError> {
if queued_cmd.is_some() {
let mut vec = Vec::new();
for (cmd, protocol) in queued_cmd.as_ref().unwrap() {
let res = Box::pin(cmd.run(server, protocol.clone(), &mut None)).await?;
vec.push(res);
}
*queued_cmd = None;
Ok(Protocol::Array(vec))
} else { } else {
Ok(Protocol::err("ERR EXEC without MULTI")) return Ok(Protocol::err("ERR EXEC without MULTI"));
};
let mut vec = Vec::new();
for (cmd, _) in cmds {
let res = cmd.run(server).await?;
vec.push(res);
} }
Ok(Protocol::Array(vec))
} }
async fn incr_cmd(server: &Server, key: &String) -> Result<Protocol, DBError> { async fn incr_cmd(server: &Server, key: &String) -> Result<Protocol, DBError> {
@@ -834,31 +836,25 @@ async fn hsetnx_cmd(server: &Server, key: &str, field: &str, value: &str) -> Res
} }
async fn scan_cmd(server: &Server, cursor: &u64, pattern: Option<&str>, count: &Option<u64>) -> Result<Protocol, DBError> { async fn scan_cmd(server: &Server, cursor: &u64, pattern: Option<&str>, count: &Option<u64>) -> Result<Protocol, DBError> {
match server.current_storage()?.scan(*cursor, pattern, *count) { server.current_storage()?.scan(*cursor, pattern, *count).map(|(next_cursor, keys)| {
Ok((next_cursor, keys)) => { let mut result = Vec::new();
let mut result = Vec::new(); result.push(Protocol::BulkString(next_cursor.to_string()));
result.push(Protocol::BulkString(next_cursor.to_string())); result.push(Protocol::Array(
result.push(Protocol::Array( keys.into_iter().map(Protocol::BulkString).collect(),
keys.into_iter().map(Protocol::BulkString).collect(), ));
)); Protocol::Array(result)
Ok(Protocol::Array(result)) })
}
Err(e) => Ok(Protocol::err(&e.0)),
}
} }
async fn hscan_cmd(server: &Server, key: &str, cursor: &u64, pattern: Option<&str>, count: &Option<u64>) -> Result<Protocol, DBError> { async fn hscan_cmd(server: &Server, key: &str, cursor: &u64, pattern: Option<&str>, count: &Option<u64>) -> Result<Protocol, DBError> {
match server.current_storage()?.hscan(key, *cursor, pattern, *count) { server.current_storage()?.hscan(key, *cursor, pattern, *count).map(|(next_cursor, fields)| {
Ok((next_cursor, fields)) => { let mut result = Vec::new();
let mut result = Vec::new(); result.push(Protocol::BulkString(next_cursor.to_string()));
result.push(Protocol::BulkString(next_cursor.to_string())); result.push(Protocol::Array(
result.push(Protocol::Array( fields.into_iter().map(Protocol::BulkString).collect(),
fields.into_iter().map(Protocol::BulkString).collect(), ));
)); Protocol::Array(result)
Ok(Protocol::Array(result)) })
}
Err(e) => Ok(Protocol::err(&e.0)),
}
} }
async fn ttl_cmd(server: &Server, key: &str) -> Result<Protocol, DBError> { async fn ttl_cmd(server: &Server, key: &str) -> Result<Protocol, DBError> {

View File

@@ -16,6 +16,7 @@ pub struct Server {
pub option: options::DBOption, pub option: options::DBOption,
pub client_name: Option<String>, pub client_name: Option<String>,
pub selected_db: u64, // Changed from usize to u64 pub selected_db: u64, // Changed from usize to u64
pub queued_cmd: Option<Vec<(Cmd, Protocol)>>,
} }
impl Server { impl Server {
@@ -25,6 +26,7 @@ impl Server {
option, option,
client_name: None, client_name: None,
selected_db: 0, selected_db: 0,
queued_cmd: None,
} }
} }
@@ -61,7 +63,6 @@ impl Server {
mut stream: tokio::net::TcpStream, mut stream: tokio::net::TcpStream,
) -> Result<(), DBError> { ) -> Result<(), DBError> {
let mut buf = [0; 512]; let mut buf = [0; 512];
let mut queued_cmd: Option<Vec<(Cmd, Protocol)>> = None;
loop { loop {
let len = match stream.read(&mut buf).await { let len = match stream.read(&mut buf).await {
@@ -96,16 +97,13 @@ impl Server {
// Check if this is a QUIT command before processing // Check if this is a QUIT command before processing
let is_quit = matches!(cmd, Cmd::Quit); let is_quit = matches!(cmd, Cmd::Quit);
let res = cmd let res = cmd.run(self).await.unwrap_or(Protocol::err("unknown cmd from server"));
.run(&mut self.clone(), protocol.clone(), &mut queued_cmd)
.await
.unwrap_or(Protocol::err("unknown cmd from server"));
if self.option.debug { if self.option.debug {
println!("\x1b[34;1mqueued cmd {:?}\x1b[0m", queued_cmd); println!("\x1b[34;1mqueued cmd {:?}\x1b[0m", self.queued_cmd);
println!("\x1b[32;1mgoing to send response {}\x1b[0m", res.encode()); println!("\x1b[32;1mgoing to send response {}\x1b[0m", res.encode());
} else { } else {
print!("queued cmd {:?}", queued_cmd); print!("queued cmd {:?}", self.queued_cmd);
println!("going to send response {}", res.encode()); println!("going to send response {}", res.encode());
} }

View File

@@ -736,8 +736,8 @@ impl Storage {
current_cursor += 1; current_cursor += 1;
} }
// If we've reached the end of iteration, return cursor 0 to indicate completion // If we've reached the end of the iteration, return cursor 0, otherwise return the next cursor position
let next_cursor = if iter.next().is_none() { 0 } else { current_cursor }; let next_cursor = if returned_keys < count { 0 } else { current_cursor };
Ok((next_cursor, keys)) Ok((next_cursor, keys))
} }