This commit is contained in:
2025-08-16 08:25:25 +02:00
parent 0f6e595000
commit 7bcb673361
15 changed files with 522 additions and 388 deletions

View File

@@ -33,7 +33,10 @@ pub enum Cmd {
Ttl(String),
Exists(String),
Quit,
Unknow,
Client(Vec<String>),
ClientSetName(String),
ClientGetName,
Unknow(String),
}
impl Cmd {
@@ -274,7 +277,30 @@ impl Cmd {
}
Cmd::Quit
}
_ => Cmd::Unknow,
"client" => {
if cmd.len() > 1 {
match cmd[1].to_lowercase().as_str() {
"setname" => {
if cmd.len() == 3 {
Cmd::ClientSetName(cmd[2].clone())
} else {
return Err(DBError("wrong number of arguments for 'client setname' command".to_string()));
}
}
"getname" => {
if cmd.len() == 2 {
Cmd::ClientGetName
} else {
return Err(DBError("wrong number of arguments for 'client getname' command".to_string()));
}
}
_ => Cmd::Client(cmd[1..].to_vec()),
}
} else {
Cmd::Client(vec![])
}
}
_ => Cmd::Unknow(cmd[0].clone()),
},
protocol.0,
))
@@ -288,7 +314,7 @@ impl Cmd {
pub async fn run(
&self,
server: &Server,
server: &mut Server,
protocol: Protocol,
queued_cmd: &mut Option<Vec<(Cmd, Protocol)>>,
) -> Result<Protocol, DBError> {
@@ -347,14 +373,20 @@ impl Cmd {
Cmd::Ttl(key) => ttl_cmd(server, key).await,
Cmd::Exists(key) => exists_cmd(server, key).await,
Cmd::Quit => Ok(Protocol::SimpleString("OK".to_string())),
Cmd::Unknow => Ok(Protocol::err("unknown cmd")),
Cmd::Client(_) => Ok(Protocol::SimpleString("OK".to_string())),
Cmd::ClientSetName(name) => client_setname_cmd(server, name).await,
Cmd::ClientGetName => client_getname_cmd(server).await,
Cmd::Unknow(s) => {
println!("\x1b[31;1munknown command: {}\x1b[0m", s);
Ok(Protocol::err(&format!("ERR unknown command '{}'", s)))
}
}
}
}
async fn exec_cmd(
queued_cmd: &mut Option<Vec<(Cmd, Protocol)>>,
server: &Server,
server: &mut Server,
) -> Result<Protocol, DBError> {
if queued_cmd.is_some() {
let mut vec = Vec::new();
@@ -593,3 +625,15 @@ async fn exists_cmd(server: &Server, key: &str) -> Result<Protocol, DBError> {
Err(e) => Ok(Protocol::err(&e.0)),
}
}
async fn client_setname_cmd(server: &mut Server, name: &str) -> Result<Protocol, DBError> {
server.client_name = Some(name.to_string());
Ok(Protocol::SimpleString("OK".to_string()))
}
async fn client_getname_cmd(server: &Server) -> Result<Protocol, DBError> {
match &server.client_name {
Some(name) => Ok(Protocol::BulkString(name.clone())),
None => Ok(Protocol::Null),
}
}

View File

@@ -4,7 +4,6 @@ use tokio::sync::mpsc;
use redb;
use bincode;
use crate::protocol::Protocol;
// todo: more error types
#[derive(Debug)]

View File

@@ -18,6 +18,10 @@ struct Args {
/// The port of the Redis server, default is 6379 if not specified
#[arg(long)]
port: Option<u16>,
/// Enable debug mode
#[arg(long)]
debug: bool,
}
#[tokio::main]
@@ -36,11 +40,15 @@ async fn main() {
let option = redis_rs::options::DBOption {
dir: args.dir,
port,
debug: args.debug,
};
// new server
let server = server::Server::new(option).await;
// Add a small delay to ensure the port is ready
tokio::time::sleep(std::time::Duration::from_millis(100)).await;
// accept new connections
loop {
let stream = listener.accept().await;

View File

@@ -2,4 +2,5 @@
pub struct DBOption {
pub dir: String,
pub port: u16,
pub debug: bool,
}

View File

@@ -159,18 +159,21 @@ impl Protocol {
}
fn parse_usize(protocol: &str) -> Result<usize, DBError> {
match protocol.len() {
0 => Err(DBError(format!("parse usize error: {:?}", protocol))),
_ => Ok(protocol
if protocol.is_empty() {
Err(DBError("Cannot parse usize from empty string".to_string()))
} else {
protocol
.parse::<usize>()
.map_err(|_| DBError(format!("parse usize error: {}", protocol)))?),
.map_err(|_| DBError(format!("Failed to parse usize from: {}", protocol)))
}
}
fn parse_string(protocol: &str) -> Result<String, DBError> {
match protocol.len() {
0 => Err(DBError(format!("parse usize error: {:?}", protocol))),
_ => Ok(protocol.to_string()),
if protocol.is_empty() {
// Allow empty strings, but handle appropriately
Ok("".to_string())
} else {
Ok(protocol.to_string())
}
}
}

View File

@@ -14,6 +14,7 @@ use crate::storage::Storage;
pub struct Server {
pub storage: Arc<Storage>,
pub option: options::DBOption,
pub client_name: Option<String>,
}
impl Server {
@@ -28,6 +29,7 @@ impl Server {
Server {
storage: Arc::new(storage),
option,
client_name: None,
}
}
@@ -46,20 +48,37 @@ impl Server {
}
let s = str::from_utf8(&buf[..len])?;
let (cmd, protocol) =
Cmd::from(s).unwrap_or((Cmd::Unknow, Protocol::err("unknow cmd")));
println!("got command: {:?}, protocol: {:?}", cmd, protocol);
let (cmd, protocol) = match Cmd::from(s) {
Ok((cmd, protocol)) => (cmd, protocol),
Err(e) => {
println!("\x1b[31;1mprotocol error: {:?}\x1b[0m", e);
(Cmd::Unknow("protocol_error".to_string()), Protocol::err(&format!("protocol error: {}", e.0)))
}
};
if self.option.debug {
println!("\x1b[34;1mgot command: {:?}, protocol: {:?}\x1b[0m", cmd, protocol);
} else {
println!("got command: {:?}, protocol: {:?}", cmd, protocol);
}
// Check if this is a QUIT command before processing
let is_quit = matches!(cmd, Cmd::Quit);
let res = cmd
.run(self, protocol, &mut queued_cmd)
.run(&mut self.clone(), protocol.clone(), &mut queued_cmd)
.await
.unwrap_or(Protocol::err("unknow cmd"));
print!("queued cmd {:?}", queued_cmd);
.unwrap_or(Protocol::err("unknown cmd from server"));
if self.option.debug {
println!("\x1b[34;1mqueued cmd {:?}\x1b[0m", queued_cmd);
} else {
print!("queued cmd {:?}", queued_cmd);
}
println!("going to send response {}", res.encode());
if self.option.debug {
println!("\x1b[32;1mgoing to send response {}\x1b[0m", res.encode());
} else {
println!("going to send response {}", res.encode());
}
_ = stream.write(res.encode().as_bytes()).await?;
// If this was a QUIT command, close the connection

View File

@@ -3,7 +3,7 @@ use std::{
time::{SystemTime, UNIX_EPOCH},
};
use redb::{Database, Error, ReadableTable, Table, TableDefinition, WriteTransaction, ReadTransaction};
use redb::{Database, ReadableTable, TableDefinition};
use serde::{Deserialize, Serialize};
use crate::error::DBError;
@@ -493,7 +493,6 @@ impl Storage {
// Stop if we've returned enough keys
if returned_keys >= count {
current_cursor += 1;
break;
}
}
@@ -502,7 +501,7 @@ impl Storage {
}
// If we've reached the end of iteration, return cursor 0 to indicate completion
let next_cursor = if returned_keys < count { 0 } else { current_cursor };
let next_cursor = if iter.next().is_none() { 0 } else { current_cursor };
Ok((next_cursor, keys))
}
@@ -563,7 +562,6 @@ impl Storage {
returned_fields += 1;
if returned_fields >= count {
current_cursor += 1;
break;
}
}
@@ -571,7 +569,7 @@ impl Storage {
current_cursor += 1;
}
let next_cursor = if returned_fields < count { 0 } else { current_cursor };
let next_cursor = if iter.next().is_none() { 0 } else { current_cursor };
Ok((next_cursor, fields))
}
Some(_) => Err(DBError("WRONGTYPE Operation against a key holding the wrong kind of value".to_string())),