This commit is contained in:
2025-08-16 10:10:24 +02:00
parent 074be114c3
commit 246304b9fa
3 changed files with 89 additions and 95 deletions

View File

@@ -415,81 +415,74 @@ impl Cmd {
}
}
pub async fn run(
&self,
server: &mut Server,
protocol: Protocol,
queued_cmd: &mut Option<Vec<(Cmd, Protocol)>>,
) -> Result<Protocol, DBError> {
pub async fn run(self, server: &mut Server) -> Result<Protocol, DBError> {
// Handle queued commands for transactions
if queued_cmd.is_some()
if server.queued_cmd.is_some()
&& !matches!(self, Cmd::Exec)
&& !matches!(self, Cmd::Multi)
&& !matches!(self, Cmd::Discard)
{
queued_cmd
.as_mut()
.unwrap()
.push((self.clone(), protocol.clone()));
let protocol = self.clone().to_protocol();
server.queued_cmd.as_mut().unwrap().push((self, protocol));
return Ok(Protocol::SimpleString("QUEUED".to_string()));
}
match self {
Cmd::Select(db) => select_cmd(server, *db).await,
Cmd::Select(db) => select_cmd(server, db).await,
Cmd::Ping => Ok(Protocol::SimpleString("PONG".to_string())),
Cmd::Echo(s) => Ok(Protocol::SimpleString(s.clone())),
Cmd::Get(k) => get_cmd(server, k).await,
Cmd::Set(k, v) => set_cmd(server, k, v).await,
Cmd::SetPx(k, v, x) => set_px_cmd(server, k, v, x).await,
Cmd::SetEx(k, v, x) => set_ex_cmd(server, k, v, x).await,
Cmd::Del(k) => del_cmd(server, k).await,
Cmd::ConfigGet(name) => config_get_cmd(name, server),
Cmd::Echo(s) => Ok(Protocol::BulkString(s)),
Cmd::Get(k) => get_cmd(server, &k).await,
Cmd::Set(k, v) => set_cmd(server, &k, &v).await,
Cmd::SetPx(k, v, x) => set_px_cmd(server, &k, &v, &x).await,
Cmd::SetEx(k, v, x) => set_ex_cmd(server, &k, &v, &x).await,
Cmd::Del(k) => del_cmd(server, &k).await,
Cmd::ConfigGet(name) => config_get_cmd(&name, server),
Cmd::Keys => keys_cmd(server).await,
Cmd::Info(section) => info_cmd(server, section).await,
Cmd::Type(k) => type_cmd(server, k).await,
Cmd::Incr(key) => incr_cmd(server, key).await,
Cmd::Info(section) => info_cmd(server, &section).await,
Cmd::Type(k) => type_cmd(server, &k).await,
Cmd::Incr(key) => incr_cmd(server, &key).await,
Cmd::Multi => {
*queued_cmd = Some(Vec::<(Cmd, Protocol)>::new());
server.queued_cmd = Some(Vec::<(Cmd, Protocol)>::new());
Ok(Protocol::SimpleString("OK".to_string()))
}
Cmd::Exec => exec_cmd(queued_cmd, server).await,
Cmd::Exec => exec_cmd(server).await,
Cmd::Discard => {
if queued_cmd.is_some() {
*queued_cmd = None;
if server.queued_cmd.is_some() {
server.queued_cmd = None;
Ok(Protocol::SimpleString("OK".to_string()))
} else {
Ok(Protocol::err("ERR DISCARD without MULTI"))
}
}
// Hash commands
Cmd::HSet(key, pairs) => hset_cmd(server, key, pairs).await,
Cmd::HGet(key, field) => hget_cmd(server, key, field).await,
Cmd::HGetAll(key) => hgetall_cmd(server, key).await,
Cmd::HDel(key, fields) => hdel_cmd(server, key, fields).await,
Cmd::HExists(key, field) => hexists_cmd(server, key, field).await,
Cmd::HKeys(key) => hkeys_cmd(server, key).await,
Cmd::HVals(key) => hvals_cmd(server, key).await,
Cmd::HLen(key) => hlen_cmd(server, key).await,
Cmd::HMGet(key, fields) => hmget_cmd(server, key, fields).await,
Cmd::HSetNx(key, field, value) => hsetnx_cmd(server, key, field, value).await,
Cmd::HScan(key, cursor, pattern, count) => hscan_cmd(server, key, cursor, pattern.as_deref(), count).await,
Cmd::Scan(cursor, pattern, count) => scan_cmd(server, cursor, pattern.as_deref(), count).await,
Cmd::Ttl(key) => ttl_cmd(server, key).await,
Cmd::Exists(key) => exists_cmd(server, key).await,
Cmd::HSet(key, pairs) => hset_cmd(server, &key, &pairs).await,
Cmd::HGet(key, field) => hget_cmd(server, &key, &field).await,
Cmd::HGetAll(key) => hgetall_cmd(server, &key).await,
Cmd::HDel(key, fields) => hdel_cmd(server, &key, &fields).await,
Cmd::HExists(key, field) => hexists_cmd(server, &key, &field).await,
Cmd::HKeys(key) => hkeys_cmd(server, &key).await,
Cmd::HVals(key) => hvals_cmd(server, &key).await,
Cmd::HLen(key) => hlen_cmd(server, &key).await,
Cmd::HMGet(key, fields) => hmget_cmd(server, &key, &fields).await,
Cmd::HSetNx(key, field, value) => hsetnx_cmd(server, &key, &field, &value).await,
Cmd::HScan(key, cursor, pattern, count) => hscan_cmd(server, &key, &cursor, pattern.as_deref(), &count).await,
Cmd::Scan(cursor, pattern, count) => scan_cmd(server, &cursor, pattern.as_deref(), &count).await,
Cmd::Ttl(key) => ttl_cmd(server, &key).await,
Cmd::Exists(key) => exists_cmd(server, &key).await,
Cmd::Quit => Ok(Protocol::SimpleString("OK".to_string())),
Cmd::Client(_) => Ok(Protocol::SimpleString("OK".to_string())),
Cmd::ClientSetName(name) => client_setname_cmd(server, name).await,
Cmd::ClientSetName(name) => client_setname_cmd(server, &name).await,
Cmd::ClientGetName => client_getname_cmd(server).await,
// List commands
Cmd::LPush(key, elements) => lpush_cmd(server, key, elements).await,
Cmd::RPush(key, elements) => rpush_cmd(server, key, elements).await,
Cmd::LPop(key, count) => lpop_cmd(server, key, count).await,
Cmd::RPop(key, count) => rpop_cmd(server, key, count).await,
Cmd::LLen(key) => llen_cmd(server, key).await,
Cmd::LRem(key, count, element) => lrem_cmd(server, key, *count, element).await,
Cmd::LTrim(key, start, stop) => ltrim_cmd(server, key, *start, *stop).await,
Cmd::LIndex(key, index) => lindex_cmd(server, key, *index).await,
Cmd::LRange(key, start, stop) => lrange_cmd(server, key, *start, *stop).await,
Cmd::LPush(key, elements) => lpush_cmd(server, &key, &elements).await,
Cmd::RPush(key, elements) => rpush_cmd(server, &key, &elements).await,
Cmd::LPop(key, count) => lpop_cmd(server, &key, &count).await,
Cmd::RPop(key, count) => rpop_cmd(server, &key, &count).await,
Cmd::LLen(key) => llen_cmd(server, &key).await,
Cmd::LRem(key, count, element) => lrem_cmd(server, &key, count, &element).await,
Cmd::LTrim(key, start, stop) => ltrim_cmd(server, &key, start, stop).await,
Cmd::LIndex(key, index) => lindex_cmd(server, &key, index).await,
Cmd::LRange(key, start, stop) => lrange_cmd(server, &key, start, stop).await,
Cmd::FlushDb => flushdb_cmd(server).await,
Cmd::Unknow(s) => {
println!("\x1b[31;1munknown command: {}\x1b[0m", s);
@@ -497,6 +490,17 @@ impl Cmd {
}
}
}
pub fn to_protocol(self) -> Protocol {
match self {
Cmd::Select(db) => Protocol::Array(vec![Protocol::BulkString("select".to_string()), Protocol::BulkString(db.to_string())]),
Cmd::Ping => Protocol::Array(vec![Protocol::BulkString("ping".to_string())]),
Cmd::Echo(s) => Protocol::Array(vec![Protocol::BulkString("echo".to_string()), Protocol::BulkString(s)]),
Cmd::Get(k) => Protocol::Array(vec![Protocol::BulkString("get".to_string()), Protocol::BulkString(k)]),
Cmd::Set(k, v) => Protocol::Array(vec![Protocol::BulkString("set".to_string()), Protocol::BulkString(k), Protocol::BulkString(v)]),
_ => Protocol::SimpleString("...".to_string())
}
}
}
async fn flushdb_cmd(server: &mut Server) -> Result<Protocol, DBError> {
@@ -605,21 +609,19 @@ async fn rpush_cmd(server: &Server, key: &str, elements: &[String]) -> Result<Pr
}
}
async fn exec_cmd(
queued_cmd: &mut Option<Vec<(Cmd, Protocol)>>,
server: &mut Server,
) -> Result<Protocol, DBError> {
if queued_cmd.is_some() {
let mut vec = Vec::new();
for (cmd, protocol) in queued_cmd.as_ref().unwrap() {
let res = Box::pin(cmd.run(server, protocol.clone(), &mut None)).await?;
vec.push(res);
}
*queued_cmd = None;
Ok(Protocol::Array(vec))
async fn exec_cmd(server: &mut Server) -> Result<Protocol, DBError> {
let cmds = if let Some(cmds) = server.queued_cmd.take() {
cmds
} else {
Ok(Protocol::err("ERR EXEC without MULTI"))
return Ok(Protocol::err("ERR EXEC without MULTI"));
};
let mut vec = Vec::new();
for (cmd, _) in cmds {
let res = cmd.run(server).await?;
vec.push(res);
}
Ok(Protocol::Array(vec))
}
async fn incr_cmd(server: &Server, key: &String) -> Result<Protocol, DBError> {
@@ -834,31 +836,25 @@ async fn hsetnx_cmd(server: &Server, key: &str, field: &str, value: &str) -> Res
}
async fn scan_cmd(server: &Server, cursor: &u64, pattern: Option<&str>, count: &Option<u64>) -> Result<Protocol, DBError> {
match server.current_storage()?.scan(*cursor, pattern, *count) {
Ok((next_cursor, keys)) => {
let mut result = Vec::new();
result.push(Protocol::BulkString(next_cursor.to_string()));
result.push(Protocol::Array(
keys.into_iter().map(Protocol::BulkString).collect(),
));
Ok(Protocol::Array(result))
}
Err(e) => Ok(Protocol::err(&e.0)),
}
server.current_storage()?.scan(*cursor, pattern, *count).map(|(next_cursor, keys)| {
let mut result = Vec::new();
result.push(Protocol::BulkString(next_cursor.to_string()));
result.push(Protocol::Array(
keys.into_iter().map(Protocol::BulkString).collect(),
));
Protocol::Array(result)
})
}
async fn hscan_cmd(server: &Server, key: &str, cursor: &u64, pattern: Option<&str>, count: &Option<u64>) -> Result<Protocol, DBError> {
match server.current_storage()?.hscan(key, *cursor, pattern, *count) {
Ok((next_cursor, fields)) => {
let mut result = Vec::new();
result.push(Protocol::BulkString(next_cursor.to_string()));
result.push(Protocol::Array(
fields.into_iter().map(Protocol::BulkString).collect(),
));
Ok(Protocol::Array(result))
}
Err(e) => Ok(Protocol::err(&e.0)),
}
server.current_storage()?.hscan(key, *cursor, pattern, *count).map(|(next_cursor, fields)| {
let mut result = Vec::new();
result.push(Protocol::BulkString(next_cursor.to_string()));
result.push(Protocol::Array(
fields.into_iter().map(Protocol::BulkString).collect(),
));
Protocol::Array(result)
})
}
async fn ttl_cmd(server: &Server, key: &str) -> Result<Protocol, DBError> {

View File

@@ -16,6 +16,7 @@ pub struct Server {
pub option: options::DBOption,
pub client_name: Option<String>,
pub selected_db: u64, // Changed from usize to u64
pub queued_cmd: Option<Vec<(Cmd, Protocol)>>,
}
impl Server {
@@ -25,6 +26,7 @@ impl Server {
option,
client_name: None,
selected_db: 0,
queued_cmd: None,
}
}
@@ -61,7 +63,6 @@ impl Server {
mut stream: tokio::net::TcpStream,
) -> Result<(), DBError> {
let mut buf = [0; 512];
let mut queued_cmd: Option<Vec<(Cmd, Protocol)>> = None;
loop {
let len = match stream.read(&mut buf).await {
@@ -96,16 +97,13 @@ impl Server {
// Check if this is a QUIT command before processing
let is_quit = matches!(cmd, Cmd::Quit);
let res = cmd
.run(&mut self.clone(), protocol.clone(), &mut queued_cmd)
.await
.unwrap_or(Protocol::err("unknown cmd from server"));
let res = cmd.run(self).await.unwrap_or(Protocol::err("unknown cmd from server"));
if self.option.debug {
println!("\x1b[34;1mqueued cmd {:?}\x1b[0m", queued_cmd);
println!("\x1b[34;1mqueued cmd {:?}\x1b[0m", self.queued_cmd);
println!("\x1b[32;1mgoing to send response {}\x1b[0m", res.encode());
} else {
print!("queued cmd {:?}", queued_cmd);
print!("queued cmd {:?}", self.queued_cmd);
println!("going to send response {}", res.encode());
}

View File

@@ -736,8 +736,8 @@ impl Storage {
current_cursor += 1;
}
// If we've reached the end of iteration, return cursor 0 to indicate completion
let next_cursor = if iter.next().is_none() { 0 } else { current_cursor };
// If we've reached the end of the iteration, return cursor 0, otherwise return the next cursor position
let next_cursor = if returned_keys < count { 0 } else { current_cursor };
Ok((next_cursor, keys))
}